lua-mongo.c 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. #define LUA_LIB
  2. #include "skynet_malloc.h"
  3. #include <lua.h>
  4. #include <lauxlib.h>
  5. #include <stdint.h>
  6. #include <stdlib.h>
  7. #include <stdio.h>
  8. #include <string.h>
  9. #define OP_COMPRESSED 2012
  10. #define OP_MSG 2013
  11. typedef enum {
  12. MSG_CHECKSUM_PRESENT = 1 << 0,
  13. MSG_MORE_TO_COME = 1 << 1,
  14. MSG_EXHAUST_ALLOWED = 1 << 16,
  15. } msg_flags_t;
  16. #define DEFAULT_CAP 128
  17. struct connection {
  18. int sock;
  19. int id;
  20. };
  21. struct response {
  22. int flags;
  23. int32_t cursor_id[2];
  24. int starting_from;
  25. int number;
  26. };
  27. struct buffer {
  28. int size;
  29. int cap;
  30. uint8_t * ptr;
  31. uint8_t buffer[DEFAULT_CAP];
  32. };
  33. static inline uint32_t
  34. little_endian(uint32_t v) {
  35. union {
  36. uint32_t v;
  37. uint8_t b[4];
  38. } u;
  39. u.v = v;
  40. return u.b[0] | u.b[1] << 8 | u.b[2] << 16 | u.b[3] << 24;
  41. }
  42. typedef void * document;
  43. static inline uint32_t
  44. get_length(document buffer) {
  45. union {
  46. uint32_t v;
  47. uint8_t b[4];
  48. } u;
  49. memcpy(&u.v, buffer, 4);
  50. return u.b[0] | u.b[1] << 8 | u.b[2] << 16 | u.b[3] << 24;
  51. }
  52. static inline void
  53. buffer_destroy(struct buffer *b) {
  54. if (b->ptr != b->buffer) {
  55. skynet_free(b->ptr);
  56. }
  57. }
  58. static inline void
  59. buffer_create(struct buffer *b) {
  60. b->size = 0;
  61. b->cap = DEFAULT_CAP;
  62. b->ptr = b->buffer;
  63. }
  64. static inline void
  65. buffer_reserve(struct buffer *b, int sz) {
  66. if (b->size + sz <= b->cap)
  67. return;
  68. do {
  69. b->cap *= 2;
  70. } while (b->cap <= b->size + sz);
  71. if (b->ptr == b->buffer) {
  72. b->ptr = (uint8_t*)malloc(b->cap);
  73. memcpy(b->ptr, b->buffer, b->size);
  74. } else {
  75. b->ptr = (uint8_t*)realloc(b->ptr, b->cap);
  76. }
  77. }
  78. static inline void
  79. write_int32(struct buffer *b, int32_t v) {
  80. uint32_t uv = (uint32_t)v;
  81. buffer_reserve(b,4);
  82. b->ptr[b->size++] = uv & 0xff;
  83. b->ptr[b->size++] = (uv >> 8)&0xff;
  84. b->ptr[b->size++] = (uv >> 16)&0xff;
  85. b->ptr[b->size++] = (uv >> 24)&0xff;
  86. }
  87. static inline void
  88. write_int8(struct buffer *b, int8_t v) {
  89. uint8_t uv = (uint8_t)v;
  90. buffer_reserve(b, 1);
  91. b->ptr[b->size++] = uv;
  92. }
  93. /*
  94. static inline void
  95. write_bytes(struct buffer *b, const void * buf, int sz) {
  96. buffer_reserve(b,sz);
  97. memcpy(b->ptr + b->size, buf, sz);
  98. b->size += sz;
  99. }
  100. static void
  101. write_string(struct buffer *b, const char *key, size_t sz) {
  102. buffer_reserve(b,sz+1);
  103. memcpy(b->ptr + b->size, key, sz);
  104. b->ptr[b->size+sz] = '\0';
  105. b->size+=sz+1;
  106. }
  107. */
  108. static inline int
  109. reserve_length(struct buffer *b) {
  110. int sz = b->size;
  111. buffer_reserve(b,4);
  112. b->size +=4;
  113. return sz;
  114. }
  115. static inline void
  116. write_length(struct buffer *b, int32_t v, int off) {
  117. uint32_t uv = (uint32_t)v;
  118. b->ptr[off++] = uv & 0xff;
  119. b->ptr[off++] = (uv >> 8)&0xff;
  120. b->ptr[off++] = (uv >> 16)&0xff;
  121. b->ptr[off++] = (uv >> 24)&0xff;
  122. }
  123. struct header_t {
  124. //int32_t message_length; // total message size, include this
  125. int32_t request_id; // identifier for this message
  126. int32_t response_to; // requestID from the original request(used in responses from the database)
  127. int32_t opcode; // message type
  128. int32_t flags;
  129. };
  130. // 1 string data
  131. // 2 result document table
  132. // return boolean succ (false -> request id, error document)
  133. // number request_id
  134. // document first
  135. static int
  136. unpack_reply(lua_State *L) {
  137. size_t data_len = 0;
  138. const char * data = luaL_checklstring(L,1,&data_len);
  139. const struct header_t* h = (const struct header_t*)data;
  140. if (data_len < sizeof(*h)) {
  141. lua_pushboolean(L, 0);
  142. return 1;
  143. }
  144. int opcode = little_endian(h->opcode);
  145. if (opcode != OP_MSG) {
  146. return luaL_error(L, "Unsupported opcode:%d", opcode);
  147. }
  148. int id = little_endian(h->response_to);
  149. int flags = little_endian(h->flags);
  150. if (flags != 0) {
  151. if ((flags & MSG_CHECKSUM_PRESENT) != 0) {
  152. return luaL_error(L, "Unsupported OP_MSG flag checksumPresent");
  153. }
  154. if ((flags ^ MSG_MORE_TO_COME) != 0) {
  155. return luaL_error(L, "Unsupported OP_MSG flag:%d", flags);
  156. }
  157. }
  158. int sz = (int)data_len - sizeof(*h);
  159. const uint8_t * section = (const uint8_t *)(h+1);
  160. uint8_t payload_type = *section;
  161. const uint8_t * doc = section+1;
  162. if (payload_type != 0) {
  163. return luaL_error(L, "Unsupported OP_MSG payload type: %d", payload_type);
  164. }
  165. int32_t doc_sz = get_length((document)(doc));
  166. if ((sz - 1) != doc_sz) {
  167. return luaL_error(L, "Unsupported OP_MSG reply: >1 section");
  168. }
  169. lua_pushboolean(L, 1);
  170. lua_pushinteger(L, id);
  171. lua_pushlightuserdata(L, (void *)(doc));
  172. return 3;
  173. }
  174. // string 4 bytes length
  175. // return integer
  176. static int
  177. reply_length(lua_State *L) {
  178. const char * rawlen_str = luaL_checkstring(L, 1);
  179. int rawlen = 0;
  180. memcpy(&rawlen, rawlen_str, sizeof(int));
  181. int length = little_endian(rawlen);
  182. lua_pushinteger(L, length - 4);
  183. return 1;
  184. }
  185. // @param 1 request_id int
  186. // @param 2 flags int
  187. // @param 3 command bson document
  188. // @return
  189. static int
  190. op_msg(lua_State *L) {
  191. int id = luaL_checkinteger(L, 1);
  192. int flags = luaL_checkinteger(L, 2);
  193. document cmd = lua_touserdata(L, 3);
  194. if (cmd == NULL) {
  195. return luaL_error(L, "opmsg require cmd document");
  196. }
  197. luaL_Buffer b;
  198. luaL_buffinit(L, &b);
  199. struct buffer buf;
  200. buffer_create(&buf);
  201. int len = reserve_length(&buf);
  202. write_int32(&buf, id);
  203. write_int32(&buf, 0);
  204. write_int32(&buf, OP_MSG);
  205. write_int32(&buf, flags);
  206. write_int8(&buf, 0);
  207. int32_t cmd_len = get_length(cmd);
  208. int total = buf.size + cmd_len;
  209. write_length(&buf, total, len);
  210. luaL_addlstring(&b, (const char *)buf.ptr, buf.size);
  211. buffer_destroy(&buf);
  212. luaL_addlstring(&b, (const char *)cmd, cmd_len);
  213. luaL_pushresult(&b);
  214. return 1;
  215. }
  216. LUAMOD_API int
  217. luaopen_skynet_mongo_driver(lua_State *L) {
  218. luaL_checkversion(L);
  219. luaL_Reg l[] ={
  220. { "reply", unpack_reply }, // 接收响应
  221. { "length", reply_length },
  222. { "op_msg", op_msg},
  223. { NULL, NULL },
  224. };
  225. luaL_newlib(L,l);
  226. return 1;
  227. }