msgserver.lua 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. local skynet = require "skynet"
  2. local gateserver = require "snax.gateserver"
  3. local netpack = require "skynet.netpack"
  4. local crypt = require "skynet.crypt"
  5. local socketdriver = require "skynet.socketdriver"
  6. local assert = assert
  7. local b64encode = crypt.base64encode
  8. local b64decode = crypt.base64decode
  9. --[[
  10. Protocol:
  11. All the number type is big-endian
  12. Shakehands (The first package)
  13. Client -> Server :
  14. base64(uid)@base64(server)#base64(subid):index:base64(hmac)
  15. Server -> Client
  16. XXX ErrorCode
  17. 404 User Not Found
  18. 403 Index Expired
  19. 401 Unauthorized
  20. 400 Bad Request
  21. 200 OK
  22. Req-Resp
  23. Client -> Server : Request
  24. word size (Not include self)
  25. string content (size-4)
  26. dword session
  27. Server -> Client : Response
  28. word size (Not include self)
  29. string content (size-5)
  30. byte ok (1 is ok, 0 is error)
  31. dword session
  32. API:
  33. server.userid(username)
  34. return uid, subid, server
  35. server.username(uid, subid, server)
  36. return username
  37. server.login(username, secret)
  38. update user secret
  39. server.logout(username)
  40. user logout
  41. server.ip(username)
  42. return ip when connection establish, or nil
  43. server.start(conf)
  44. start server
  45. Supported skynet command:
  46. kick username (may used by loginserver)
  47. login username secret (used by loginserver)
  48. logout username (used by agent)
  49. Config for server.start:
  50. conf.expired_number : the number of the response message cached after sending out (default is 128)
  51. conf.login_handler(uid, secret) -> subid : the function when a new user login, alloc a subid for it. (may call by login server)
  52. conf.logout_handler(uid, subid) : the functon when a user logout. (may call by agent)
  53. conf.kick_handler(uid, subid) : the functon when a user logout. (may call by login server)
  54. conf.request_handler(username, session, msg) : the function when recv a new request.
  55. conf.register_handler(servername) : call when gate open
  56. conf.disconnect_handler(username) : call when a connection disconnect (afk)
  57. ]]
  58. local server = {}
  59. skynet.register_protocol {
  60. name = "client",
  61. id = skynet.PTYPE_CLIENT,
  62. }
  63. local user_online = {}
  64. local handshake = {}
  65. local connection = {}
  66. function server.userid(username)
  67. -- base64(uid)@base64(server)#base64(subid)
  68. local uid, servername, subid = username:match "([^@]*)@([^#]*)#(.*)"
  69. return b64decode(uid), b64decode(subid), b64decode(servername)
  70. end
  71. function server.username(uid, subid, servername)
  72. return string.format("%s@%s#%s", b64encode(uid), b64encode(servername), b64encode(tostring(subid)))
  73. end
  74. function server.logout(username)
  75. local u = user_online[username]
  76. user_online[username] = nil
  77. if u.fd then
  78. if connection[u.fd] then
  79. gateserver.closeclient(u.fd)
  80. connection[u.fd] = nil
  81. end
  82. end
  83. end
  84. function server.login(username, secret)
  85. assert(user_online[username] == nil)
  86. user_online[username] = {
  87. secret = secret,
  88. version = 0,
  89. index = 0,
  90. username = username,
  91. response = {}, -- response cache
  92. }
  93. end
  94. function server.ip(username)
  95. local u = user_online[username]
  96. if u and u.fd then
  97. return u.ip
  98. end
  99. end
  100. function server.start(conf)
  101. local expired_number = conf.expired_number or 128
  102. local handler = {}
  103. local CMD = {
  104. login = assert(conf.login_handler),
  105. logout = assert(conf.logout_handler),
  106. kick = assert(conf.kick_handler),
  107. }
  108. function handler.command(cmd, source, ...)
  109. local f = assert(CMD[cmd])
  110. return f(...)
  111. end
  112. function handler.open(source, gateconf)
  113. local servername = assert(gateconf.servername)
  114. return conf.register_handler(servername)
  115. end
  116. function handler.connect(fd, addr)
  117. handshake[fd] = addr
  118. gateserver.openclient(fd)
  119. end
  120. function handler.disconnect(fd)
  121. handshake[fd] = nil
  122. local c = connection[fd]
  123. if c then
  124. if conf.disconnect_handler then
  125. conf.disconnect_handler(c.username)
  126. end
  127. -- double check, conf.disconnect_handler may close fd
  128. if connection[fd] then
  129. c.fd = nil
  130. connection[fd] = nil
  131. gateserver.closeclient(fd)
  132. end
  133. end
  134. end
  135. handler.error = handler.disconnect
  136. -- atomic , no yield
  137. local function do_auth(fd, message, addr)
  138. local username, index, hmac = string.match(message, "([^:]*):([^:]*):([^:]*)")
  139. local u = user_online[username]
  140. if u == nil then
  141. return "404 User Not Found"
  142. end
  143. local idx = assert(tonumber(index))
  144. hmac = b64decode(hmac)
  145. if idx <= u.version then
  146. return "403 Index Expired"
  147. end
  148. local text = string.format("%s:%s", username, index)
  149. local v = crypt.hmac_hash(u.secret, text) -- equivalent to crypt.hmac64(crypt.hashkey(text), u.secret)
  150. if v ~= hmac then
  151. return "401 Unauthorized"
  152. end
  153. u.version = idx
  154. u.fd = fd
  155. u.ip = addr
  156. connection[fd] = u
  157. end
  158. local function auth(fd, addr, msg, sz)
  159. local message = netpack.tostring(msg, sz)
  160. local ok, result = pcall(do_auth, fd, message, addr)
  161. if not ok then
  162. skynet.error(result)
  163. result = "400 Bad Request"
  164. end
  165. local close = result ~= nil
  166. if result == nil then
  167. result = "200 OK"
  168. end
  169. socketdriver.send(fd, netpack.pack(result))
  170. if close then
  171. gateserver.closeclient(fd)
  172. end
  173. end
  174. local request_handler = assert(conf.request_handler)
  175. -- u.response is a struct { return_fd , response, version, index }
  176. local function retire_response(u)
  177. if u.index >= expired_number * 2 then
  178. local max = 0
  179. local response = u.response
  180. for k,p in pairs(response) do
  181. if p[1] == nil then
  182. -- request complete, check expired
  183. if p[4] < expired_number then
  184. response[k] = nil
  185. else
  186. p[4] = p[4] - expired_number
  187. if p[4] > max then
  188. max = p[4]
  189. end
  190. end
  191. end
  192. end
  193. u.index = max + 1
  194. end
  195. end
  196. local function do_request(fd, message)
  197. local u = assert(connection[fd], "invalid fd")
  198. local session = string.unpack(">I4", message, -4)
  199. message = message:sub(1,-5)
  200. local p = u.response[session]
  201. if p then
  202. -- session can be reuse in the same connection
  203. if p[3] == u.version then
  204. local last = u.response[session]
  205. u.response[session] = nil
  206. p = nil
  207. if last[2] == nil then
  208. local error_msg = string.format("Conflict session %s", crypt.hexencode(session))
  209. skynet.error(error_msg)
  210. error(error_msg)
  211. end
  212. end
  213. end
  214. if p == nil then
  215. p = { fd }
  216. u.response[session] = p
  217. local ok, result = pcall(request_handler, u.username, message)
  218. -- NOTICE: YIELD here, socket may close.
  219. result = result or ""
  220. if not ok then
  221. skynet.error(result)
  222. result = string.pack(">BI4", 0, session)
  223. else
  224. result = result .. string.pack(">BI4", 1, session)
  225. end
  226. p[2] = string.pack(">s2",result)
  227. p[3] = u.version
  228. p[4] = u.index
  229. else
  230. -- update version/index, change return fd.
  231. -- resend response.
  232. p[1] = fd
  233. p[3] = u.version
  234. p[4] = u.index
  235. if p[2] == nil then
  236. -- already request, but response is not ready
  237. return
  238. end
  239. end
  240. u.index = u.index + 1
  241. -- the return fd is p[1] (fd may change by multi request) check connect
  242. fd = p[1]
  243. if connection[fd] then
  244. socketdriver.send(fd, p[2])
  245. end
  246. p[1] = nil
  247. retire_response(u)
  248. end
  249. local function request(fd, msg, sz)
  250. local message = netpack.tostring(msg, sz)
  251. local ok, err = pcall(do_request, fd, message)
  252. -- not atomic, may yield
  253. if not ok then
  254. skynet.error(string.format("Invalid package %s : %s", err, message))
  255. if connection[fd] then
  256. gateserver.closeclient(fd)
  257. end
  258. end
  259. end
  260. function handler.message(fd, msg, sz)
  261. local addr = handshake[fd]
  262. if addr then
  263. auth(fd,addr,msg,sz)
  264. handshake[fd] = nil
  265. else
  266. request(fd, msg, sz)
  267. end
  268. end
  269. return gateserver.start(handler)
  270. end
  271. return server