123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497 |
- #define LUA_LIB
- #include "skynet_malloc.h"
- #include "skynet_socket.h"
- #include <lua.h>
- #include <lauxlib.h>
- #include <assert.h>
- #include <stdint.h>
- #include <stdlib.h>
- #include <string.h>
- #define QUEUESIZE 1024
- #define HASHSIZE 4096
- #define SMALLSTRING 2048
- #define TYPE_DATA 1
- #define TYPE_MORE 2
- #define TYPE_ERROR 3
- #define TYPE_OPEN 4
- #define TYPE_CLOSE 5
- #define TYPE_WARNING 6
- #define TYPE_INIT 7
- /*
- Each package is uint16 + data , uint16 (serialized in big-endian) is the number of bytes comprising the data .
- */
- struct netpack {
- int id;
- int size;
- void * buffer;
- };
- struct uncomplete {
- struct netpack pack;
- struct uncomplete * next;
- int read;
- int header;
- };
- struct queue {
- int cap;
- int head;
- int tail;
- struct uncomplete * hash[HASHSIZE];
- struct netpack queue[QUEUESIZE];
- };
- static void
- clear_list(struct uncomplete * uc) {
- while (uc) {
- skynet_free(uc->pack.buffer);
- void * tmp = uc;
- uc = uc->next;
- skynet_free(tmp);
- }
- }
- static int
- lclear(lua_State *L) {
- struct queue * q = lua_touserdata(L, 1);
- if (q == NULL) {
- return 0;
- }
- int i;
- for (i=0;i<HASHSIZE;i++) {
- clear_list(q->hash[i]);
- q->hash[i] = NULL;
- }
- if (q->head > q->tail) {
- q->tail += q->cap;
- }
- for (i=q->head;i<q->tail;i++) {
- struct netpack *np = &q->queue[i % q->cap];
- skynet_free(np->buffer);
- }
- q->head = q->tail = 0;
- return 0;
- }
- static inline int
- hash_fd(int fd) {
- int a = fd >> 24;
- int b = fd >> 12;
- int c = fd;
- return (int)(((uint32_t)(a + b + c)) % HASHSIZE);
- }
- static struct uncomplete *
- find_uncomplete(struct queue *q, int fd) {
- if (q == NULL)
- return NULL;
- int h = hash_fd(fd);
- struct uncomplete * uc = q->hash[h];
- if (uc == NULL)
- return NULL;
- if (uc->pack.id == fd) {
- q->hash[h] = uc->next;
- return uc;
- }
- struct uncomplete * last = uc;
- while (last->next) {
- uc = last->next;
- if (uc->pack.id == fd) {
- last->next = uc->next;
- return uc;
- }
- last = uc;
- }
- return NULL;
- }
- static struct queue *
- get_queue(lua_State *L) {
- struct queue *q = lua_touserdata(L,1);
- if (q == NULL) {
- q = lua_newuserdatauv(L, sizeof(struct queue), 0);
- q->cap = QUEUESIZE;
- q->head = 0;
- q->tail = 0;
- int i;
- for (i=0;i<HASHSIZE;i++) {
- q->hash[i] = NULL;
- }
- lua_replace(L, 1);
- }
- return q;
- }
- static void
- expand_queue(lua_State *L, struct queue *q) {
- struct queue *nq = lua_newuserdatauv(L, sizeof(struct queue) + q->cap * sizeof(struct netpack), 0);
- nq->cap = q->cap + QUEUESIZE;
- nq->head = 0;
- nq->tail = q->cap;
- memcpy(nq->hash, q->hash, sizeof(nq->hash));
- memset(q->hash, 0, sizeof(q->hash));
- int i;
- for (i=0;i<q->cap;i++) {
- int idx = (q->head + i) % q->cap;
- nq->queue[i] = q->queue[idx];
- }
- q->head = q->tail = 0;
- lua_replace(L,1);
- }
- static void
- push_data(lua_State *L, int fd, void *buffer, int size, int clone) {
- if (clone) {
- void * tmp = skynet_malloc(size);
- memcpy(tmp, buffer, size);
- buffer = tmp;
- }
- struct queue *q = get_queue(L);
- struct netpack *np = &q->queue[q->tail];
- if (++q->tail >= q->cap)
- q->tail -= q->cap;
- np->id = fd;
- np->buffer = buffer;
- np->size = size;
- if (q->head == q->tail) {
- expand_queue(L, q);
- }
- }
- static struct uncomplete *
- save_uncomplete(lua_State *L, int fd) {
- struct queue *q = get_queue(L);
- int h = hash_fd(fd);
- struct uncomplete * uc = skynet_malloc(sizeof(struct uncomplete));
- memset(uc, 0, sizeof(*uc));
- uc->next = q->hash[h];
- uc->pack.id = fd;
- q->hash[h] = uc;
- return uc;
- }
- static inline int
- read_size(uint8_t * buffer) {
- int r = (int)buffer[0] << 8 | (int)buffer[1];
- return r;
- }
- static void
- push_more(lua_State *L, int fd, uint8_t *buffer, int size) {
- if (size == 1) {
- struct uncomplete * uc = save_uncomplete(L, fd);
- uc->read = -1;
- uc->header = *buffer;
- return;
- }
- int pack_size = read_size(buffer);
- buffer += 2;
- size -= 2;
- if (size < pack_size) {
- struct uncomplete * uc = save_uncomplete(L, fd);
- uc->read = size;
- uc->pack.size = pack_size;
- uc->pack.buffer = skynet_malloc(pack_size);
- memcpy(uc->pack.buffer, buffer, size);
- return;
- }
- push_data(L, fd, buffer, pack_size, 1);
- buffer += pack_size;
- size -= pack_size;
- if (size > 0) {
- push_more(L, fd, buffer, size);
- }
- }
- static void
- close_uncomplete(lua_State *L, int fd) {
- struct queue *q = lua_touserdata(L,1);
- struct uncomplete * uc = find_uncomplete(q, fd);
- if (uc) {
- skynet_free(uc->pack.buffer);
- skynet_free(uc);
- }
- }
- static int
- filter_data_(lua_State *L, int fd, uint8_t * buffer, int size) {
- struct queue *q = lua_touserdata(L,1);
- struct uncomplete * uc = find_uncomplete(q, fd);
- if (uc) {
- // fill uncomplete
- if (uc->read < 0) {
- // read size
- assert(uc->read == -1);
- int pack_size = *buffer;
- pack_size |= uc->header << 8 ;
- ++buffer;
- --size;
- uc->pack.size = pack_size;
- uc->pack.buffer = skynet_malloc(pack_size);
- uc->read = 0;
- }
- int need = uc->pack.size - uc->read;
- if (size < need) {
- memcpy(uc->pack.buffer + uc->read, buffer, size);
- uc->read += size;
- int h = hash_fd(fd);
- uc->next = q->hash[h];
- q->hash[h] = uc;
- return 1;
- }
- memcpy(uc->pack.buffer + uc->read, buffer, need);
- buffer += need;
- size -= need;
- if (size == 0) {
- lua_pushvalue(L, lua_upvalueindex(TYPE_DATA));
- lua_pushinteger(L, fd);
- lua_pushlightuserdata(L, uc->pack.buffer);
- lua_pushinteger(L, uc->pack.size);
- skynet_free(uc);
- return 5;
- }
- // more data
- push_data(L, fd, uc->pack.buffer, uc->pack.size, 0);
- skynet_free(uc);
- push_more(L, fd, buffer, size);
- lua_pushvalue(L, lua_upvalueindex(TYPE_MORE));
- return 2;
- } else {
- if (size == 1) {
- struct uncomplete * uc = save_uncomplete(L, fd);
- uc->read = -1;
- uc->header = *buffer;
- return 1;
- }
- int pack_size = read_size(buffer);
- buffer+=2;
- size-=2;
- if (size < pack_size) {
- struct uncomplete * uc = save_uncomplete(L, fd);
- uc->read = size;
- uc->pack.size = pack_size;
- uc->pack.buffer = skynet_malloc(pack_size);
- memcpy(uc->pack.buffer, buffer, size);
- return 1;
- }
- if (size == pack_size) {
- // just one package
- lua_pushvalue(L, lua_upvalueindex(TYPE_DATA));
- lua_pushinteger(L, fd);
- void * result = skynet_malloc(pack_size);
- memcpy(result, buffer, size);
- lua_pushlightuserdata(L, result);
- lua_pushinteger(L, size);
- return 5;
- }
- // more data
- push_data(L, fd, buffer, pack_size, 1);
- buffer += pack_size;
- size -= pack_size;
- push_more(L, fd, buffer, size);
- lua_pushvalue(L, lua_upvalueindex(TYPE_MORE));
- return 2;
- }
- }
- static inline int
- filter_data(lua_State *L, int fd, uint8_t * buffer, int size) {
- int ret = filter_data_(L, fd, buffer, size);
- // buffer is the data of socket message, it malloc at socket_server.c : function forward_message .
- // it should be free before return,
- skynet_free(buffer);
- return ret;
- }
- static void
- pushstring(lua_State *L, const char * msg, int size) {
- if (msg) {
- lua_pushlstring(L, msg, size);
- } else {
- lua_pushliteral(L, "");
- }
- }
- /*
- userdata queue
- lightuserdata msg
- integer size
- return
- userdata queue
- integer type
- integer fd
- string msg | lightuserdata/integer
- */
- static int
- lfilter(lua_State *L) {
- struct skynet_socket_message *message = lua_touserdata(L,2);
- int size = luaL_checkinteger(L,3);
- char * buffer = message->buffer;
- if (buffer == NULL) {
- buffer = (char *)(message+1);
- size -= sizeof(*message);
- } else {
- size = -1;
- }
- lua_settop(L, 1);
- switch(message->type) {
- case SKYNET_SOCKET_TYPE_DATA:
- // ignore listen id (message->id)
- assert(size == -1); // never padding string
- return filter_data(L, message->id, (uint8_t *)buffer, message->ud);
- case SKYNET_SOCKET_TYPE_CONNECT:
- lua_pushvalue(L, lua_upvalueindex(TYPE_INIT));
- lua_pushinteger(L, message->id);
- lua_pushlstring(L, buffer, size);
- lua_pushinteger(L, message->ud);
- return 5;
- case SKYNET_SOCKET_TYPE_CLOSE:
- // no more data in fd (message->id)
- close_uncomplete(L, message->id);
- lua_pushvalue(L, lua_upvalueindex(TYPE_CLOSE));
- lua_pushinteger(L, message->id);
- return 3;
- case SKYNET_SOCKET_TYPE_ACCEPT:
- lua_pushvalue(L, lua_upvalueindex(TYPE_OPEN));
- // ignore listen id (message->id);
- lua_pushinteger(L, message->ud);
- pushstring(L, buffer, size);
- return 4;
- case SKYNET_SOCKET_TYPE_ERROR:
- // no more data in fd (message->id)
- close_uncomplete(L, message->id);
- lua_pushvalue(L, lua_upvalueindex(TYPE_ERROR));
- lua_pushinteger(L, message->id);
- pushstring(L, buffer, size);
- return 4;
- case SKYNET_SOCKET_TYPE_WARNING:
- lua_pushvalue(L, lua_upvalueindex(TYPE_WARNING));
- lua_pushinteger(L, message->id);
- lua_pushinteger(L, message->ud);
- return 4;
- default:
- // never get here
- return 1;
- }
- }
- /*
- userdata queue
- return
- integer fd
- lightuserdata msg
- integer size
- */
- static int
- lpop(lua_State *L) {
- struct queue * q = lua_touserdata(L, 1);
- if (q == NULL || q->head == q->tail)
- return 0;
- struct netpack *np = &q->queue[q->head];
- if (++q->head >= q->cap) {
- q->head = 0;
- }
- lua_pushinteger(L, np->id);
- lua_pushlightuserdata(L, np->buffer);
- lua_pushinteger(L, np->size);
- return 3;
- }
- /*
- string msg | lightuserdata/integer
- lightuserdata/integer
- */
- static const char *
- tolstring(lua_State *L, size_t *sz, int index) {
- const char * ptr;
- if (lua_isuserdata(L,index)) {
- ptr = (const char *)lua_touserdata(L,index);
- *sz = (size_t)luaL_checkinteger(L, index+1);
- } else {
- ptr = luaL_checklstring(L, index, sz);
- }
- return ptr;
- }
- static inline void
- write_size(uint8_t * buffer, int len) {
- buffer[0] = (len >> 8) & 0xff;
- buffer[1] = len & 0xff;
- }
- static int
- lpack(lua_State *L) {
- size_t len;
- const char * ptr = tolstring(L, &len, 1);
- if (len >= 0x10000) {
- return luaL_error(L, "Invalid size (too long) of data : %d", (int)len);
- }
- uint8_t * buffer = skynet_malloc(len + 2);
- write_size(buffer, len);
- memcpy(buffer+2, ptr, len);
- lua_pushlightuserdata(L, buffer);
- lua_pushinteger(L, len + 2);
- return 2;
- }
- static int
- ltostring(lua_State *L) {
- void * ptr = lua_touserdata(L, 1);
- int size = luaL_checkinteger(L, 2);
- if (ptr == NULL) {
- lua_pushliteral(L, "");
- } else {
- lua_pushlstring(L, (const char *)ptr, size);
- skynet_free(ptr);
- }
- return 1;
- }
- LUAMOD_API int
- luaopen_skynet_netpack(lua_State *L) {
- luaL_checkversion(L);
- luaL_Reg l[] = {
- { "pop", lpop },
- { "pack", lpack },
- { "clear", lclear },
- { "tostring", ltostring },
- { NULL, NULL },
- };
- luaL_newlib(L,l);
- // the order is same with macros : TYPE_* (defined top)
- lua_pushliteral(L, "data");
- lua_pushliteral(L, "more");
- lua_pushliteral(L, "error");
- lua_pushliteral(L, "open");
- lua_pushliteral(L, "close");
- lua_pushliteral(L, "warning");
- lua_pushliteral(L, "init");
- lua_pushcclosure(L, lfilter, 7);
- lua_setfield(L, -2, "filter");
- return 1;
- }
|