util.lua 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. local json = require "cjson"
  2. -- Various common routines used by the Lua CJSON package
  3. --
  4. -- Mark Pulford <mark@kyne.com.au>
  5. -- Determine with a Lua table can be treated as an array.
  6. -- Explicitly returns "not an array" for very sparse arrays.
  7. -- Returns:
  8. -- -1 Not an array
  9. -- 0 Empty table
  10. -- >0 Highest index in the array
  11. local function is_array(table)
  12. local max = 0
  13. local count = 0
  14. for k, v in pairs(table) do
  15. if type(k) == "number" then
  16. if k > max then max = k end
  17. count = count + 1
  18. else
  19. return -1
  20. end
  21. end
  22. if max > count * 2 then
  23. return -1
  24. end
  25. return max
  26. end
  27. local serialise_value
  28. local function serialise_table(value, indent, depth)
  29. local spacing, spacing2, indent2
  30. if indent then
  31. spacing = "\n" .. indent
  32. spacing2 = spacing .. " "
  33. indent2 = indent .. " "
  34. else
  35. spacing, spacing2, indent2 = " ", " ", false
  36. end
  37. depth = depth + 1
  38. if depth > 50 then
  39. return "Cannot serialise any further: too many nested tables"
  40. end
  41. local max = is_array(value)
  42. local comma = false
  43. local fragment = { "{" .. spacing2 }
  44. if max > 0 then
  45. -- Serialise array
  46. for i = 1, max do
  47. if comma then
  48. table.insert(fragment, "," .. spacing2)
  49. end
  50. table.insert(fragment, serialise_value(value[i], indent2, depth))
  51. comma = true
  52. end
  53. elseif max < 0 then
  54. -- Serialise table
  55. for k, v in pairs(value) do
  56. if comma then
  57. table.insert(fragment, "," .. spacing2)
  58. end
  59. table.insert(fragment,
  60. ("[%s] = %s"):format(serialise_value(k, indent2, depth),
  61. serialise_value(v, indent2, depth)))
  62. comma = true
  63. end
  64. end
  65. table.insert(fragment, spacing .. "}")
  66. return table.concat(fragment)
  67. end
  68. function serialise_value(value, indent, depth)
  69. if indent == nil then indent = "" end
  70. if depth == nil then depth = 0 end
  71. if value == json.null then
  72. return "json.null"
  73. elseif type(value) == "string" then
  74. return ("%q"):format(value)
  75. elseif type(value) == "nil" or type(value) == "number" or
  76. type(value) == "boolean" then
  77. return tostring(value)
  78. elseif type(value) == "table" then
  79. return serialise_table(value, indent, depth)
  80. else
  81. return "\"<" .. type(value) .. ">\""
  82. end
  83. end
  84. local function file_load(filename)
  85. local file
  86. if filename == nil then
  87. file = io.stdin
  88. else
  89. local err
  90. file, err = io.open(filename, "rb")
  91. if file == nil then
  92. error(("Unable to read '%s': %s"):format(filename, err))
  93. end
  94. end
  95. local data = file:read("*a")
  96. if filename ~= nil then
  97. file:close()
  98. end
  99. if data == nil then
  100. error("Failed to read " .. filename)
  101. end
  102. return data
  103. end
  104. local function file_save(filename, data)
  105. local file
  106. if filename == nil then
  107. file = io.stdout
  108. else
  109. local err
  110. file, err = io.open(filename, "wb")
  111. if file == nil then
  112. error(("Unable to write '%s': %s"):format(filename, err))
  113. end
  114. end
  115. file:write(data)
  116. if filename ~= nil then
  117. file:close()
  118. end
  119. end
  120. local function compare_values(val1, val2)
  121. local type1 = type(val1)
  122. local type2 = type(val2)
  123. if type1 ~= type2 then
  124. return false
  125. end
  126. -- Check for NaN
  127. if type1 == "number" and val1 ~= val1 and val2 ~= val2 then
  128. return true
  129. end
  130. if type1 ~= "table" then
  131. return val1 == val2
  132. end
  133. -- check_keys stores all the keys that must be checked in val2
  134. local check_keys = {}
  135. for k, _ in pairs(val1) do
  136. check_keys[k] = true
  137. end
  138. for k, v in pairs(val2) do
  139. if not check_keys[k] then
  140. return false
  141. end
  142. if not compare_values(val1[k], val2[k]) then
  143. return false
  144. end
  145. check_keys[k] = nil
  146. end
  147. for k, _ in pairs(check_keys) do
  148. -- Not the same if any keys from val1 were not found in val2
  149. return false
  150. end
  151. return true
  152. end
  153. local test_count_pass = 0
  154. local test_count_total = 0
  155. local function run_test_summary()
  156. return test_count_pass, test_count_total
  157. end
  158. local function run_test(testname, func, input, should_work, output)
  159. local function status_line(name, status, value)
  160. local statusmap = { [true] = ":success", [false] = ":error" }
  161. if status ~= nil then
  162. name = name .. statusmap[status]
  163. end
  164. print(("[%s] %s"):format(name, serialise_value(value, false)))
  165. end
  166. local result = { pcall(func, unpack(input)) }
  167. local success = table.remove(result, 1)
  168. local correct = false
  169. if success == should_work and compare_values(result, output) then
  170. correct = true
  171. test_count_pass = test_count_pass + 1
  172. end
  173. test_count_total = test_count_total + 1
  174. local teststatus = { [true] = "PASS", [false] = "FAIL" }
  175. print(("==> Test [%d] %s: %s"):format(test_count_total, testname,
  176. teststatus[correct]))
  177. status_line("Input", nil, input)
  178. if not correct then
  179. status_line("Expected", should_work, output)
  180. end
  181. status_line("Received", success, result)
  182. print()
  183. return correct, result
  184. end
  185. local function run_test_group(tests)
  186. local function run_helper(name, func, input)
  187. if type(name) == "string" and #name > 0 then
  188. print("==> " .. name)
  189. end
  190. -- Not a protected call, these functions should never generate errors.
  191. func(unpack(input or {}))
  192. print()
  193. end
  194. for _, v in ipairs(tests) do
  195. -- Run the helper if "should_work" is missing
  196. if v[4] == nil then
  197. run_helper(unpack(v))
  198. else
  199. run_test(unpack(v))
  200. end
  201. end
  202. end
  203. -- Run a Lua script in a separate environment
  204. local function run_script(script, env)
  205. local env = env or {}
  206. local func
  207. -- Use setfenv() if it exists, otherwise assume Lua 5.2 load() exists
  208. if _G.setfenv then
  209. func = loadstring(script)
  210. if func then
  211. setfenv(func, env)
  212. end
  213. else
  214. func = load(script, nil, nil, env)
  215. end
  216. if func == nil then
  217. error("Invalid syntax.")
  218. end
  219. func()
  220. return env
  221. end
  222. -- Export functions
  223. return {
  224. serialise_value = serialise_value,
  225. file_load = file_load,
  226. file_save = file_save,
  227. compare_values = compare_values,
  228. run_test_summary = run_test_summary,
  229. run_test = run_test,
  230. run_test_group = run_test_group,
  231. run_script = run_script
  232. }
  233. -- vi:ai et sw=4 ts=4: