123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- local si = require "snax.interface"
- local function envid(f)
- local i = 1
- while true do
- local name, value = debug.getupvalue(f, i)
- if name == nil then
- return
- end
- if name == "_ENV" then
- return debug.upvalueid(f, i)
- end
- i = i + 1
- end
- end
- local function collect_uv(f , uv, env)
- local i = 1
- while true do
- local name, value = debug.getupvalue(f, i)
- if name == nil then
- break
- end
- local id = debug.upvalueid(f, i)
- if uv[name] then
- assert(uv[name].id == id, string.format("ambiguity local value %s", name))
- else
- uv[name] = { func = f, index = i, id = id }
- if type(value) == "function" then
- if envid(value) == env then
- collect_uv(value, uv, env)
- end
- end
- end
- i = i + 1
- end
- end
- local function collect_all_uv(funcs)
- local global = {}
- for _, v in pairs(funcs) do
- if v[4] then
- collect_uv(v[4], global, envid(v[4]))
- end
- end
- if not global["_ENV"] then
- global["_ENV"] = {func = collect_uv, index = 1}
- end
- return global
- end
- local function loader(source)
- return function (path, name, G)
- return load(source, "=patch", "bt", G)
- end
- end
- local function find_func(funcs, group , name)
- for _, desc in pairs(funcs) do
- local _, g, n = table.unpack(desc)
- if group == g and name == n then
- return desc
- end
- end
- end
- local dummy_env = {}
- for k,v in pairs(_ENV) do dummy_env[k] = v end
- local function _patch(global, f)
- local i = 1
- while true do
- local name, value = debug.getupvalue(f, i)
- if name == nil then
- break
- elseif value == nil or value == dummy_env then
- local old_uv = global[name]
- if old_uv then
- debug.upvaluejoin(f, i, old_uv.func, old_uv.index)
- end
- else
- if type(value) == "function" then
- _patch(global, value)
- end
- end
- i = i + 1
- end
- end
- local function patch_func(funcs, global, group, name, f)
- local desc = assert(find_func(funcs, group, name) , string.format("Patch mismatch %s.%s", group, name))
- _patch(global, f)
- desc[4] = f
- end
- local function inject(funcs, source, ...)
- local patch = si("patch", dummy_env, loader(source))
- local global = collect_all_uv(funcs)
- for _, v in pairs(patch) do
- local _, group, name, f = table.unpack(v)
- if f then
- patch_func(funcs, global, group, name, f)
- end
- end
- local hf = find_func(patch, "system", "hotfix")
- if hf and hf[4] then
- return hf[4](...)
- end
- end
- return function (funcs, source, ...)
- return pcall(inject, funcs, source, ...)
- end
|