gateserver.lua 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. local skynet = require "skynet"
  2. local netpack = require "skynet.netpack"
  3. local socketdriver = require "skynet.socketdriver"
  4. local gateserver = {}
  5. local socket -- listen socket
  6. local queue -- message queue
  7. local maxclient -- max client
  8. local client_number = 0
  9. local CMD = setmetatable({}, { __gc = function() netpack.clear(queue) end })
  10. local nodelay = false
  11. local connection = {}
  12. -- true : connected
  13. -- nil : closed
  14. -- false : close read
  15. function gateserver.openclient(fd)
  16. if connection[fd] then
  17. socketdriver.start(fd)
  18. end
  19. end
  20. function gateserver.closeclient(fd)
  21. local c = connection[fd]
  22. if c ~= nil then
  23. connection[fd] = nil
  24. socketdriver.close(fd)
  25. end
  26. end
  27. function gateserver.start(handler)
  28. assert(handler.message)
  29. assert(handler.connect)
  30. local listen_context = {}
  31. function CMD.open( source, conf )
  32. assert(not socket)
  33. local address = conf.address or "0.0.0.0"
  34. local port = assert(conf.port)
  35. maxclient = conf.maxclient or 1024
  36. nodelay = conf.nodelay
  37. skynet.error(string.format("Listen on %s:%d", address, port))
  38. socket = socketdriver.listen(address, port)
  39. listen_context.co = coroutine.running()
  40. listen_context.fd = socket
  41. skynet.wait(listen_context.co)
  42. conf.address = listen_context.addr
  43. conf.port = listen_context.port
  44. listen_context = nil
  45. socketdriver.start(socket)
  46. if handler.open then
  47. return handler.open(source, conf)
  48. end
  49. end
  50. function CMD.close()
  51. assert(socket)
  52. socketdriver.close(socket)
  53. end
  54. local MSG = {}
  55. local function dispatch_msg(fd, msg, sz)
  56. if connection[fd] then
  57. handler.message(fd, msg, sz)
  58. else
  59. skynet.error(string.format("Drop message from fd (%d) : %s", fd, netpack.tostring(msg,sz)))
  60. end
  61. end
  62. MSG.data = dispatch_msg
  63. local function dispatch_queue()
  64. local fd, msg, sz = netpack.pop(queue)
  65. if fd then
  66. -- may dispatch even the handler.message blocked
  67. -- If the handler.message never block, the queue should be empty, so only fork once and then exit.
  68. skynet.fork(dispatch_queue)
  69. dispatch_msg(fd, msg, sz)
  70. for fd, msg, sz in netpack.pop, queue do
  71. dispatch_msg(fd, msg, sz)
  72. end
  73. end
  74. end
  75. MSG.more = dispatch_queue
  76. function MSG.open(fd, msg)
  77. client_number = client_number + 1
  78. if client_number >= maxclient then
  79. socketdriver.shutdown(fd)
  80. return
  81. end
  82. if nodelay then
  83. socketdriver.nodelay(fd)
  84. end
  85. connection[fd] = true
  86. handler.connect(fd, msg)
  87. end
  88. function MSG.close(fd)
  89. if fd ~= socket then
  90. client_number = client_number - 1
  91. if connection[fd] then
  92. connection[fd] = false -- close read
  93. end
  94. if handler.disconnect then
  95. handler.disconnect(fd)
  96. end
  97. else
  98. socket = nil
  99. end
  100. end
  101. function MSG.error(fd, msg)
  102. if fd == socket then
  103. skynet.error("gateserver accept error:",msg)
  104. else
  105. socketdriver.shutdown(fd)
  106. if handler.error then
  107. handler.error(fd, msg)
  108. end
  109. end
  110. end
  111. function MSG.warning(fd, size)
  112. if handler.warning then
  113. handler.warning(fd, size)
  114. end
  115. end
  116. function MSG.init(id, addr, port)
  117. if listen_context then
  118. local co = listen_context.co
  119. if co then
  120. assert(id == listen_context.fd)
  121. listen_context.addr = addr
  122. listen_context.port = port
  123. skynet.wakeup(co)
  124. listen_context.co = nil
  125. end
  126. end
  127. end
  128. skynet.register_protocol {
  129. name = "socket",
  130. id = skynet.PTYPE_SOCKET, -- PTYPE_SOCKET = 6
  131. unpack = function ( msg, sz )
  132. return netpack.filter( queue, msg, sz)
  133. end,
  134. dispatch = function (_, _, q, type, ...)
  135. queue = q
  136. if type then
  137. MSG[type](...)
  138. end
  139. end
  140. }
  141. local function init()
  142. skynet.dispatch("lua", function (_, address, cmd, ...)
  143. local f = CMD[cmd]
  144. if f then
  145. skynet.ret(skynet.pack(f(address, ...)))
  146. else
  147. skynet.ret(skynet.pack(handler.command(cmd, address, ...)))
  148. end
  149. end)
  150. end
  151. if handler.embed then
  152. init()
  153. else
  154. skynet.start(init)
  155. end
  156. end
  157. return gateserver