123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527 |
- local lpeg = require "lpeg"
- local table = require "table"
- local packbytes
- local packvalue
- local version = _VERSION:match "5.*"
- if version and tonumber(version) >= 5.3 then
- function packbytes(str)
- return string.pack("<s4",str)
- end
- function packvalue(id)
- id = (id + 1) * 2
- return string.pack("<I2",id)
- end
- else
- function packbytes(str)
- local size = #str
- local a = size % 256
- size = math.floor(size / 256)
- local b = size % 256
- size = math.floor(size / 256)
- local c = size % 256
- size = math.floor(size / 256)
- local d = size
- return string.char(a)..string.char(b)..string.char(c)..string.char(d) .. str
- end
- function packvalue(id)
- id = (id + 1) * 2
- assert(id >=0 and id < 65536)
- local a = id % 256
- local b = math.floor(id / 256)
- return string.char(a) .. string.char(b)
- end
- end
- local P = lpeg.P
- local S = lpeg.S
- local R = lpeg.R
- local C = lpeg.C
- local Ct = lpeg.Ct
- local Cg = lpeg.Cg
- local Cc = lpeg.Cc
- local V = lpeg.V
- local function count_lines(_,pos, parser_state)
- if parser_state.pos < pos then
- parser_state.line = parser_state.line + 1
- parser_state.pos = pos
- end
- return pos
- end
- local exception = lpeg.Cmt( lpeg.Carg(1) , function ( _ , pos, parser_state)
- error(string.format("syntax error at [%s] line (%d)", parser_state.file or "", parser_state.line))
- return pos
- end)
- local eof = P(-1)
- local newline = lpeg.Cmt((P"\n" + "\r\n") * lpeg.Carg(1) ,count_lines)
- local line_comment = "#" * (1 - newline) ^0 * (newline + eof)
- local blank = S" \t" + newline + line_comment
- local blank0 = blank ^ 0
- local blanks = blank ^ 1
- local alpha = R"az" + R"AZ" + "_"
- local alnum = alpha + R"09"
- local word = alpha * alnum ^ 0
- local name = C(word)
- local typename = C(word * ("." * word) ^ 0)
- local tag = R"09" ^ 1 / tonumber
- local mainkey = "(" * blank0 * C((word ^ 0)) * blank0 * ")"
- local decimal = "(" * blank0 * C(tag) * blank0 * ")"
- local function multipat(pat)
- return Ct(blank0 * (pat * blanks) ^ 0 * pat^0 * blank0)
- end
- local function namedpat(name, pat)
- return Ct(Cg(Cc(name), "type") * Cg(pat))
- end
- local typedef = P {
- "ALL",
- FIELD = namedpat("field", name * blanks * tag * blank0 * ":" * blank0 * (C"*")^-1 * typename * (mainkey + decimal)^0),
- STRUCT = P"{" * multipat(V"FIELD" + V"TYPE") * P"}",
- TYPE = namedpat("type", P"." * name * blank0 * V"STRUCT" ),
- SUBPROTO = Ct((C"request" + C"response") * blanks * (typename + V"STRUCT")),
- PROTOCOL = namedpat("protocol", name * blanks * tag * blank0 * P"{" * multipat(V"SUBPROTO") * P"}"),
- ALL = multipat(V"TYPE" + V"PROTOCOL"),
- }
- local proto = blank0 * typedef * blank0
- local convert = {}
- function convert.protocol(all, obj)
- local result = { tag = obj[2] }
- for _, p in ipairs(obj[3]) do
- local pt = p[1]
- if result[pt] ~= nil then
- error(string.format("redefine %s in protocol %s", pt, obj[1]))
- end
- local typename = p[2]
- if type(typename) == "table" then
- local struct = typename
- typename = obj[1] .. "." .. p[1]
- all.type[typename] = convert.type(all, { typename, struct })
- end
- if typename == "nil" then
- if p[1] == "response" then
- result.confirm = true
- end
- else
- result[p[1]] = typename
- end
- end
- return result
- end
- local map_keytypes = {
- integer = true,
- string = true,
- }
- function convert.type(all, obj)
- local result = {}
- local typename = obj[1]
- local tags = {}
- local names = {}
- for _, f in ipairs(obj[2]) do
- if f.type == "field" then
- local name = f[1]
- if names[name] then
- error(string.format("redefine %s in type %s", name, typename))
- end
- names[name] = true
- local tag = f[2]
- if tags[tag] then
- error(string.format("redefine tag %d in type %s", tag, typename))
- end
- tags[tag] = true
- local field = { name = name, tag = tag }
- table.insert(result, field)
- local fieldtype = f[3]
- if fieldtype == "*" then
- field.array = true
- fieldtype = f[4]
- end
- local mainkey = f[5]
- if mainkey then
- if fieldtype == "integer" then
- field.decimal = mainkey
- else
- assert(field.array)
- field.key = mainkey
- end
- end
- field.typename = fieldtype
- else
- assert(f.type == "type") -- nest type
- local nesttypename = typename .. "." .. f[1]
- f[1] = nesttypename
- assert(all.type[nesttypename] == nil, "redefined " .. nesttypename)
- all.type[nesttypename] = convert.type(all, f)
- end
- end
- table.sort(result, function(a,b) return a.tag < b.tag end)
- return result
- end
- local function adjust(r)
- local result = { type = {} , protocol = {} }
- for _, obj in ipairs(r) do
- local set = result[obj.type]
- local name = obj[1]
- assert(set[name] == nil , "redefined " .. name)
- set[name] = convert[obj.type](result,obj)
- end
- return result
- end
- local buildin_types = {
- integer = 0,
- boolean = 1,
- string = 2,
- binary = 2, -- binary is a sub type of string
- double = 3,
- }
- local function checktype(types, ptype, t)
- if buildin_types[t] then
- return t
- end
- local fullname = ptype .. "." .. t
- if types[fullname] then
- return fullname
- else
- ptype = ptype:match "(.+)%..+$"
- if ptype then
- return checktype(types, ptype, t)
- elseif types[t] then
- return t
- end
- end
- end
- local function check_protocol(r)
- local map = {}
- local type = r.type
- for name, v in pairs(r.protocol) do
- local tag = v.tag
- local request = v.request
- local response = v.response
- local p = map[tag]
- if p then
- error(string.format("redefined protocol tag %d at %s", tag, name))
- end
- if request and not type[request] then
- error(string.format("Undefined request type %s in protocol %s", request, name))
- end
- if response and not type[response] then
- error(string.format("Undefined response type %s in protocol %s", response, name))
- end
- map[tag] = v
- end
- return r
- end
- local function flattypename(r)
- for typename, t in pairs(r.type) do
- for _, f in pairs(t) do
- local ftype = f.typename
- local fullname = checktype(r.type, typename, ftype)
- if fullname == nil then
- error(string.format("Undefined type %s in type %s", ftype, typename))
- end
- f.typename = fullname
- end
- end
- return r
- end
- local function parser(text,filename)
- local state = { file = filename, pos = 0, line = 1 }
- local r = lpeg.match(proto * -1 + exception , text , 1, state )
- return flattypename(check_protocol(adjust(r)))
- end
- --[[
- -- The protocol of sproto
- .type {
- .field {
- name 0 : string
- buildin 1 : integer
- type 2 : integer
- tag 3 : integer
- array 4 : boolean
- key 5 : integer # If key exists, array must be true
- map 6 : boolean # Interpret two fields struct as map when decoding
- }
- name 0 : string
- fields 1 : *field
- }
- .protocol {
- name 0 : string
- tag 1 : integer
- request 2 : integer # index
- response 3 : integer # index
- confirm 4 : boolean # true means response nil
- }
- .group {
- type 0 : *type
- protocol 1 : *protocol
- }
- ]]
- local function packfield(f)
- local strtbl = {}
- if f.array then
- if f.key then
- if f.map then
- table.insert(strtbl, "\7\0") -- 7 fields
- else
- table.insert(strtbl, "\6\0") -- 6 fields
- end
- else
- table.insert(strtbl, "\5\0") -- 5 fields
- end
- else
- table.insert(strtbl, "\4\0") -- 4 fields
- end
- table.insert(strtbl, "\0\0") -- name (tag = 0, ref an object)
- if f.buildin then
- table.insert(strtbl, packvalue(f.buildin)) -- buildin (tag = 1)
- if f.extra then
- table.insert(strtbl, packvalue(f.extra)) -- f.buildin can be integer or string
- else
- table.insert(strtbl, "\1\0") -- skip (tag = 2)
- end
- table.insert(strtbl, packvalue(f.tag)) -- tag (tag = 3)
- else
- table.insert(strtbl, "\1\0") -- skip (tag = 1)
- table.insert(strtbl, packvalue(f.type)) -- type (tag = 2)
- table.insert(strtbl, packvalue(f.tag)) -- tag (tag = 3)
- end
- if f.array then
- table.insert(strtbl, packvalue(1)) -- array = true (tag = 4)
- if f.key then
- table.insert(strtbl, packvalue(f.key)) -- key tag (tag = 5)
- if f.map then
- table.insert(strtbl, packvalue(f.map)) -- map tag (tag = 6)
- end
- end
- end
- table.insert(strtbl, packbytes(f.name)) -- external object (name)
- return packbytes(table.concat(strtbl))
- end
- local function packtype(name, t, alltypes)
- local fields = {}
- local tmp = {}
- for _, f in ipairs(t) do
- tmp.array = f.array
- tmp.name = f.name
- tmp.tag = f.tag
- tmp.extra = f.decimal
- tmp.buildin = buildin_types[f.typename]
- if f.typename == "binary" then
- tmp.extra = 1 -- binary is sub type of string
- end
- local subtype
- if not tmp.buildin then
- subtype = assert(alltypes[f.typename])
- tmp.type = subtype.id
- else
- tmp.type = nil
- end
- tmp.map = nil
- if f.key then
- assert(f.array)
- if f.key == "" then
- tmp.map = 1
- local c = 0
- local min_t = math.maxinteger
- for n, t in pairs(subtype.fields) do
- c = c + 1
- if t.tag < min_t then
- min_t = t.tag
- f.key = n
- end
- end
- if c ~= 2 then
- error(string.format("Invalid map definition: %s, must only have two fields", tmp.name))
- end
- end
- local stfield = subtype.fields[f.key]
- if not stfield or not stfield.buildin then
- error("Invalid map index :" .. f.key)
- end
- tmp.key = stfield.tag
- else
- tmp.key = nil
- end
- table.insert(fields, packfield(tmp))
- end
- local data
- if #fields == 0 then
- data = {
- "\1\0", -- 1 fields
- "\0\0", -- name (id = 0, ref = 0)
- packbytes(name),
- }
- else
- data = {
- "\2\0", -- 2 fields
- "\0\0", -- name (tag = 0, ref = 0)
- "\0\0", -- field[] (tag = 1, ref = 1)
- packbytes(name),
- packbytes(table.concat(fields)),
- }
- end
- return packbytes(table.concat(data))
- end
- local function packproto(name, p, alltypes)
- if p.request then
- local request = alltypes[p.request]
- if request == nil then
- error(string.format("Protocol %s request type %s not found", name, p.request))
- end
- request = request.id
- end
- local tmp = {
- "\4\0", -- 4 fields
- "\0\0", -- name (id=0, ref=0)
- packvalue(p.tag), -- tag (tag=1)
- }
- if p.request == nil and p.response == nil and p.confirm == nil then
- tmp[1] = "\2\0" -- only two fields
- else
- if p.request then
- table.insert(tmp, packvalue(alltypes[p.request].id)) -- request typename (tag=2)
- else
- table.insert(tmp, "\1\0") -- skip this field (request)
- end
- if p.response then
- table.insert(tmp, packvalue(alltypes[p.response].id)) -- request typename (tag=3)
- elseif p.confirm then
- tmp[1] = "\5\0" -- add confirm field
- table.insert(tmp, "\1\0") -- skip this field (response)
- table.insert(tmp, packvalue(1)) -- confirm = true
- else
- tmp[1] = "\3\0" -- only three fields
- end
- end
- table.insert(tmp, packbytes(name))
- return packbytes(table.concat(tmp))
- end
- local function packgroup(t,p)
- if next(t) == nil then
- assert(next(p) == nil)
- return "\0\0"
- end
- local tt, tp
- local alltypes = {}
- for name in pairs(t) do
- table.insert(alltypes, name)
- end
- table.sort(alltypes) -- make result stable
- for idx, name in ipairs(alltypes) do
- local fields = {}
- for _, type_fields in ipairs(t[name]) do
- fields[type_fields.name] = {
- tag = type_fields.tag,
- buildin = buildin_types[type_fields.typename]
- }
- end
- alltypes[name] = { id = idx - 1, fields = fields }
- end
- tt = {}
- for _,name in ipairs(alltypes) do
- table.insert(tt, packtype(name, t[name], alltypes))
- end
- tt = packbytes(table.concat(tt))
- if next(p) then
- local tmp = {}
- for name, tbl in pairs(p) do
- table.insert(tmp, tbl)
- tbl.name = name
- end
- table.sort(tmp, function(a,b) return a.tag < b.tag end)
- tp = {}
- for _, tbl in ipairs(tmp) do
- table.insert(tp, packproto(tbl.name, tbl, alltypes))
- end
- tp = packbytes(table.concat(tp))
- end
- local result
- if tp == nil then
- result = {
- "\1\0", -- 1 field
- "\0\0", -- type[] (id = 0, ref = 0)
- tt,
- }
- else
- result = {
- "\2\0", -- 2fields
- "\0\0", -- type array (id = 0, ref = 0)
- "\0\0", -- protocol array (id = 1, ref =1)
- tt,
- tp,
- }
- end
- return table.concat(result)
- end
- local function encodeall(r)
- return packgroup(r.type, r.protocol)
- end
- local sparser = {}
- function sparser.dump(str)
- local tmp = ""
- for i=1,#str do
- tmp = tmp .. string.format("%02X ", string.byte(str,i))
- if i % 8 == 0 then
- if i % 16 == 0 then
- print(tmp)
- tmp = ""
- else
- tmp = tmp .. "- "
- end
- end
- end
- print(tmp)
- end
- function sparser.parse(text, name)
- local r = parser(text, name or "=text")
- local data = encodeall(r)
- return data
- end
- return sparser
|