hotfix.lua 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. local si = require "snax.interface"
  2. local function envid(f)
  3. local i = 1
  4. while true do
  5. local name, value = debug.getupvalue(f, i)
  6. if name == nil then
  7. return
  8. end
  9. if name == "_ENV" then
  10. return debug.upvalueid(f, i)
  11. end
  12. i = i + 1
  13. end
  14. end
  15. local function collect_uv(f , uv, env)
  16. local i = 1
  17. while true do
  18. local name, value = debug.getupvalue(f, i)
  19. if name == nil then
  20. break
  21. end
  22. local id = debug.upvalueid(f, i)
  23. if uv[name] then
  24. assert(uv[name].id == id, string.format("ambiguity local value %s", name))
  25. else
  26. uv[name] = { func = f, index = i, id = id }
  27. if type(value) == "function" then
  28. if envid(value) == env then
  29. collect_uv(value, uv, env)
  30. end
  31. end
  32. end
  33. i = i + 1
  34. end
  35. end
  36. local function collect_all_uv(funcs)
  37. local global = {}
  38. for _, v in pairs(funcs) do
  39. if v[4] then
  40. collect_uv(v[4], global, envid(v[4]))
  41. end
  42. end
  43. if not global["_ENV"] then
  44. global["_ENV"] = {func = collect_uv, index = 1}
  45. end
  46. return global
  47. end
  48. local function loader(source)
  49. return function (path, name, G)
  50. return load(source, "=patch", "bt", G)
  51. end
  52. end
  53. local function find_func(funcs, group , name)
  54. for _, desc in pairs(funcs) do
  55. local _, g, n = table.unpack(desc)
  56. if group == g and name == n then
  57. return desc
  58. end
  59. end
  60. end
  61. local dummy_env = {}
  62. for k,v in pairs(_ENV) do dummy_env[k] = v end
  63. local function _patch(global, f)
  64. local i = 1
  65. while true do
  66. local name, value = debug.getupvalue(f, i)
  67. if name == nil then
  68. break
  69. elseif value == nil or value == dummy_env then
  70. local old_uv = global[name]
  71. if old_uv then
  72. debug.upvaluejoin(f, i, old_uv.func, old_uv.index)
  73. end
  74. else
  75. if type(value) == "function" then
  76. _patch(global, value)
  77. end
  78. end
  79. i = i + 1
  80. end
  81. end
  82. local function patch_func(funcs, global, group, name, f)
  83. local desc = assert(find_func(funcs, group, name) , string.format("Patch mismatch %s.%s", group, name))
  84. _patch(global, f)
  85. desc[4] = f
  86. end
  87. local function inject(funcs, source, ...)
  88. local patch = si("patch", dummy_env, loader(source))
  89. local global = collect_all_uv(funcs)
  90. for _, v in pairs(patch) do
  91. local _, group, name, f = table.unpack(v)
  92. if f then
  93. patch_func(funcs, global, group, name, f)
  94. end
  95. end
  96. local hf = find_func(patch, "system", "hotfix")
  97. if hf and hf[4] then
  98. return hf[4](...)
  99. end
  100. end
  101. return function (funcs, source, ...)
  102. return pcall(inject, funcs, source, ...)
  103. end