123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253 |
- local core = require "sproto.core"
- local assert = assert
- local sproto = {}
- local host = {}
- local weak_mt = { __mode = "kv" }
- local sproto_mt = { __index = sproto }
- local sproto_nogc = { __index = sproto }
- local host_mt = { __index = host }
- function sproto_mt:__gc()
- core.deleteproto(self.__cobj)
- end
- function sproto.new(bin)
- local cobj = assert(core.newproto(bin))
- local self = {
- __cobj = cobj,
- __tcache = setmetatable( {} , weak_mt ),
- __pcache = setmetatable( {} , weak_mt ),
- }
- return setmetatable(self, sproto_mt)
- end
- function sproto.sharenew(cobj)
- local self = {
- __cobj = cobj,
- __tcache = setmetatable( {} , weak_mt ),
- __pcache = setmetatable( {} , weak_mt ),
- }
- return setmetatable(self, sproto_nogc)
- end
- function sproto.parse(ptext)
- local parser = require "sprotoparser"
- local pbin = parser.parse(ptext)
- return sproto.new(pbin)
- end
- function sproto:host( packagename )
- packagename = packagename or "package"
- local obj = {
- __proto = self,
- __package = assert(core.querytype(self.__cobj, packagename), "type package not found"),
- __session = {},
- }
- return setmetatable(obj, host_mt)
- end
- local function querytype(self, typename)
- local v = self.__tcache[typename]
- if not v then
- v = assert(core.querytype(self.__cobj, typename), "type not found")
- self.__tcache[typename] = v
- end
- return v
- end
- function sproto:exist_type(typename)
- local v = self.__tcache[typename]
- if not v then
- return core.querytype(self.__cobj, typename) ~= nil
- else
- return true
- end
- end
- function sproto:encode(typename, tbl)
- local st = querytype(self, typename)
- return core.encode(st, tbl)
- end
- function sproto:decode(typename, ...)
- local st = querytype(self, typename)
- return core.decode(st, ...)
- end
- function sproto:pencode(typename, tbl)
- local st = querytype(self, typename)
- return core.pack(core.encode(st, tbl))
- end
- function sproto:pdecode(typename, ...)
- local st = querytype(self, typename)
- return core.decode(st, core.unpack(...))
- end
- local function queryproto(self, pname)
- local v = self.__pcache[pname]
- if not v then
- local tag, req, resp = core.protocol(self.__cobj, pname)
- assert(tag, pname .. " not found")
- if tonumber(pname) then
- pname, tag = tag, pname
- end
- v = {
- request = req,
- response =resp,
- name = pname,
- tag = tag,
- }
- self.__pcache[pname] = v
- self.__pcache[tag] = v
- end
- return v
- end
- sproto.queryproto = queryproto
- function sproto:exist_proto(pname)
- local v = self.__pcache[pname]
- if not v then
- return core.protocol(self.__cobj, pname) ~= nil
- else
- return true
- end
- end
- function sproto:request_encode(protoname, tbl)
- local p = queryproto(self, protoname)
- local request = p.request
- if request then
- return core.encode(request,tbl) , p.tag
- else
- return "" , p.tag
- end
- end
- function sproto:response_encode(protoname, tbl)
- local p = queryproto(self, protoname)
- local response = p.response
- if response then
- return core.encode(response,tbl)
- else
- return ""
- end
- end
- function sproto:request_decode(protoname, ...)
- local p = queryproto(self, protoname)
- local request = p.request
- if request then
- return core.decode(request,...) , p.name
- else
- return nil, p.name
- end
- end
- function sproto:response_decode(protoname, ...)
- local p = queryproto(self, protoname)
- local response = p.response
- if response then
- return core.decode(response,...)
- end
- end
- sproto.pack = core.pack
- sproto.unpack = core.unpack
- function sproto:default(typename, type)
- if type == nil then
- return core.default(querytype(self, typename))
- else
- local p = queryproto(self, typename)
- if type == "REQUEST" then
- if p.request then
- return core.default(p.request)
- end
- elseif type == "RESPONSE" then
- if p.response then
- return core.default(p.response)
- end
- else
- error "Invalid type"
- end
- end
- end
- local header_tmp = {}
- local function gen_response(self, response, session)
- return function(args, ud)
- header_tmp.type = nil
- header_tmp.session = session
- header_tmp.ud = ud
- local header = core.encode(self.__package, header_tmp)
- if response then
- local content = core.encode(response, args)
- return core.pack(header .. content)
- else
- return core.pack(header)
- end
- end
- end
- function host:dispatch(...)
- local bin = core.unpack(...)
- header_tmp.type = nil
- header_tmp.session = nil
- header_tmp.ud = nil
- local header, size = core.decode(self.__package, bin, header_tmp)
- local content = bin:sub(size + 1)
- if header.type then
- -- request
- local proto = queryproto(self.__proto, header.type)
- local result
- if proto.request then
- result = core.decode(proto.request, content)
- end
- if header_tmp.session then
- return "REQUEST", proto.name, result, gen_response(self, proto.response, header_tmp.session), header.ud
- else
- return "REQUEST", proto.name, result, nil, header.ud
- end
- else
- -- response
- local session = assert(header_tmp.session, "session not found")
- local response = assert(self.__session[session], "Unknown session")
- self.__session[session] = nil
- if response == true then
- return "RESPONSE", session, nil, header.ud
- else
- local result = core.decode(response, content)
- return "RESPONSE", session, result, header.ud
- end
- end
- end
- function host:attach(sp)
- return function(name, args, session, ud)
- local proto = queryproto(sp, name)
- header_tmp.type = proto.tag
- header_tmp.session = session
- header_tmp.ud = ud
- local header = core.encode(self.__package, header_tmp)
- if session then
- self.__session[session] = proto.response or true
- end
- if proto.request then
- local content = core.encode(proto.request, args)
- return core.pack(header .. content)
- else
- return core.pack(header)
- end
- end
- end
- return sproto
|