ws_gate.lua 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. local skynet = require "skynet"
  2. require "skynet.manager"
  3. local socket = require "skynet.socket"
  4. local websocket = require "http.websocket"
  5. local socketdriver = require "skynet.socketdriver"
  6. local logger = require "logger"
  7. local traceback = debug.traceback
  8. local stringify = require "stringify"
  9. local nodelay
  10. local watchdog
  11. local connection = {} -- fd -> connection : { fd , client, agent , ip, mode }
  12. local forwarding = {} -- agent -> connection
  13. local client_number = 0
  14. local maxclient -- max client
  15. local function unforward(c)
  16. if c.agent then
  17. forwarding[c.agent] = nil
  18. c.agent = nil
  19. c.client = nil
  20. end
  21. end
  22. local function close_fd(fd)
  23. local c = connection[fd]
  24. if c then
  25. socketdriver.close(fd)
  26. unforward(c)
  27. connection[fd] = nil
  28. client_number = client_number - 1
  29. end
  30. end
  31. local handle = {}
  32. function handle.connect(fd)
  33. -- print("ws connect from: " .. tostring(fd))
  34. if client_number >= maxclient then
  35. socketdriver.close(fd)
  36. return
  37. end
  38. if nodelay then
  39. socketdriver.nodelay(fd)
  40. end
  41. client_number = client_number + 1
  42. local addr = websocket.addrinfo(fd)
  43. local c = {
  44. fd = fd,
  45. ip = addr,
  46. }
  47. connection[fd] = c
  48. skynet.send(watchdog, "lua", "socket", "open", fd, addr)
  49. end
  50. function handle.handshake(fd, header, url)
  51. -- local addr = websocket.addrinfo(fd)
  52. -- print("ws handshake from: " .. tostring(fd), "url", url, "addr:", addr)
  53. -- print("----header-----")
  54. -- for k,v in pairs(header) do
  55. -- print(k,v)
  56. -- end
  57. -- print("--------------")
  58. end
  59. function handle.message(fd, msg)
  60. -- print("ws ping from: " .. tostring(fd), msg.."\n")
  61. local sz = #msg
  62. -- recv a package, forward it
  63. local c = connection[fd]
  64. local agent = c.agent
  65. if agent then
  66. -- It's safe to redirect msg directly , gateserver framework will not free msg.
  67. skynet.redirect(agent, c.client, "client", fd, msg, sz)
  68. else
  69. skynet.send(watchdog, "lua", "socket", "data", fd, msg)
  70. -- skynet.tostring will copy msg to a string, so we must free msg here.
  71. skynet.trash(msg,sz)
  72. end
  73. end
  74. function handle.ping(fd)
  75. -- print("ws ping from: " .. tostring(fd) .. "\n")
  76. end
  77. function handle.pong(fd)
  78. -- print("ws pong from: " .. tostring(fd))
  79. end
  80. function handle.close(fd, code, reason)
  81. -- print("ws close from: " .. tostring(fd), code, reason)
  82. close_fd(fd)
  83. skynet.send(watchdog, "lua", "socket", "close", fd)
  84. end
  85. function handle.error(fd)
  86. -- print("ws error from: " .. tostring(fd))
  87. close_fd(fd)
  88. skynet.send(watchdog, "lua", "socket", "error", fd)
  89. end
  90. local CMD = {}
  91. function CMD.open(source, conf)
  92. watchdog = conf.watchdog or source
  93. local address = conf.address or "0.0.0.0"
  94. local port = assert(conf.port)
  95. local protocol = conf.protocol or "ws"
  96. maxclient = conf.maxclient or 1024
  97. nodelay = conf.nodelay
  98. local fd = socket.listen(address, port)
  99. logger.trace("Listen websocket port:%s protocol:%s", port, protocol)
  100. socket.start(fd, function(fd, addr)
  101. logger.trace(string.format("accept client socket_fd: %s addr:%s", fd, addr))
  102. websocket.accept(fd, handle, protocol, addr)
  103. end)
  104. end
  105. function CMD.forward(source, fd, client, address)
  106. local c = assert(connection[fd])
  107. unforward(c)
  108. c.client = client or 0
  109. c.agent = address or source
  110. forwarding[c.agent] = c
  111. end
  112. function CMD.accept(source, fd)
  113. local c = assert(connection[fd])
  114. unforward(c)
  115. end
  116. function CMD.kick(source, fd)
  117. websocket.close(fd)
  118. end
  119. skynet.register_protocol {
  120. name = "client",
  121. id = skynet.PTYPE_CLIENT,
  122. }
  123. skynet.start(function()
  124. skynet.dispatch("lua", function(session, source, cmd, ...)
  125. local f = CMD[cmd]
  126. if not f then
  127. skynet.error("simplewebsocket can't dispatch cmd ".. (cmd or nil))
  128. skynet.ret(skynet.pack({ok=false}))
  129. return
  130. end
  131. if session == 0 then
  132. f(source, ...)
  133. else
  134. skynet.ret(skynet.pack(f(source, ...)))
  135. end
  136. end)
  137. skynet.register(".ws_gate")
  138. skynet.error("ws_gate booted...")
  139. end)