lua-netpack.c 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. #define LUA_LIB
  2. #include "skynet_malloc.h"
  3. #include "skynet_socket.h"
  4. #include <lua.h>
  5. #include <lauxlib.h>
  6. #include <assert.h>
  7. #include <stdint.h>
  8. #include <stdlib.h>
  9. #include <string.h>
  10. #define QUEUESIZE 1024
  11. #define HASHSIZE 4096
  12. #define SMALLSTRING 2048
  13. #define TYPE_DATA 1
  14. #define TYPE_MORE 2
  15. #define TYPE_ERROR 3
  16. #define TYPE_OPEN 4
  17. #define TYPE_CLOSE 5
  18. #define TYPE_WARNING 6
  19. #define TYPE_INIT 7
  20. /*
  21. Each package is uint16 + data , uint16 (serialized in big-endian) is the number of bytes comprising the data .
  22. */
  23. struct netpack {
  24. int id;
  25. int size;
  26. void * buffer;
  27. };
  28. struct uncomplete {
  29. struct netpack pack;
  30. struct uncomplete * next;
  31. int read;
  32. int header;
  33. };
  34. struct queue {
  35. int cap;
  36. int head;
  37. int tail;
  38. struct uncomplete * hash[HASHSIZE];
  39. struct netpack queue[QUEUESIZE];
  40. };
  41. static void
  42. clear_list(struct uncomplete * uc) {
  43. while (uc) {
  44. skynet_free(uc->pack.buffer);
  45. void * tmp = uc;
  46. uc = uc->next;
  47. skynet_free(tmp);
  48. }
  49. }
  50. static int
  51. lclear(lua_State *L) {
  52. struct queue * q = lua_touserdata(L, 1);
  53. if (q == NULL) {
  54. return 0;
  55. }
  56. int i;
  57. for (i=0;i<HASHSIZE;i++) {
  58. clear_list(q->hash[i]);
  59. q->hash[i] = NULL;
  60. }
  61. if (q->head > q->tail) {
  62. q->tail += q->cap;
  63. }
  64. for (i=q->head;i<q->tail;i++) {
  65. struct netpack *np = &q->queue[i % q->cap];
  66. skynet_free(np->buffer);
  67. }
  68. q->head = q->tail = 0;
  69. return 0;
  70. }
  71. static inline int
  72. hash_fd(int fd) {
  73. int a = fd >> 24;
  74. int b = fd >> 12;
  75. int c = fd;
  76. return (int)(((uint32_t)(a + b + c)) % HASHSIZE);
  77. }
  78. static struct uncomplete *
  79. find_uncomplete(struct queue *q, int fd) {
  80. if (q == NULL)
  81. return NULL;
  82. int h = hash_fd(fd);
  83. struct uncomplete * uc = q->hash[h];
  84. if (uc == NULL)
  85. return NULL;
  86. if (uc->pack.id == fd) {
  87. q->hash[h] = uc->next;
  88. return uc;
  89. }
  90. struct uncomplete * last = uc;
  91. while (last->next) {
  92. uc = last->next;
  93. if (uc->pack.id == fd) {
  94. last->next = uc->next;
  95. return uc;
  96. }
  97. last = uc;
  98. }
  99. return NULL;
  100. }
  101. static struct queue *
  102. get_queue(lua_State *L) {
  103. struct queue *q = lua_touserdata(L,1);
  104. if (q == NULL) {
  105. q = lua_newuserdatauv(L, sizeof(struct queue), 0);
  106. q->cap = QUEUESIZE;
  107. q->head = 0;
  108. q->tail = 0;
  109. int i;
  110. for (i=0;i<HASHSIZE;i++) {
  111. q->hash[i] = NULL;
  112. }
  113. lua_replace(L, 1);
  114. }
  115. return q;
  116. }
  117. static void
  118. expand_queue(lua_State *L, struct queue *q) {
  119. struct queue *nq = lua_newuserdatauv(L, sizeof(struct queue) + q->cap * sizeof(struct netpack), 0);
  120. nq->cap = q->cap + QUEUESIZE;
  121. nq->head = 0;
  122. nq->tail = q->cap;
  123. memcpy(nq->hash, q->hash, sizeof(nq->hash));
  124. memset(q->hash, 0, sizeof(q->hash));
  125. int i;
  126. for (i=0;i<q->cap;i++) {
  127. int idx = (q->head + i) % q->cap;
  128. nq->queue[i] = q->queue[idx];
  129. }
  130. q->head = q->tail = 0;
  131. lua_replace(L,1);
  132. }
  133. static void
  134. push_data(lua_State *L, int fd, void *buffer, int size, int clone) {
  135. if (clone) {
  136. void * tmp = skynet_malloc(size);
  137. memcpy(tmp, buffer, size);
  138. buffer = tmp;
  139. }
  140. struct queue *q = get_queue(L);
  141. struct netpack *np = &q->queue[q->tail];
  142. if (++q->tail >= q->cap)
  143. q->tail -= q->cap;
  144. np->id = fd;
  145. np->buffer = buffer;
  146. np->size = size;
  147. if (q->head == q->tail) {
  148. expand_queue(L, q);
  149. }
  150. }
  151. static struct uncomplete *
  152. save_uncomplete(lua_State *L, int fd) {
  153. struct queue *q = get_queue(L);
  154. int h = hash_fd(fd);
  155. struct uncomplete * uc = skynet_malloc(sizeof(struct uncomplete));
  156. memset(uc, 0, sizeof(*uc));
  157. uc->next = q->hash[h];
  158. uc->pack.id = fd;
  159. q->hash[h] = uc;
  160. return uc;
  161. }
  162. static inline int
  163. read_size(uint8_t * buffer) {
  164. int r = (int)buffer[0] << 8 | (int)buffer[1];
  165. return r;
  166. }
  167. static void
  168. push_more(lua_State *L, int fd, uint8_t *buffer, int size) {
  169. if (size == 1) {
  170. struct uncomplete * uc = save_uncomplete(L, fd);
  171. uc->read = -1;
  172. uc->header = *buffer;
  173. return;
  174. }
  175. int pack_size = read_size(buffer);
  176. buffer += 2;
  177. size -= 2;
  178. if (size < pack_size) {
  179. struct uncomplete * uc = save_uncomplete(L, fd);
  180. uc->read = size;
  181. uc->pack.size = pack_size;
  182. uc->pack.buffer = skynet_malloc(pack_size);
  183. memcpy(uc->pack.buffer, buffer, size);
  184. return;
  185. }
  186. push_data(L, fd, buffer, pack_size, 1);
  187. buffer += pack_size;
  188. size -= pack_size;
  189. if (size > 0) {
  190. push_more(L, fd, buffer, size);
  191. }
  192. }
  193. static void
  194. close_uncomplete(lua_State *L, int fd) {
  195. struct queue *q = lua_touserdata(L,1);
  196. struct uncomplete * uc = find_uncomplete(q, fd);
  197. if (uc) {
  198. skynet_free(uc->pack.buffer);
  199. skynet_free(uc);
  200. }
  201. }
  202. static int
  203. filter_data_(lua_State *L, int fd, uint8_t * buffer, int size) {
  204. struct queue *q = lua_touserdata(L,1);
  205. struct uncomplete * uc = find_uncomplete(q, fd);
  206. if (uc) {
  207. // fill uncomplete
  208. if (uc->read < 0) {
  209. // read size
  210. assert(uc->read == -1);
  211. int pack_size = *buffer;
  212. pack_size |= uc->header << 8 ;
  213. ++buffer;
  214. --size;
  215. uc->pack.size = pack_size;
  216. uc->pack.buffer = skynet_malloc(pack_size);
  217. uc->read = 0;
  218. }
  219. int need = uc->pack.size - uc->read;
  220. if (size < need) {
  221. memcpy(uc->pack.buffer + uc->read, buffer, size);
  222. uc->read += size;
  223. int h = hash_fd(fd);
  224. uc->next = q->hash[h];
  225. q->hash[h] = uc;
  226. return 1;
  227. }
  228. memcpy(uc->pack.buffer + uc->read, buffer, need);
  229. buffer += need;
  230. size -= need;
  231. if (size == 0) {
  232. lua_pushvalue(L, lua_upvalueindex(TYPE_DATA));
  233. lua_pushinteger(L, fd);
  234. lua_pushlightuserdata(L, uc->pack.buffer);
  235. lua_pushinteger(L, uc->pack.size);
  236. skynet_free(uc);
  237. return 5;
  238. }
  239. // more data
  240. push_data(L, fd, uc->pack.buffer, uc->pack.size, 0);
  241. skynet_free(uc);
  242. push_more(L, fd, buffer, size);
  243. lua_pushvalue(L, lua_upvalueindex(TYPE_MORE));
  244. return 2;
  245. } else {
  246. if (size == 1) {
  247. struct uncomplete * uc = save_uncomplete(L, fd);
  248. uc->read = -1;
  249. uc->header = *buffer;
  250. return 1;
  251. }
  252. int pack_size = read_size(buffer);
  253. buffer+=2;
  254. size-=2;
  255. if (size < pack_size) {
  256. struct uncomplete * uc = save_uncomplete(L, fd);
  257. uc->read = size;
  258. uc->pack.size = pack_size;
  259. uc->pack.buffer = skynet_malloc(pack_size);
  260. memcpy(uc->pack.buffer, buffer, size);
  261. return 1;
  262. }
  263. if (size == pack_size) {
  264. // just one package
  265. lua_pushvalue(L, lua_upvalueindex(TYPE_DATA));
  266. lua_pushinteger(L, fd);
  267. void * result = skynet_malloc(pack_size);
  268. memcpy(result, buffer, size);
  269. lua_pushlightuserdata(L, result);
  270. lua_pushinteger(L, size);
  271. return 5;
  272. }
  273. // more data
  274. push_data(L, fd, buffer, pack_size, 1);
  275. buffer += pack_size;
  276. size -= pack_size;
  277. push_more(L, fd, buffer, size);
  278. lua_pushvalue(L, lua_upvalueindex(TYPE_MORE));
  279. return 2;
  280. }
  281. }
  282. static inline int
  283. filter_data(lua_State *L, int fd, uint8_t * buffer, int size) {
  284. int ret = filter_data_(L, fd, buffer, size);
  285. // buffer is the data of socket message, it malloc at socket_server.c : function forward_message .
  286. // it should be free before return,
  287. skynet_free(buffer);
  288. return ret;
  289. }
  290. static void
  291. pushstring(lua_State *L, const char * msg, int size) {
  292. if (msg) {
  293. lua_pushlstring(L, msg, size);
  294. } else {
  295. lua_pushliteral(L, "");
  296. }
  297. }
  298. /*
  299. userdata queue
  300. lightuserdata msg
  301. integer size
  302. return
  303. userdata queue
  304. integer type
  305. integer fd
  306. string msg | lightuserdata/integer
  307. */
  308. static int
  309. lfilter(lua_State *L) {
  310. struct skynet_socket_message *message = lua_touserdata(L,2);
  311. int size = luaL_checkinteger(L,3);
  312. char * buffer = message->buffer;
  313. if (buffer == NULL) {
  314. buffer = (char *)(message+1);
  315. size -= sizeof(*message);
  316. } else {
  317. size = -1;
  318. }
  319. lua_settop(L, 1);
  320. switch(message->type) {
  321. case SKYNET_SOCKET_TYPE_DATA:
  322. // ignore listen id (message->id)
  323. assert(size == -1); // never padding string
  324. return filter_data(L, message->id, (uint8_t *)buffer, message->ud);
  325. case SKYNET_SOCKET_TYPE_CONNECT:
  326. lua_pushvalue(L, lua_upvalueindex(TYPE_INIT));
  327. lua_pushinteger(L, message->id);
  328. lua_pushlstring(L, buffer, size);
  329. lua_pushinteger(L, message->ud);
  330. return 5;
  331. case SKYNET_SOCKET_TYPE_CLOSE:
  332. // no more data in fd (message->id)
  333. close_uncomplete(L, message->id);
  334. lua_pushvalue(L, lua_upvalueindex(TYPE_CLOSE));
  335. lua_pushinteger(L, message->id);
  336. return 3;
  337. case SKYNET_SOCKET_TYPE_ACCEPT:
  338. lua_pushvalue(L, lua_upvalueindex(TYPE_OPEN));
  339. // ignore listen id (message->id);
  340. lua_pushinteger(L, message->ud);
  341. pushstring(L, buffer, size);
  342. return 4;
  343. case SKYNET_SOCKET_TYPE_ERROR:
  344. // no more data in fd (message->id)
  345. close_uncomplete(L, message->id);
  346. lua_pushvalue(L, lua_upvalueindex(TYPE_ERROR));
  347. lua_pushinteger(L, message->id);
  348. pushstring(L, buffer, size);
  349. return 4;
  350. case SKYNET_SOCKET_TYPE_WARNING:
  351. lua_pushvalue(L, lua_upvalueindex(TYPE_WARNING));
  352. lua_pushinteger(L, message->id);
  353. lua_pushinteger(L, message->ud);
  354. return 4;
  355. default:
  356. // never get here
  357. return 1;
  358. }
  359. }
  360. /*
  361. userdata queue
  362. return
  363. integer fd
  364. lightuserdata msg
  365. integer size
  366. */
  367. static int
  368. lpop(lua_State *L) {
  369. struct queue * q = lua_touserdata(L, 1);
  370. if (q == NULL || q->head == q->tail)
  371. return 0;
  372. struct netpack *np = &q->queue[q->head];
  373. if (++q->head >= q->cap) {
  374. q->head = 0;
  375. }
  376. lua_pushinteger(L, np->id);
  377. lua_pushlightuserdata(L, np->buffer);
  378. lua_pushinteger(L, np->size);
  379. return 3;
  380. }
  381. /*
  382. string msg | lightuserdata/integer
  383. lightuserdata/integer
  384. */
  385. static const char *
  386. tolstring(lua_State *L, size_t *sz, int index) {
  387. const char * ptr;
  388. if (lua_isuserdata(L,index)) {
  389. ptr = (const char *)lua_touserdata(L,index);
  390. *sz = (size_t)luaL_checkinteger(L, index+1);
  391. } else {
  392. ptr = luaL_checklstring(L, index, sz);
  393. }
  394. return ptr;
  395. }
  396. static inline void
  397. write_size(uint8_t * buffer, int len) {
  398. buffer[0] = (len >> 8) & 0xff;
  399. buffer[1] = len & 0xff;
  400. }
  401. static int
  402. lpack(lua_State *L) {
  403. size_t len;
  404. const char * ptr = tolstring(L, &len, 1);
  405. if (len >= 0x10000) {
  406. return luaL_error(L, "Invalid size (too long) of data : %d", (int)len);
  407. }
  408. uint8_t * buffer = skynet_malloc(len + 2);
  409. write_size(buffer, len);
  410. memcpy(buffer+2, ptr, len);
  411. lua_pushlightuserdata(L, buffer);
  412. lua_pushinteger(L, len + 2);
  413. return 2;
  414. }
  415. static int
  416. ltostring(lua_State *L) {
  417. void * ptr = lua_touserdata(L, 1);
  418. int size = luaL_checkinteger(L, 2);
  419. if (ptr == NULL) {
  420. lua_pushliteral(L, "");
  421. } else {
  422. lua_pushlstring(L, (const char *)ptr, size);
  423. skynet_free(ptr);
  424. }
  425. return 1;
  426. }
  427. LUAMOD_API int
  428. luaopen_skynet_netpack(lua_State *L) {
  429. luaL_checkversion(L);
  430. luaL_Reg l[] = {
  431. { "pop", lpop },
  432. { "pack", lpack },
  433. { "clear", lclear },
  434. { "tostring", ltostring },
  435. { NULL, NULL },
  436. };
  437. luaL_newlib(L,l);
  438. // the order is same with macros : TYPE_* (defined top)
  439. lua_pushliteral(L, "data");
  440. lua_pushliteral(L, "more");
  441. lua_pushliteral(L, "error");
  442. lua_pushliteral(L, "open");
  443. lua_pushliteral(L, "close");
  444. lua_pushliteral(L, "warning");
  445. lua_pushliteral(L, "init");
  446. lua_pushcclosure(L, lfilter, 7);
  447. lua_setfield(L, -2, "filter");
  448. return 1;
  449. }