sprotoparser.lua 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. local lpeg = require "lpeg"
  2. local table = require "table"
  3. local packbytes
  4. local packvalue
  5. local version = _VERSION:match "5.*"
  6. if version and tonumber(version) >= 5.3 then
  7. function packbytes(str)
  8. return string.pack("<s4",str)
  9. end
  10. function packvalue(id)
  11. id = (id + 1) * 2
  12. return string.pack("<I2",id)
  13. end
  14. else
  15. function packbytes(str)
  16. local size = #str
  17. local a = size % 256
  18. size = math.floor(size / 256)
  19. local b = size % 256
  20. size = math.floor(size / 256)
  21. local c = size % 256
  22. size = math.floor(size / 256)
  23. local d = size
  24. return string.char(a)..string.char(b)..string.char(c)..string.char(d) .. str
  25. end
  26. function packvalue(id)
  27. id = (id + 1) * 2
  28. assert(id >=0 and id < 65536)
  29. local a = id % 256
  30. local b = math.floor(id / 256)
  31. return string.char(a) .. string.char(b)
  32. end
  33. end
  34. local P = lpeg.P
  35. local S = lpeg.S
  36. local R = lpeg.R
  37. local C = lpeg.C
  38. local Ct = lpeg.Ct
  39. local Cg = lpeg.Cg
  40. local Cc = lpeg.Cc
  41. local V = lpeg.V
  42. local function count_lines(_,pos, parser_state)
  43. if parser_state.pos < pos then
  44. parser_state.line = parser_state.line + 1
  45. parser_state.pos = pos
  46. end
  47. return pos
  48. end
  49. local exception = lpeg.Cmt( lpeg.Carg(1) , function ( _ , pos, parser_state)
  50. error(string.format("syntax error at [%s] line (%d)", parser_state.file or "", parser_state.line))
  51. return pos
  52. end)
  53. local eof = P(-1)
  54. local newline = lpeg.Cmt((P"\n" + "\r\n") * lpeg.Carg(1) ,count_lines)
  55. local line_comment = "#" * (1 - newline) ^0 * (newline + eof)
  56. local blank = S" \t" + newline + line_comment
  57. local blank0 = blank ^ 0
  58. local blanks = blank ^ 1
  59. local alpha = R"az" + R"AZ" + "_"
  60. local alnum = alpha + R"09"
  61. local word = alpha * alnum ^ 0
  62. local name = C(word)
  63. local typename = C(word * ("." * word) ^ 0)
  64. local tag = R"09" ^ 1 / tonumber
  65. local mainkey = "(" * blank0 * C((word ^ 0)) * blank0 * ")"
  66. local decimal = "(" * blank0 * C(tag) * blank0 * ")"
  67. local function multipat(pat)
  68. return Ct(blank0 * (pat * blanks) ^ 0 * pat^0 * blank0)
  69. end
  70. local function namedpat(name, pat)
  71. return Ct(Cg(Cc(name), "type") * Cg(pat))
  72. end
  73. local typedef = P {
  74. "ALL",
  75. FIELD = namedpat("field", name * blanks * tag * blank0 * ":" * blank0 * (C"*")^-1 * typename * (mainkey + decimal)^0),
  76. STRUCT = P"{" * multipat(V"FIELD" + V"TYPE") * P"}",
  77. TYPE = namedpat("type", P"." * name * blank0 * V"STRUCT" ),
  78. SUBPROTO = Ct((C"request" + C"response") * blanks * (typename + V"STRUCT")),
  79. PROTOCOL = namedpat("protocol", name * blanks * tag * blank0 * P"{" * multipat(V"SUBPROTO") * P"}"),
  80. ALL = multipat(V"TYPE" + V"PROTOCOL"),
  81. }
  82. local proto = blank0 * typedef * blank0
  83. local convert = {}
  84. function convert.protocol(all, obj)
  85. local result = { tag = obj[2] }
  86. for _, p in ipairs(obj[3]) do
  87. local pt = p[1]
  88. if result[pt] ~= nil then
  89. error(string.format("redefine %s in protocol %s", pt, obj[1]))
  90. end
  91. local typename = p[2]
  92. if type(typename) == "table" then
  93. local struct = typename
  94. typename = obj[1] .. "." .. p[1]
  95. all.type[typename] = convert.type(all, { typename, struct })
  96. end
  97. if typename == "nil" then
  98. if p[1] == "response" then
  99. result.confirm = true
  100. end
  101. else
  102. result[p[1]] = typename
  103. end
  104. end
  105. return result
  106. end
  107. local map_keytypes = {
  108. integer = true,
  109. string = true,
  110. }
  111. function convert.type(all, obj)
  112. local result = {}
  113. local typename = obj[1]
  114. local tags = {}
  115. local names = {}
  116. for _, f in ipairs(obj[2]) do
  117. if f.type == "field" then
  118. local name = f[1]
  119. if names[name] then
  120. error(string.format("redefine %s in type %s", name, typename))
  121. end
  122. names[name] = true
  123. local tag = f[2]
  124. if tags[tag] then
  125. error(string.format("redefine tag %d in type %s", tag, typename))
  126. end
  127. tags[tag] = true
  128. local field = { name = name, tag = tag }
  129. table.insert(result, field)
  130. local fieldtype = f[3]
  131. if fieldtype == "*" then
  132. field.array = true
  133. fieldtype = f[4]
  134. end
  135. local mainkey = f[5]
  136. if mainkey then
  137. if fieldtype == "integer" then
  138. field.decimal = mainkey
  139. else
  140. assert(field.array)
  141. field.key = mainkey
  142. end
  143. end
  144. field.typename = fieldtype
  145. else
  146. assert(f.type == "type") -- nest type
  147. local nesttypename = typename .. "." .. f[1]
  148. f[1] = nesttypename
  149. assert(all.type[nesttypename] == nil, "redefined " .. nesttypename)
  150. all.type[nesttypename] = convert.type(all, f)
  151. end
  152. end
  153. table.sort(result, function(a,b) return a.tag < b.tag end)
  154. return result
  155. end
  156. local function adjust(r)
  157. local result = { type = {} , protocol = {} }
  158. for _, obj in ipairs(r) do
  159. local set = result[obj.type]
  160. local name = obj[1]
  161. assert(set[name] == nil , "redefined " .. name)
  162. set[name] = convert[obj.type](result,obj)
  163. end
  164. return result
  165. end
  166. local buildin_types = {
  167. integer = 0,
  168. boolean = 1,
  169. string = 2,
  170. binary = 2, -- binary is a sub type of string
  171. double = 3,
  172. }
  173. local function checktype(types, ptype, t)
  174. if buildin_types[t] then
  175. return t
  176. end
  177. local fullname = ptype .. "." .. t
  178. if types[fullname] then
  179. return fullname
  180. else
  181. ptype = ptype:match "(.+)%..+$"
  182. if ptype then
  183. return checktype(types, ptype, t)
  184. elseif types[t] then
  185. return t
  186. end
  187. end
  188. end
  189. local function check_protocol(r)
  190. local map = {}
  191. local type = r.type
  192. for name, v in pairs(r.protocol) do
  193. local tag = v.tag
  194. local request = v.request
  195. local response = v.response
  196. local p = map[tag]
  197. if p then
  198. error(string.format("redefined protocol tag %d at %s", tag, name))
  199. end
  200. if request and not type[request] then
  201. error(string.format("Undefined request type %s in protocol %s", request, name))
  202. end
  203. if response and not type[response] then
  204. error(string.format("Undefined response type %s in protocol %s", response, name))
  205. end
  206. map[tag] = v
  207. end
  208. return r
  209. end
  210. local function flattypename(r)
  211. for typename, t in pairs(r.type) do
  212. for _, f in pairs(t) do
  213. local ftype = f.typename
  214. local fullname = checktype(r.type, typename, ftype)
  215. if fullname == nil then
  216. error(string.format("Undefined type %s in type %s", ftype, typename))
  217. end
  218. f.typename = fullname
  219. end
  220. end
  221. return r
  222. end
  223. local function parser(text,filename)
  224. local state = { file = filename, pos = 0, line = 1 }
  225. local r = lpeg.match(proto * -1 + exception , text , 1, state )
  226. return flattypename(check_protocol(adjust(r)))
  227. end
  228. --[[
  229. -- The protocol of sproto
  230. .type {
  231. .field {
  232. name 0 : string
  233. buildin 1 : integer
  234. type 2 : integer
  235. tag 3 : integer
  236. array 4 : boolean
  237. key 5 : integer # If key exists, array must be true
  238. map 6 : boolean # Interpret two fields struct as map when decoding
  239. }
  240. name 0 : string
  241. fields 1 : *field
  242. }
  243. .protocol {
  244. name 0 : string
  245. tag 1 : integer
  246. request 2 : integer # index
  247. response 3 : integer # index
  248. confirm 4 : boolean # true means response nil
  249. }
  250. .group {
  251. type 0 : *type
  252. protocol 1 : *protocol
  253. }
  254. ]]
  255. local function packfield(f)
  256. local strtbl = {}
  257. if f.array then
  258. if f.key then
  259. if f.map then
  260. table.insert(strtbl, "\7\0") -- 7 fields
  261. else
  262. table.insert(strtbl, "\6\0") -- 6 fields
  263. end
  264. else
  265. table.insert(strtbl, "\5\0") -- 5 fields
  266. end
  267. else
  268. table.insert(strtbl, "\4\0") -- 4 fields
  269. end
  270. table.insert(strtbl, "\0\0") -- name (tag = 0, ref an object)
  271. if f.buildin then
  272. table.insert(strtbl, packvalue(f.buildin)) -- buildin (tag = 1)
  273. if f.extra then
  274. table.insert(strtbl, packvalue(f.extra)) -- f.buildin can be integer or string
  275. else
  276. table.insert(strtbl, "\1\0") -- skip (tag = 2)
  277. end
  278. table.insert(strtbl, packvalue(f.tag)) -- tag (tag = 3)
  279. else
  280. table.insert(strtbl, "\1\0") -- skip (tag = 1)
  281. table.insert(strtbl, packvalue(f.type)) -- type (tag = 2)
  282. table.insert(strtbl, packvalue(f.tag)) -- tag (tag = 3)
  283. end
  284. if f.array then
  285. table.insert(strtbl, packvalue(1)) -- array = true (tag = 4)
  286. if f.key then
  287. table.insert(strtbl, packvalue(f.key)) -- key tag (tag = 5)
  288. if f.map then
  289. table.insert(strtbl, packvalue(f.map)) -- map tag (tag = 6)
  290. end
  291. end
  292. end
  293. table.insert(strtbl, packbytes(f.name)) -- external object (name)
  294. return packbytes(table.concat(strtbl))
  295. end
  296. local function packtype(name, t, alltypes)
  297. local fields = {}
  298. local tmp = {}
  299. for _, f in ipairs(t) do
  300. tmp.array = f.array
  301. tmp.name = f.name
  302. tmp.tag = f.tag
  303. tmp.extra = f.decimal
  304. tmp.buildin = buildin_types[f.typename]
  305. if f.typename == "binary" then
  306. tmp.extra = 1 -- binary is sub type of string
  307. end
  308. local subtype
  309. if not tmp.buildin then
  310. subtype = assert(alltypes[f.typename])
  311. tmp.type = subtype.id
  312. else
  313. tmp.type = nil
  314. end
  315. tmp.map = nil
  316. if f.key then
  317. assert(f.array)
  318. if f.key == "" then
  319. tmp.map = 1
  320. local c = 0
  321. local min_t = math.maxinteger
  322. for n, t in pairs(subtype.fields) do
  323. c = c + 1
  324. if t.tag < min_t then
  325. min_t = t.tag
  326. f.key = n
  327. end
  328. end
  329. if c ~= 2 then
  330. error(string.format("Invalid map definition: %s, must only have two fields", tmp.name))
  331. end
  332. end
  333. local stfield = subtype.fields[f.key]
  334. if not stfield or not stfield.buildin then
  335. error("Invalid map index :" .. f.key)
  336. end
  337. tmp.key = stfield.tag
  338. else
  339. tmp.key = nil
  340. end
  341. table.insert(fields, packfield(tmp))
  342. end
  343. local data
  344. if #fields == 0 then
  345. data = {
  346. "\1\0", -- 1 fields
  347. "\0\0", -- name (id = 0, ref = 0)
  348. packbytes(name),
  349. }
  350. else
  351. data = {
  352. "\2\0", -- 2 fields
  353. "\0\0", -- name (tag = 0, ref = 0)
  354. "\0\0", -- field[] (tag = 1, ref = 1)
  355. packbytes(name),
  356. packbytes(table.concat(fields)),
  357. }
  358. end
  359. return packbytes(table.concat(data))
  360. end
  361. local function packproto(name, p, alltypes)
  362. if p.request then
  363. local request = alltypes[p.request]
  364. if request == nil then
  365. error(string.format("Protocol %s request type %s not found", name, p.request))
  366. end
  367. request = request.id
  368. end
  369. local tmp = {
  370. "\4\0", -- 4 fields
  371. "\0\0", -- name (id=0, ref=0)
  372. packvalue(p.tag), -- tag (tag=1)
  373. }
  374. if p.request == nil and p.response == nil and p.confirm == nil then
  375. tmp[1] = "\2\0" -- only two fields
  376. else
  377. if p.request then
  378. table.insert(tmp, packvalue(alltypes[p.request].id)) -- request typename (tag=2)
  379. else
  380. table.insert(tmp, "\1\0") -- skip this field (request)
  381. end
  382. if p.response then
  383. table.insert(tmp, packvalue(alltypes[p.response].id)) -- request typename (tag=3)
  384. elseif p.confirm then
  385. tmp[1] = "\5\0" -- add confirm field
  386. table.insert(tmp, "\1\0") -- skip this field (response)
  387. table.insert(tmp, packvalue(1)) -- confirm = true
  388. else
  389. tmp[1] = "\3\0" -- only three fields
  390. end
  391. end
  392. table.insert(tmp, packbytes(name))
  393. return packbytes(table.concat(tmp))
  394. end
  395. local function packgroup(t,p)
  396. if next(t) == nil then
  397. assert(next(p) == nil)
  398. return "\0\0"
  399. end
  400. local tt, tp
  401. local alltypes = {}
  402. for name in pairs(t) do
  403. table.insert(alltypes, name)
  404. end
  405. table.sort(alltypes) -- make result stable
  406. for idx, name in ipairs(alltypes) do
  407. local fields = {}
  408. for _, type_fields in ipairs(t[name]) do
  409. fields[type_fields.name] = {
  410. tag = type_fields.tag,
  411. buildin = buildin_types[type_fields.typename]
  412. }
  413. end
  414. alltypes[name] = { id = idx - 1, fields = fields }
  415. end
  416. tt = {}
  417. for _,name in ipairs(alltypes) do
  418. table.insert(tt, packtype(name, t[name], alltypes))
  419. end
  420. tt = packbytes(table.concat(tt))
  421. if next(p) then
  422. local tmp = {}
  423. for name, tbl in pairs(p) do
  424. table.insert(tmp, tbl)
  425. tbl.name = name
  426. end
  427. table.sort(tmp, function(a,b) return a.tag < b.tag end)
  428. tp = {}
  429. for _, tbl in ipairs(tmp) do
  430. table.insert(tp, packproto(tbl.name, tbl, alltypes))
  431. end
  432. tp = packbytes(table.concat(tp))
  433. end
  434. local result
  435. if tp == nil then
  436. result = {
  437. "\1\0", -- 1 field
  438. "\0\0", -- type[] (id = 0, ref = 0)
  439. tt,
  440. }
  441. else
  442. result = {
  443. "\2\0", -- 2fields
  444. "\0\0", -- type array (id = 0, ref = 0)
  445. "\0\0", -- protocol array (id = 1, ref =1)
  446. tt,
  447. tp,
  448. }
  449. end
  450. return table.concat(result)
  451. end
  452. local function encodeall(r)
  453. return packgroup(r.type, r.protocol)
  454. end
  455. local sparser = {}
  456. function sparser.dump(str)
  457. local tmp = ""
  458. for i=1,#str do
  459. tmp = tmp .. string.format("%02X ", string.byte(str,i))
  460. if i % 8 == 0 then
  461. if i % 16 == 0 then
  462. print(tmp)
  463. tmp = ""
  464. else
  465. tmp = tmp .. "- "
  466. end
  467. end
  468. end
  469. print(tmp)
  470. end
  471. function sparser.parse(text, name)
  472. local r = parser(text, name or "=text")
  473. local data = encodeall(r)
  474. return data
  475. end
  476. return sparser