websocket.lua 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. local internal = require "http.internal"
  2. local socket = require "skynet.socket"
  3. local crypt = require "skynet.crypt"
  4. local httpd = require "http.httpd"
  5. local skynet = require "skynet"
  6. local sockethelper = require "http.sockethelper"
  7. local socket_error = sockethelper.socket_error
  8. local logger = require "logger"
  9. local GLOBAL_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
  10. local MAX_FRAME_SIZE = 256 * 1024 -- max frame is 256K
  11. local assert = assert
  12. local pairs = pairs
  13. local error = error
  14. local string = string
  15. local xpcall = xpcall
  16. local debug = debug
  17. local table = table
  18. local tonumber = tonumber
  19. local M = {}
  20. local ws_pool = {}
  21. local function _close_websocket(ws_obj)
  22. local id = ws_obj.id
  23. assert(ws_pool[id] == ws_obj)
  24. ws_pool[id] = nil
  25. ws_obj.close()
  26. end
  27. local function _isws_closed(id)
  28. return not ws_pool[id]
  29. end
  30. local function reader_with_payload(self, payload)
  31. local sz_payload = #payload
  32. if sz_payload == 0 then
  33. return
  34. end
  35. local read = self.read
  36. function self.read (sz)
  37. if sz == nil or sz == sz_payload then
  38. self.read = read
  39. return payload
  40. end
  41. if sz < sz_payload then
  42. local ret = payload:sub(1, sz)
  43. payload = payload:sub(sz + 1)
  44. sz_payload = #payload
  45. return ret
  46. end
  47. self.read = read
  48. return payload .. read(sz - sz_payload)
  49. end
  50. end
  51. local function write_handshake(self, host, url, header)
  52. local key = crypt.base64encode(crypt.randomkey()..crypt.randomkey())
  53. local request_header = {
  54. ["Upgrade"] = "websocket",
  55. ["Connection"] = "Upgrade",
  56. ["Sec-WebSocket-Version"] = "13",
  57. ["Sec-WebSocket-Key"] = key
  58. }
  59. if header then
  60. for k,v in pairs(header) do
  61. assert(request_header[k] == nil, k)
  62. request_header[k] = v
  63. end
  64. end
  65. local recvheader = {}
  66. local code, payload = internal.request(self, "GET", host, url, recvheader, request_header)
  67. if code ~= 101 then
  68. error(string.format("websocket handshake error: code[%s] info:%s", code, payload))
  69. end
  70. reader_with_payload(self, payload)
  71. if not recvheader["upgrade"] or recvheader["upgrade"]:lower() ~= "websocket" then
  72. error("websocket handshake upgrade must websocket")
  73. end
  74. if not recvheader["connection"] or recvheader["connection"]:lower() ~= "upgrade" then
  75. error("websocket handshake connection must upgrade")
  76. end
  77. local sw_key = recvheader["sec-websocket-accept"]
  78. if not sw_key then
  79. error("websocket handshake need Sec-WebSocket-Accept")
  80. end
  81. local guid = self.guid
  82. sw_key = crypt.base64decode(sw_key)
  83. if sw_key ~= crypt.sha1(key .. guid) then
  84. error("websocket handshake invalid Sec-WebSocket-Accept")
  85. end
  86. end
  87. local function read_handshake(self, upgrade_ops)
  88. local header, method, url
  89. if upgrade_ops then
  90. header, method, url = upgrade_ops.header, upgrade_ops.method, upgrade_ops.url
  91. else
  92. local tmpline = {}
  93. local payload = internal.recvheader(self.read, tmpline, "")
  94. if not payload then
  95. return 413
  96. end
  97. reader_with_payload(self, payload)
  98. local request = assert(tmpline[1])
  99. local httpver
  100. method, url, httpver = request:match "^(%a+)%s+(.-)%s+HTTP/([%d%.]+)$"
  101. assert(method and url and httpver)
  102. if method ~= "GET" then
  103. return 400, "need GET method"
  104. end
  105. httpver = assert(tonumber(httpver))
  106. if httpver < 1.1 then
  107. return 505 -- HTTP Version not supported
  108. end
  109. header = internal.parseheader(tmpline, 2, {})
  110. end
  111. if not header then
  112. return 400 -- Bad request
  113. end
  114. if not header["upgrade"] or header["upgrade"]:lower() ~= "websocket" then
  115. return 426, "Upgrade Required"
  116. end
  117. if not header["host"] then
  118. return 400, "host Required"
  119. end
  120. if not header["connection"] or not header["connection"]:lower():find("upgrade", 1,true) then
  121. return 400, "Connection must Upgrade"
  122. end
  123. local sw_key = header["sec-websocket-key"]
  124. if not sw_key then
  125. return 400, "Sec-WebSocket-Key Required"
  126. else
  127. local raw_key = crypt.base64decode(sw_key)
  128. if #raw_key ~= 16 then
  129. return 400, "Sec-WebSocket-Key invalid"
  130. end
  131. end
  132. if not header["sec-websocket-version"] or header["sec-websocket-version"] ~= "13" then
  133. return 400, "Sec-WebSocket-Version must 13"
  134. end
  135. local sw_protocol = header["sec-websocket-protocol"]
  136. local sub_pro = ""
  137. if sw_protocol then
  138. local has_chat = false
  139. for sub_protocol in string.gmatch(sw_protocol, "[^%s,]+") do
  140. if sub_protocol == "chat" then
  141. sub_pro = "Sec-WebSocket-Protocol: chat\r\n"
  142. has_chat = true
  143. break
  144. end
  145. end
  146. if not has_chat then
  147. return 400, "Sec-WebSocket-Protocol need include chat"
  148. end
  149. end
  150. -- read 'x-real-ip' header from nginx
  151. self.real_ip = header["x-real-ip"]
  152. -- response handshake
  153. local accept = crypt.base64encode(crypt.sha1(sw_key .. self.guid))
  154. local resp = "HTTP/1.1 101 Switching Protocols\r\n"..
  155. "Upgrade: websocket\r\n"..
  156. "Connection: Upgrade\r\n"..
  157. string.format("Sec-WebSocket-Accept: %s\r\n", accept)..
  158. sub_pro ..
  159. "\r\n"
  160. self.write(resp)
  161. return nil, header, url
  162. end
  163. local function try_handle(self, method, ...)
  164. local handle = self.handle
  165. local f = handle and handle[method]
  166. if f then
  167. f(self.id, ...)
  168. end
  169. end
  170. local op_code = {
  171. ["frame"] = 0x00,
  172. ["text"] = 0x01,
  173. ["binary"] = 0x02,
  174. ["close"] = 0x08,
  175. ["ping"] = 0x09,
  176. ["pong"] = 0x0A,
  177. [0x00] = "frame",
  178. [0x01] = "text",
  179. [0x02] = "binary",
  180. [0x08] = "close",
  181. [0x09] = "ping",
  182. [0x0A] = "pong",
  183. }
  184. local function write_frame(self, op, payload_data, masking_key)
  185. payload_data = payload_data or ""
  186. local payload_len = #payload_data
  187. local op_v = assert(op_code[op])
  188. local v1 = 0x80 | op_v -- fin is 1 with opcode
  189. local s
  190. local mask = masking_key and 0x80 or 0x00
  191. -- mask set to 0
  192. if payload_len < 126 then
  193. s = string.pack("I1I1", v1, mask | payload_len)
  194. elseif payload_len <= 0xffff then
  195. s = string.pack("I1I1>I2", v1, mask | 126, payload_len)
  196. else
  197. s = string.pack("I1I1>I8", v1, mask | 127, payload_len)
  198. end
  199. self.write(s)
  200. -- write masking_key
  201. if masking_key then
  202. s = string.pack(">I4", masking_key)
  203. self.write(s)
  204. payload_data = crypt.xor_str(payload_data, s)
  205. end
  206. if payload_len > 0 then
  207. self.write(payload_data)
  208. end
  209. end
  210. local function read_close(payload_data)
  211. local code, reason
  212. local payload_len = #payload_data
  213. if payload_len > 2 then
  214. local fmt = string.format(">I2c%d", payload_len - 2)
  215. code, reason = string.unpack(fmt, payload_data)
  216. end
  217. return code, reason
  218. end
  219. local function read_frame(self)
  220. local s = self.read(2)
  221. local v1, v2 = string.unpack("I1I1", s)
  222. local fin = (v1 & 0x80) ~= 0
  223. -- unused flag
  224. -- local rsv1 = (v1 & 0x40) ~= 0
  225. -- local rsv2 = (v1 & 0x20) ~= 0
  226. -- local rsv3 = (v1 & 0x10) ~= 0
  227. local op = v1 & 0x0f
  228. local mask = (v2 & 0x80) ~= 0
  229. local payload_len = (v2 & 0x7f)
  230. if payload_len == 126 then
  231. s = self.read(2)
  232. payload_len = string.unpack(">I2", s)
  233. elseif payload_len == 127 then
  234. s = self.read(8)
  235. payload_len = string.unpack(">I8", s)
  236. end
  237. if self.mode == "server" and payload_len > MAX_FRAME_SIZE then
  238. error("payload_len is too large")
  239. end
  240. -- print(string.format("fin:%s, op:%s, mask:%s, payload_len:%s", fin, op_code[op], mask, payload_len))
  241. local masking_key = mask and self.read(4) or false
  242. local payload_data = payload_len>0 and self.read(payload_len) or ""
  243. payload_data = masking_key and crypt.xor_str(payload_data, masking_key) or payload_data
  244. return fin, assert(op_code[op]), payload_data
  245. end
  246. local function resolve_accept(self, options)
  247. try_handle(self, "connect")
  248. local code, err, url = read_handshake(self, options and options.upgrade)
  249. if code then
  250. local ok, s = httpd.write_response(self.write, code, err)
  251. if not ok then
  252. error(s)
  253. end
  254. try_handle(self, "close")
  255. return
  256. end
  257. local header = err
  258. try_handle(self, "handshake", header, url)
  259. local recv_count = 0
  260. local recv_buf = {}
  261. local first_op
  262. while true do
  263. if _isws_closed(self.id) then
  264. try_handle(self, "close")
  265. return
  266. end
  267. local fin, op, payload_data = read_frame(self)
  268. if op == "close" then
  269. local code, reason = read_close(payload_data)
  270. write_frame(self, "close")
  271. try_handle(self, "close", code, reason)
  272. break
  273. elseif op == "ping" then
  274. write_frame(self, "pong", payload_data)
  275. try_handle(self, "ping")
  276. elseif op == "pong" then
  277. try_handle(self, "pong")
  278. else
  279. if fin and #recv_buf == 0 then
  280. try_handle(self, "message", payload_data, op)
  281. else
  282. recv_buf[#recv_buf+1] = payload_data
  283. recv_count = recv_count + #payload_data
  284. if recv_count > MAX_FRAME_SIZE then
  285. error("payload_len is too large")
  286. end
  287. first_op = first_op or op
  288. if fin then
  289. local s = table.concat(recv_buf)
  290. try_handle(self, "message", s, first_op)
  291. recv_buf = {} -- clear recv_buf
  292. recv_count = 0
  293. first_op = nil
  294. end
  295. end
  296. end
  297. end
  298. end
  299. local SSLCTX_CLIENT = nil
  300. local function _new_client_ws(socket_id, protocol, hostname)
  301. local obj
  302. if protocol == "ws" then
  303. obj = {
  304. close = function ()
  305. socket.close(socket_id)
  306. end,
  307. read = sockethelper.readfunc(socket_id),
  308. write = sockethelper.writefunc(socket_id),
  309. readall = function ()
  310. return socket.readall(socket_id)
  311. end,
  312. }
  313. elseif protocol == "wss" then
  314. local tls = require "http.tlshelper"
  315. SSLCTX_CLIENT = SSLCTX_CLIENT or tls.newctx()
  316. local tls_ctx = tls.newtls("client", SSLCTX_CLIENT, hostname)
  317. local init = tls.init_requestfunc(socket_id, tls_ctx)
  318. init()
  319. obj = {
  320. close = function ()
  321. socket.close(socket_id)
  322. tls.closefunc(tls_ctx)()
  323. end,
  324. read = tls.readfunc(socket_id, tls_ctx),
  325. write = tls.writefunc(socket_id, tls_ctx),
  326. readall = tls.readallfunc(socket_id, tls_ctx),
  327. }
  328. else
  329. error(string.format("invalid websocket protocol:%s", tostring(protocol)))
  330. end
  331. obj.mode = "client"
  332. obj.id = assert(socket_id)
  333. obj.guid = GLOBAL_GUID
  334. ws_pool[socket_id] = obj
  335. return obj
  336. end
  337. local SSLCTX_SERVER = nil
  338. local function _new_server_ws(socket_id, handle, protocol)
  339. local obj
  340. if protocol == "ws" then
  341. obj = {
  342. close = function ()
  343. socket.close(socket_id)
  344. end,
  345. read = sockethelper.readfunc(socket_id),
  346. write = sockethelper.writefunc(socket_id),
  347. }
  348. elseif protocol == "wss" then
  349. local tls = require "http.tlshelper"
  350. if not SSLCTX_SERVER then
  351. SSLCTX_SERVER = tls.newctx()
  352. -- gen cert and key
  353. -- openssl req -x509 -newkey rsa:2048 -days 3650 -nodes -keyout server-key.pem -out server-cert.pem
  354. local certfile = skynet.getenv("certfile") or "./server-cert.pem"
  355. local keyfile = skynet.getenv("keyfile") or "./server-key.pem"
  356. SSLCTX_SERVER:set_cert(certfile, keyfile)
  357. end
  358. local tls_ctx = tls.newtls("server", SSLCTX_SERVER)
  359. local init = tls.init_responsefunc(socket_id, tls_ctx)
  360. init()
  361. obj = {
  362. close = function ()
  363. socket.close(socket_id)
  364. tls.closefunc(tls_ctx)()
  365. end,
  366. read = tls.readfunc(socket_id, tls_ctx),
  367. write = tls.writefunc(socket_id, tls_ctx),
  368. }
  369. else
  370. error(string.format("invalid websocket protocol:%s", tostring(protocol)))
  371. end
  372. obj.mode = "server"
  373. obj.id = assert(socket_id)
  374. obj.handle = handle
  375. obj.guid = GLOBAL_GUID
  376. ws_pool[socket_id] = obj
  377. return obj
  378. end
  379. -- handle interface
  380. -- connect / handshake / message / ping / pong / close / error
  381. function M.accept(socket_id, handle, protocol, addr, options)
  382. if not (options and options.upgrade) then
  383. socket.start(socket_id)
  384. end
  385. protocol = protocol or "ws"
  386. local ws_obj = _new_server_ws(socket_id, handle, protocol)
  387. ws_obj.addr = addr
  388. local on_warning = handle and handle["warning"]
  389. if on_warning then
  390. socket.warning(socket_id, function (id, sz)
  391. on_warning(ws_obj, sz)
  392. end)
  393. end
  394. local ok, err = xpcall(resolve_accept, debug.traceback, ws_obj, options)
  395. local closed = _isws_closed(socket_id)
  396. if not closed then
  397. _close_websocket(ws_obj)
  398. end
  399. if not ok then
  400. if err == socket_error then
  401. if closed then
  402. try_handle(ws_obj, "close")
  403. else
  404. try_handle(ws_obj, "error")
  405. end
  406. else
  407. -- error(err)
  408. return false, err
  409. end
  410. end
  411. return true
  412. end
  413. function M.connect(url, header, timeout)
  414. local protocol, host, uri = string.match(url, "^(wss?)://([^/]+)(.*)$")
  415. if protocol ~= "wss" and protocol ~= "ws" then
  416. error(string.format("invalid protocol: %s", protocol))
  417. end
  418. assert(host)
  419. local host_addr, host_port = string.match(host, "^([^:]+):?(%d*)$")
  420. assert(host_addr and host_port)
  421. if host_port == "" then
  422. host_port = protocol == "ws" and 80 or 443
  423. end
  424. local hostname
  425. if not host_addr:match(".*%d+$") then
  426. hostname = host_addr
  427. end
  428. uri = uri == "" and "/" or uri
  429. local socket_id = sockethelper.connect(host_addr, host_port, timeout)
  430. local ws_obj = _new_client_ws(socket_id, protocol, hostname)
  431. ws_obj.addr = host
  432. write_handshake(ws_obj, host_addr, uri, header)
  433. return socket_id
  434. end
  435. function M.read(id)
  436. local ws_obj = assert(ws_pool[id])
  437. local recv_buf
  438. while true do
  439. local fin, op, payload_data = read_frame(ws_obj)
  440. if op == "close" then
  441. _close_websocket(ws_obj)
  442. return false, payload_data
  443. elseif op == "ping" then
  444. write_frame(ws_obj, "pong", payload_data)
  445. elseif op ~= "pong" then -- op is frame, text binary
  446. if fin and not recv_buf then
  447. return payload_data
  448. else
  449. recv_buf = recv_buf or {}
  450. recv_buf[#recv_buf+1] = payload_data
  451. if fin then
  452. local s = table.concat(recv_buf)
  453. return s
  454. end
  455. end
  456. end
  457. end
  458. end
  459. function M.write(id, data, fmt, masking_key)
  460. local ws_obj = assert(ws_pool[id], id)
  461. fmt = fmt or "text"
  462. assert(fmt == "text" or fmt == "binary")
  463. write_frame(ws_obj, fmt, data, masking_key)
  464. end
  465. function M.ping(id)
  466. local ws_obj = assert(ws_pool[id])
  467. write_frame(ws_obj, "ping")
  468. end
  469. function M.addrinfo(id)
  470. local ws_obj = assert(ws_pool[id])
  471. return ws_obj.addr
  472. end
  473. function M.real_ip(id)
  474. local ws_obj = assert(ws_pool[id])
  475. return ws_obj.real_ip
  476. end
  477. function M.close(id, code ,reason)
  478. local ws_obj = ws_pool[id]
  479. if not ws_obj then
  480. return
  481. end
  482. local ok, err = xpcall(function ()
  483. reason = reason or ""
  484. local payload_data
  485. if code then
  486. local fmt =string.format(">I2c%d", #reason)
  487. payload_data = string.pack(fmt, code, reason)
  488. end
  489. write_frame(ws_obj, "close", payload_data)
  490. end, debug.traceback)
  491. _close_websocket(ws_obj)
  492. if not ok then
  493. skynet.error(err)
  494. end
  495. end
  496. M.is_close = _isws_closed
  497. function M.forward(fd, protocol, addr)
  498. protocol = protocol or "ws"
  499. local ws_obj = _new_server_ws(fd, {}, protocol)
  500. ws_obj.addr = addr
  501. end
  502. function M.clear_pool(id)
  503. local ws_obj = ws_pool[id]
  504. if not ws_obj then
  505. return
  506. end
  507. ws_pool[id] = nil
  508. end
  509. return M