lua-mysqlaux.c 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #define LUA_LIB
  2. #include <stdio.h>
  3. #include <stdlib.h>
  4. #include <string.h>
  5. #include <lua.h>
  6. #include <lauxlib.h>
  7. static unsigned int num_escape_sql_str(unsigned char *dst, unsigned char *src, size_t size)
  8. {
  9. unsigned int n =0;
  10. while (size) {
  11. /* the highest bit of all the UTF-8 chars
  12. * is always 1 */
  13. if ((*src & 0x80) == 0) {
  14. switch (*src) {
  15. case '\0':
  16. case '\b':
  17. case '\n':
  18. case '\r':
  19. case '\t':
  20. case 26: /* \Z */
  21. case '\\':
  22. case '\'':
  23. case '"':
  24. n++;
  25. break;
  26. default:
  27. break;
  28. }
  29. }
  30. src++;
  31. size--;
  32. }
  33. return n;
  34. }
  35. static unsigned char*
  36. escape_sql_str(unsigned char *dst, unsigned char *src, size_t size)
  37. {
  38. while (size) {
  39. if ((*src & 0x80) == 0) {
  40. switch (*src) {
  41. case '\0':
  42. *dst++ = '\\';
  43. *dst++ = '0';
  44. break;
  45. case '\b':
  46. *dst++ = '\\';
  47. *dst++ = 'b';
  48. break;
  49. case '\n':
  50. *dst++ = '\\';
  51. *dst++ = 'n';
  52. break;
  53. case '\r':
  54. *dst++ = '\\';
  55. *dst++ = 'r';
  56. break;
  57. case '\t':
  58. *dst++ = '\\';
  59. *dst++ = 't';
  60. break;
  61. case 26:
  62. *dst++ = '\\';
  63. *dst++ = 'Z';
  64. break;
  65. case '\\':
  66. *dst++ = '\\';
  67. *dst++ = '\\';
  68. break;
  69. case '\'':
  70. *dst++ = '\\';
  71. *dst++ = '\'';
  72. break;
  73. case '"':
  74. *dst++ = '\\';
  75. *dst++ = '"';
  76. break;
  77. default:
  78. *dst++ = *src;
  79. break;
  80. }
  81. } else {
  82. *dst++ = *src;
  83. }
  84. src++;
  85. size--;
  86. } /* while (size) */
  87. return dst;
  88. }
  89. static int
  90. quote_sql_str(lua_State *L)
  91. {
  92. size_t len, dlen, escape;
  93. unsigned char *p;
  94. unsigned char *src, *dst;
  95. if (lua_gettop(L) != 1) {
  96. return luaL_error(L, "expecting one argument");
  97. }
  98. src = (unsigned char *) luaL_checklstring(L, 1, &len);
  99. if (len == 0) {
  100. dst = (unsigned char *) "''";
  101. dlen = sizeof("''") - 1;
  102. lua_pushlstring(L, (char *) dst, dlen);
  103. return 1;
  104. }
  105. escape = num_escape_sql_str(NULL, src, len);
  106. dlen = sizeof("''") - 1 + len + escape;
  107. p = lua_newuserdata(L, dlen);
  108. dst = p;
  109. *p++ = '\'';
  110. if (escape == 0) {
  111. memcpy(p, src, len);
  112. p+=len;
  113. } else {
  114. p = (unsigned char *) escape_sql_str(p, src, len);
  115. }
  116. *p++ = '\'';
  117. if (p != dst + dlen) {
  118. return luaL_error(L, "quote sql string error");
  119. }
  120. lua_pushlstring(L, (char *) dst, p - dst);
  121. return 1;
  122. }
  123. static struct luaL_Reg mysqlauxlib[] = {
  124. {"quote_sql_str",quote_sql_str},
  125. {NULL, NULL}
  126. };
  127. LUAMOD_API int luaopen_skynet_mysqlaux_c (lua_State *L) {
  128. lua_newtable(L);
  129. luaL_setfuncs(L, mysqlauxlib, 0);
  130. return 1;
  131. }