sproto.lua 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. local core = require "sproto.core"
  2. local assert = assert
  3. local sproto = {}
  4. local host = {}
  5. local weak_mt = { __mode = "kv" }
  6. local sproto_mt = { __index = sproto }
  7. local sproto_nogc = { __index = sproto }
  8. local host_mt = { __index = host }
  9. function sproto_mt:__gc()
  10. core.deleteproto(self.__cobj)
  11. end
  12. function sproto.new(bin)
  13. local cobj = assert(core.newproto(bin))
  14. local self = {
  15. __cobj = cobj,
  16. __tcache = setmetatable( {} , weak_mt ),
  17. __pcache = setmetatable( {} , weak_mt ),
  18. }
  19. return setmetatable(self, sproto_mt)
  20. end
  21. function sproto.sharenew(cobj)
  22. local self = {
  23. __cobj = cobj,
  24. __tcache = setmetatable( {} , weak_mt ),
  25. __pcache = setmetatable( {} , weak_mt ),
  26. }
  27. return setmetatable(self, sproto_nogc)
  28. end
  29. function sproto.parse(ptext)
  30. local parser = require "sprotoparser"
  31. local pbin = parser.parse(ptext)
  32. return sproto.new(pbin)
  33. end
  34. function sproto:host( packagename )
  35. packagename = packagename or "package"
  36. local obj = {
  37. __proto = self,
  38. __package = assert(core.querytype(self.__cobj, packagename), "type package not found"),
  39. __session = {},
  40. }
  41. return setmetatable(obj, host_mt)
  42. end
  43. local function querytype(self, typename)
  44. local v = self.__tcache[typename]
  45. if not v then
  46. v = assert(core.querytype(self.__cobj, typename), "type not found")
  47. self.__tcache[typename] = v
  48. end
  49. return v
  50. end
  51. function sproto:exist_type(typename)
  52. local v = self.__tcache[typename]
  53. if not v then
  54. return core.querytype(self.__cobj, typename) ~= nil
  55. else
  56. return true
  57. end
  58. end
  59. function sproto:encode(typename, tbl)
  60. local st = querytype(self, typename)
  61. return core.encode(st, tbl)
  62. end
  63. function sproto:decode(typename, ...)
  64. local st = querytype(self, typename)
  65. return core.decode(st, ...)
  66. end
  67. function sproto:pencode(typename, tbl)
  68. local st = querytype(self, typename)
  69. return core.pack(core.encode(st, tbl))
  70. end
  71. function sproto:pdecode(typename, ...)
  72. local st = querytype(self, typename)
  73. return core.decode(st, core.unpack(...))
  74. end
  75. local function queryproto(self, pname)
  76. local v = self.__pcache[pname]
  77. if not v then
  78. local tag, req, resp = core.protocol(self.__cobj, pname)
  79. assert(tag, pname .. " not found")
  80. if tonumber(pname) then
  81. pname, tag = tag, pname
  82. end
  83. v = {
  84. request = req,
  85. response =resp,
  86. name = pname,
  87. tag = tag,
  88. }
  89. self.__pcache[pname] = v
  90. self.__pcache[tag] = v
  91. end
  92. return v
  93. end
  94. sproto.queryproto = queryproto
  95. function sproto:exist_proto(pname)
  96. local v = self.__pcache[pname]
  97. if not v then
  98. return core.protocol(self.__cobj, pname) ~= nil
  99. else
  100. return true
  101. end
  102. end
  103. function sproto:request_encode(protoname, tbl)
  104. local p = queryproto(self, protoname)
  105. local request = p.request
  106. if request then
  107. return core.encode(request,tbl) , p.tag
  108. else
  109. return "" , p.tag
  110. end
  111. end
  112. function sproto:response_encode(protoname, tbl)
  113. local p = queryproto(self, protoname)
  114. local response = p.response
  115. if response then
  116. return core.encode(response,tbl)
  117. else
  118. return ""
  119. end
  120. end
  121. function sproto:request_decode(protoname, ...)
  122. local p = queryproto(self, protoname)
  123. local request = p.request
  124. if request then
  125. return core.decode(request,...) , p.name
  126. else
  127. return nil, p.name
  128. end
  129. end
  130. function sproto:response_decode(protoname, ...)
  131. local p = queryproto(self, protoname)
  132. local response = p.response
  133. if response then
  134. return core.decode(response,...)
  135. end
  136. end
  137. sproto.pack = core.pack
  138. sproto.unpack = core.unpack
  139. function sproto:default(typename, type)
  140. if type == nil then
  141. return core.default(querytype(self, typename))
  142. else
  143. local p = queryproto(self, typename)
  144. if type == "REQUEST" then
  145. if p.request then
  146. return core.default(p.request)
  147. end
  148. elseif type == "RESPONSE" then
  149. if p.response then
  150. return core.default(p.response)
  151. end
  152. else
  153. error "Invalid type"
  154. end
  155. end
  156. end
  157. local header_tmp = {}
  158. local function gen_response(self, response, session)
  159. return function(args, ud)
  160. header_tmp.type = nil
  161. header_tmp.session = session
  162. header_tmp.ud = ud
  163. local header = core.encode(self.__package, header_tmp)
  164. if response then
  165. local content = core.encode(response, args)
  166. return core.pack(header .. content)
  167. else
  168. return core.pack(header)
  169. end
  170. end
  171. end
  172. function host:dispatch(...)
  173. local bin = core.unpack(...)
  174. header_tmp.type = nil
  175. header_tmp.session = nil
  176. header_tmp.ud = nil
  177. local header, size = core.decode(self.__package, bin, header_tmp)
  178. local content = bin:sub(size + 1)
  179. if header.type then
  180. -- request
  181. local proto = queryproto(self.__proto, header.type)
  182. local result
  183. if proto.request then
  184. result = core.decode(proto.request, content)
  185. end
  186. if header_tmp.session then
  187. return "REQUEST", proto.name, result, gen_response(self, proto.response, header_tmp.session), header.ud
  188. else
  189. return "REQUEST", proto.name, result, nil, header.ud
  190. end
  191. else
  192. -- response
  193. local session = assert(header_tmp.session, "session not found")
  194. local response = assert(self.__session[session], "Unknown session")
  195. self.__session[session] = nil
  196. if response == true then
  197. return "RESPONSE", session, nil, header.ud
  198. else
  199. local result = core.decode(response, content)
  200. return "RESPONSE", session, result, header.ud
  201. end
  202. end
  203. end
  204. function host:attach(sp)
  205. return function(name, args, session, ud)
  206. local proto = queryproto(sp, name)
  207. header_tmp.type = proto.tag
  208. header_tmp.session = session
  209. header_tmp.ud = ud
  210. local header = core.encode(self.__package, header_tmp)
  211. if session then
  212. self.__session[session] = proto.response or true
  213. end
  214. if proto.request then
  215. local content = core.encode(proto.request, args)
  216. return core.pack(header .. content)
  217. else
  218. return core.pack(header)
  219. end
  220. end
  221. end
  222. return sproto