123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577 |
- local internal = require "http.internal"
- local socket = require "skynet.socket"
- local crypt = require "skynet.crypt"
- local httpd = require "http.httpd"
- local skynet = require "skynet"
- local sockethelper = require "http.sockethelper"
- local socket_error = sockethelper.socket_error
- local logger = require "logger"
- local GLOBAL_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
- local MAX_FRAME_SIZE = 256 * 1024 -- max frame is 256K
- local assert = assert
- local pairs = pairs
- local error = error
- local string = string
- local xpcall = xpcall
- local debug = debug
- local table = table
- local tonumber = tonumber
- local M = {}
- local ws_pool = {}
- local function _close_websocket(ws_obj)
- local id = ws_obj.id
- assert(ws_pool[id] == ws_obj)
- ws_pool[id] = nil
- ws_obj.close()
- end
- local function _isws_closed(id)
- return not ws_pool[id]
- end
- local function reader_with_payload(self, payload)
- local sz_payload = #payload
- if sz_payload == 0 then
- return
- end
- local read = self.read
- function self.read (sz)
- if sz == nil or sz == sz_payload then
- self.read = read
- return payload
- end
- if sz < sz_payload then
- local ret = payload:sub(1, sz)
- payload = payload:sub(sz + 1)
- sz_payload = #payload
- return ret
- end
- self.read = read
- return payload .. read(sz - sz_payload)
- end
- end
- local function write_handshake(self, host, url, header)
- local key = crypt.base64encode(crypt.randomkey()..crypt.randomkey())
- local request_header = {
- ["Upgrade"] = "websocket",
- ["Connection"] = "Upgrade",
- ["Sec-WebSocket-Version"] = "13",
- ["Sec-WebSocket-Key"] = key
- }
- if header then
- for k,v in pairs(header) do
- assert(request_header[k] == nil, k)
- request_header[k] = v
- end
- end
- local recvheader = {}
- local code, payload = internal.request(self, "GET", host, url, recvheader, request_header)
- if code ~= 101 then
- error(string.format("websocket handshake error: code[%s] info:%s", code, payload))
- end
- reader_with_payload(self, payload)
- if not recvheader["upgrade"] or recvheader["upgrade"]:lower() ~= "websocket" then
- error("websocket handshake upgrade must websocket")
- end
- if not recvheader["connection"] or recvheader["connection"]:lower() ~= "upgrade" then
- error("websocket handshake connection must upgrade")
- end
- local sw_key = recvheader["sec-websocket-accept"]
- if not sw_key then
- error("websocket handshake need Sec-WebSocket-Accept")
- end
- local guid = self.guid
- sw_key = crypt.base64decode(sw_key)
- if sw_key ~= crypt.sha1(key .. guid) then
- error("websocket handshake invalid Sec-WebSocket-Accept")
- end
- end
- local function read_handshake(self, upgrade_ops)
- local header, method, url
- if upgrade_ops then
- header, method, url = upgrade_ops.header, upgrade_ops.method, upgrade_ops.url
- else
- local tmpline = {}
- local payload = internal.recvheader(self.read, tmpline, "")
- if not payload then
- return 413
- end
- reader_with_payload(self, payload)
- local request = assert(tmpline[1])
- local httpver
- method, url, httpver = request:match "^(%a+)%s+(.-)%s+HTTP/([%d%.]+)$"
- assert(method and url and httpver)
- if method ~= "GET" then
- return 400, "need GET method"
- end
- httpver = assert(tonumber(httpver))
- if httpver < 1.1 then
- return 505 -- HTTP Version not supported
- end
- header = internal.parseheader(tmpline, 2, {})
- end
- if not header then
- return 400 -- Bad request
- end
- if not header["upgrade"] or header["upgrade"]:lower() ~= "websocket" then
- return 426, "Upgrade Required"
- end
- if not header["host"] then
- return 400, "host Required"
- end
- if not header["connection"] or not header["connection"]:lower():find("upgrade", 1,true) then
- return 400, "Connection must Upgrade"
- end
- local sw_key = header["sec-websocket-key"]
- if not sw_key then
- return 400, "Sec-WebSocket-Key Required"
- else
- local raw_key = crypt.base64decode(sw_key)
- if #raw_key ~= 16 then
- return 400, "Sec-WebSocket-Key invalid"
- end
- end
- if not header["sec-websocket-version"] or header["sec-websocket-version"] ~= "13" then
- return 400, "Sec-WebSocket-Version must 13"
- end
- local sw_protocol = header["sec-websocket-protocol"]
- local sub_pro = ""
- if sw_protocol then
- local has_chat = false
- for sub_protocol in string.gmatch(sw_protocol, "[^%s,]+") do
- if sub_protocol == "chat" then
- sub_pro = "Sec-WebSocket-Protocol: chat\r\n"
- has_chat = true
- break
- end
- end
- if not has_chat then
- return 400, "Sec-WebSocket-Protocol need include chat"
- end
- end
- -- read 'x-real-ip' header from nginx
- self.real_ip = header["x-real-ip"]
- -- response handshake
- local accept = crypt.base64encode(crypt.sha1(sw_key .. self.guid))
- local resp = "HTTP/1.1 101 Switching Protocols\r\n"..
- "Upgrade: websocket\r\n"..
- "Connection: Upgrade\r\n"..
- string.format("Sec-WebSocket-Accept: %s\r\n", accept)..
- sub_pro ..
- "\r\n"
- self.write(resp)
- return nil, header, url
- end
- local function try_handle(self, method, ...)
- local handle = self.handle
- local f = handle and handle[method]
- if f then
- f(self.id, ...)
- end
- end
- local op_code = {
- ["frame"] = 0x00,
- ["text"] = 0x01,
- ["binary"] = 0x02,
- ["close"] = 0x08,
- ["ping"] = 0x09,
- ["pong"] = 0x0A,
- [0x00] = "frame",
- [0x01] = "text",
- [0x02] = "binary",
- [0x08] = "close",
- [0x09] = "ping",
- [0x0A] = "pong",
- }
- local function write_frame(self, op, payload_data, masking_key)
- payload_data = payload_data or ""
- local payload_len = #payload_data
- local op_v = assert(op_code[op])
- local v1 = 0x80 | op_v -- fin is 1 with opcode
- local s
- local mask = masking_key and 0x80 or 0x00
- -- mask set to 0
- if payload_len < 126 then
- s = string.pack("I1I1", v1, mask | payload_len)
- elseif payload_len <= 0xffff then
- s = string.pack("I1I1>I2", v1, mask | 126, payload_len)
- else
- s = string.pack("I1I1>I8", v1, mask | 127, payload_len)
- end
- self.write(s)
- -- write masking_key
- if masking_key then
- s = string.pack(">I4", masking_key)
- self.write(s)
- payload_data = crypt.xor_str(payload_data, s)
- end
- if payload_len > 0 then
- self.write(payload_data)
- end
- end
- local function read_close(payload_data)
- local code, reason
- local payload_len = #payload_data
- if payload_len > 2 then
- local fmt = string.format(">I2c%d", payload_len - 2)
- code, reason = string.unpack(fmt, payload_data)
- end
- return code, reason
- end
- local function read_frame(self)
- local s = self.read(2)
- local v1, v2 = string.unpack("I1I1", s)
- local fin = (v1 & 0x80) ~= 0
- -- unused flag
- -- local rsv1 = (v1 & 0x40) ~= 0
- -- local rsv2 = (v1 & 0x20) ~= 0
- -- local rsv3 = (v1 & 0x10) ~= 0
- local op = v1 & 0x0f
- local mask = (v2 & 0x80) ~= 0
- local payload_len = (v2 & 0x7f)
- if payload_len == 126 then
- s = self.read(2)
- payload_len = string.unpack(">I2", s)
- elseif payload_len == 127 then
- s = self.read(8)
- payload_len = string.unpack(">I8", s)
- end
- if self.mode == "server" and payload_len > MAX_FRAME_SIZE then
- error("payload_len is too large")
- end
- -- print(string.format("fin:%s, op:%s, mask:%s, payload_len:%s", fin, op_code[op], mask, payload_len))
- local masking_key = mask and self.read(4) or false
- local payload_data = payload_len>0 and self.read(payload_len) or ""
- payload_data = masking_key and crypt.xor_str(payload_data, masking_key) or payload_data
- return fin, assert(op_code[op]), payload_data
- end
- local function resolve_accept(self, options)
- try_handle(self, "connect")
- local code, err, url = read_handshake(self, options and options.upgrade)
- if code then
- local ok, s = httpd.write_response(self.write, code, err)
- if not ok then
- error(s)
- end
- try_handle(self, "close")
- return
- end
- local header = err
- try_handle(self, "handshake", header, url)
- local recv_count = 0
- local recv_buf = {}
- local first_op
- while true do
- if _isws_closed(self.id) then
- try_handle(self, "close")
- return
- end
- local fin, op, payload_data = read_frame(self)
- if op == "close" then
- local code, reason = read_close(payload_data)
- write_frame(self, "close")
- try_handle(self, "close", code, reason)
- break
- elseif op == "ping" then
- write_frame(self, "pong", payload_data)
- try_handle(self, "ping")
- elseif op == "pong" then
- try_handle(self, "pong")
- else
- if fin and #recv_buf == 0 then
- try_handle(self, "message", payload_data, op)
- else
- recv_buf[#recv_buf+1] = payload_data
- recv_count = recv_count + #payload_data
- if recv_count > MAX_FRAME_SIZE then
- error("payload_len is too large")
- end
- first_op = first_op or op
- if fin then
- local s = table.concat(recv_buf)
- try_handle(self, "message", s, first_op)
- recv_buf = {} -- clear recv_buf
- recv_count = 0
- first_op = nil
- end
- end
- end
- end
- end
- local SSLCTX_CLIENT = nil
- local function _new_client_ws(socket_id, protocol, hostname)
- local obj
- if protocol == "ws" then
- obj = {
- close = function ()
- socket.close(socket_id)
- end,
- read = sockethelper.readfunc(socket_id),
- write = sockethelper.writefunc(socket_id),
- readall = function ()
- return socket.readall(socket_id)
- end,
- }
- elseif protocol == "wss" then
- local tls = require "http.tlshelper"
- SSLCTX_CLIENT = SSLCTX_CLIENT or tls.newctx()
- local tls_ctx = tls.newtls("client", SSLCTX_CLIENT, hostname)
- local init = tls.init_requestfunc(socket_id, tls_ctx)
- init()
- obj = {
- close = function ()
- socket.close(socket_id)
- tls.closefunc(tls_ctx)()
- end,
- read = tls.readfunc(socket_id, tls_ctx),
- write = tls.writefunc(socket_id, tls_ctx),
- readall = tls.readallfunc(socket_id, tls_ctx),
- }
- else
- error(string.format("invalid websocket protocol:%s", tostring(protocol)))
- end
- obj.mode = "client"
- obj.id = assert(socket_id)
- obj.guid = GLOBAL_GUID
- ws_pool[socket_id] = obj
- return obj
- end
- local SSLCTX_SERVER = nil
- local function _new_server_ws(socket_id, handle, protocol)
- local obj
- if protocol == "ws" then
- obj = {
- close = function ()
- socket.close(socket_id)
- end,
- read = sockethelper.readfunc(socket_id),
- write = sockethelper.writefunc(socket_id),
- }
- elseif protocol == "wss" then
- local tls = require "http.tlshelper"
- if not SSLCTX_SERVER then
- SSLCTX_SERVER = tls.newctx()
- -- gen cert and key
- -- openssl req -x509 -newkey rsa:2048 -days 3650 -nodes -keyout server-key.pem -out server-cert.pem
- local certfile = skynet.getenv("certfile") or "./server-cert.pem"
- local keyfile = skynet.getenv("keyfile") or "./server-key.pem"
- SSLCTX_SERVER:set_cert(certfile, keyfile)
- end
- local tls_ctx = tls.newtls("server", SSLCTX_SERVER)
- local init = tls.init_responsefunc(socket_id, tls_ctx)
- init()
- obj = {
- close = function ()
- socket.close(socket_id)
- tls.closefunc(tls_ctx)()
- end,
- read = tls.readfunc(socket_id, tls_ctx),
- write = tls.writefunc(socket_id, tls_ctx),
- }
- else
- error(string.format("invalid websocket protocol:%s", tostring(protocol)))
- end
- obj.mode = "server"
- obj.id = assert(socket_id)
- obj.handle = handle
- obj.guid = GLOBAL_GUID
- ws_pool[socket_id] = obj
- return obj
- end
- -- handle interface
- -- connect / handshake / message / ping / pong / close / error
- function M.accept(socket_id, handle, protocol, addr, options)
- if not (options and options.upgrade) then
- socket.start(socket_id)
- end
- protocol = protocol or "ws"
- local ws_obj = _new_server_ws(socket_id, handle, protocol)
- ws_obj.addr = addr
- local on_warning = handle and handle["warning"]
- if on_warning then
- socket.warning(socket_id, function (id, sz)
- on_warning(ws_obj, sz)
- end)
- end
- local ok, err = xpcall(resolve_accept, debug.traceback, ws_obj, options)
- local closed = _isws_closed(socket_id)
- if not closed then
- _close_websocket(ws_obj)
- end
- if not ok then
- if err == socket_error then
- if closed then
- try_handle(ws_obj, "close")
- else
- try_handle(ws_obj, "error")
- end
- else
- -- error(err)
- return false, err
- end
- end
- return true
- end
- function M.connect(url, header, timeout)
- local protocol, host, uri = string.match(url, "^(wss?)://([^/]+)(.*)$")
- if protocol ~= "wss" and protocol ~= "ws" then
- error(string.format("invalid protocol: %s", protocol))
- end
- assert(host)
- local host_addr, host_port = string.match(host, "^([^:]+):?(%d*)$")
- assert(host_addr and host_port)
- if host_port == "" then
- host_port = protocol == "ws" and 80 or 443
- end
- local hostname
- if not host_addr:match(".*%d+$") then
- hostname = host_addr
- end
- uri = uri == "" and "/" or uri
- local socket_id = sockethelper.connect(host_addr, host_port, timeout)
- local ws_obj = _new_client_ws(socket_id, protocol, hostname)
- ws_obj.addr = host
- write_handshake(ws_obj, host_addr, uri, header)
- return socket_id
- end
- function M.read(id)
- local ws_obj = assert(ws_pool[id])
- local recv_buf
- while true do
- local fin, op, payload_data = read_frame(ws_obj)
- if op == "close" then
- _close_websocket(ws_obj)
- return false, payload_data
- elseif op == "ping" then
- write_frame(ws_obj, "pong", payload_data)
- elseif op ~= "pong" then -- op is frame, text binary
- if fin and not recv_buf then
- return payload_data
- else
- recv_buf = recv_buf or {}
- recv_buf[#recv_buf+1] = payload_data
- if fin then
- local s = table.concat(recv_buf)
- return s
- end
- end
- end
- end
- end
- function M.write(id, data, fmt, masking_key)
- local ws_obj = assert(ws_pool[id], id)
- fmt = fmt or "text"
- assert(fmt == "text" or fmt == "binary")
- write_frame(ws_obj, fmt, data, masking_key)
- end
- function M.ping(id)
- local ws_obj = assert(ws_pool[id])
- write_frame(ws_obj, "ping")
- end
- function M.addrinfo(id)
- local ws_obj = assert(ws_pool[id])
- return ws_obj.addr
- end
- function M.real_ip(id)
- local ws_obj = assert(ws_pool[id])
- return ws_obj.real_ip
- end
- function M.close(id, code ,reason)
- local ws_obj = ws_pool[id]
- if not ws_obj then
- return
- end
- local ok, err = xpcall(function ()
- reason = reason or ""
- local payload_data
- if code then
- local fmt =string.format(">I2c%d", #reason)
- payload_data = string.pack(fmt, code, reason)
- end
- write_frame(ws_obj, "close", payload_data)
- end, debug.traceback)
- _close_websocket(ws_obj)
- if not ok then
- skynet.error(err)
- end
- end
- M.is_close = _isws_closed
- function M.forward(fd, protocol, addr)
- protocol = protocol or "ws"
- local ws_obj = _new_server_ws(fd, {}, protocol)
- ws_obj.addr = addr
- end
- function M.clear_pool(id)
- local ws_obj = ws_pool[id]
- if not ws_obj then
- return
- end
- ws_pool[id] = nil
- end
- return M
|