Subclass hook fix (#40)

This commit is contained in:
Azumgi 2019-01-17 13:12:26 +03:00 committed by GitHub
commit 92b12fb5b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -26,6 +26,7 @@ local auto_table_meta = {__index = function(t, k) t[k] = {} return t[k] end }
-- This table will hold all mod-specific data. -- This table will hold all mod-specific data.
local _registry = setmetatable({}, auto_table_meta) local _registry = setmetatable({}, auto_table_meta)
_registry.uids = {}
-- This table will hold all of the hooks, in the format of _hooks[hook_type] -- This table will hold all of the hooks, in the format of _hooks[hook_type]
-- Do the same thing with these tables to allow _hooks[hook_type][orig] without a ton of nil-checks. -- Do the same thing with these tables to allow _hooks[hook_type][orig] without a ton of nil-checks.
@ -43,12 +44,7 @@ local _origs = {}
-- This will tell us if we already have the given function in our registry. -- This will tell us if we already have the given function in our registry.
local function is_orig_hooked(obj, method) local function is_orig_hooked(obj, method)
local orig_registry = _origs if _origs[obj] and _origs[obj][method] then
if obj then
if orig_registry[obj] and orig_registry[obj][method] then
return true
end
elseif not obj and orig_registry[method] then
return true return true
end end
return false return false
@ -57,18 +53,10 @@ end
-- Since we replace the original function, we need to keep its reference around. -- Since we replace the original function, we need to keep its reference around.
-- This will grab the cached reference if we hooked it before, otherwise return the function. -- This will grab the cached reference if we hooked it before, otherwise return the function.
local function get_orig_function(obj, method) local function get_orig_function(obj, method)
if obj then if is_orig_hooked(obj, method) then
if is_orig_hooked(obj, method) then return _origs[obj][method]
return _origs[obj][method]
else
return obj[method]
end
else else
if is_orig_hooked(obj, method) then return obj[method]
return _origs[method]
else
return rawget(_G, method)
end
end end
end end
@ -105,16 +93,16 @@ end
-- For any given original function, return the newest entry of the hook_chain. -- For any given original function, return the newest entry of the hook_chain.
-- Since all hooks of the chain contain the call to the previous one, we don't need to do any manual loops. -- Since all hooks of the chain contain the call to the previous one, we don't need to do any manual loops.
-- This continues until the end of the chain, where the original function is called. -- This continues until the end of the chain, where the original function is called.
local function get_hook_chain(orig) local function get_hook_chain(orig, unique_id)
local hook_registry = _hooks local hook_registry = _hooks
local hooks = hook_registry[HOOK_TYPE_NORMAL][orig] local hooks = hook_registry[HOOK_TYPE_NORMAL][unique_id]
if hooks and #hooks > 0 then if hooks and #hooks > 0 then
return hooks[#hooks] return hooks[#hooks]
end end
-- We can't simply return orig here, or it would cause origins to depend on load order. -- We can't simply return orig here, or it would cause origins to depend on load order.
return function(...) return function(...)
if hook_registry[HOOK_TYPE_ORIGIN][orig] then if hook_registry[HOOK_TYPE_ORIGIN][unique_id] then
return hook_registry[HOOK_TYPE_ORIGIN][orig](...) return hook_registry[HOOK_TYPE_ORIGIN][unique_id](...)
else else
return orig(...) return orig(...)
end end
@ -122,11 +110,12 @@ local function get_hook_chain(orig)
end end
-- Returns a table containing hook data inside of it. -- Returns a table containing hook data inside of it.
local function create_hook_data(mod, obj, handler, hook_type) local function create_hook_data(mod, obj, orig, handler, hook_type)
return { return {
active = mod:is_enabled(), active = mod:is_enabled(),
hook_type = hook_type, hook_type = hook_type,
handler = handler, handler = handler,
orig = orig,
obj = obj, obj = obj,
} }
end end
@ -135,12 +124,13 @@ end
-- Note: If a previous hook is removed from the table, these functions wouldn't be updated -- Note: If a previous hook is removed from the table, these functions wouldn't be updated
-- This would break the chain, solution is to not remove the hooks, simply make them inactive -- This would break the chain, solution is to not remove the hooks, simply make them inactive
-- Make sure inactive hooks that rely on the chain still call the next function seamlessly. -- Make sure inactive hooks that rely on the chain still call the next function seamlessly.
local function create_specialized_hook(mod, orig, hook_type) local function create_specialized_hook(mod, unique_id, hook_type)
local func local func
local hook_data = _registry[mod][orig] local hook_data = _registry[mod][unique_id]
local orig = hook_data.orig
-- Determine the previous function in the hook stack -- Determine the previous function in the hook stack
local previous_hook = get_hook_chain(orig) local previous_hook = get_hook_chain(orig, unique_id)
if hook_type == HOOK_TYPE_NORMAL then if hook_type == HOOK_TYPE_NORMAL then
func = function(...) func = function(...)
@ -172,52 +162,53 @@ end
-- The hook system makes internal functions that replace the original function and handles all the hooks. -- The hook system makes internal functions that replace the original function and handles all the hooks.
-- Once all hooks that are part of the chain have been executed, we can go over the safe hooks. -- Once all hooks that are part of the chain have been executed, we can go over the safe hooks.
-- Note: We need to keep the return values in mind in case another function depends on them. -- Note: We need to keep the return values in mind in case another function depends on them.
-- At this point in the execution, Obj has already been type-checked and cannot be a string anymore. -- We then use this internal hook as a unique identifier for the registry.
local function create_internal_hook(orig, obj, method) local function create_internal_hook(orig, obj, method)
local fn = function(...) local fn -- Needs to be over two line to be usable within the closure.
local hook_chain = get_hook_chain(orig) fn = function(...)
local hook_chain = get_hook_chain(orig, fn)
local num_values, values = get_return_values( hook_chain(...) ) local num_values, values = get_return_values( hook_chain(...) )
local safe_hooks = _hooks[HOOK_TYPE_SAFE][orig] local safe_hooks = _hooks[HOOK_TYPE_SAFE][fn]
if safe_hooks and #safe_hooks > 0 then if safe_hooks and #safe_hooks > 0 then
for i = 1, #safe_hooks do safe_hooks[i](...) end for i = 1, #safe_hooks do safe_hooks[i](...) end
end end
return unpack(values, 1, num_values) return unpack(values, 1, num_values)
end end
if obj then if not _origs[obj] then _origs[obj] = {} end
if not _origs[obj] then _origs[obj] = {} end _origs[obj][method] = orig
_origs[obj][method] = orig obj[method] = fn
obj[method] = fn return fn
else
_origs[method] = orig
_G[method] = fn
end
end end
-- This function handles the handling the hook data and adding them to the registry. -- This function handles the handling the hook data and adding them to the registry.
-- Origin Hooks have to be unique by nature so we have to make sure we don't allow multiple mods to do it. -- Origin Hooks have to be unique by nature so we have to make sure we don't allow multiple mods to do it.
local function create_hook(mod, orig, obj, method, handler, func_name, hook_type) local function create_hook(mod, orig, obj, method, handler, func_name, hook_type)
mod:info("(%s): Hooking '%s' from [%s] (Origin: %s)", func_name, method, obj or "_G", orig) local unique_id
if not is_orig_hooked(obj, method) then if not is_orig_hooked(obj, method) then
create_internal_hook(orig, obj, method) unique_id = create_internal_hook(orig, obj, method)
_registry.uids[unique_id] = orig
else
unique_id = obj[method]
end end
mod:info("(%s): Hooking '%s' from [%s] (Origin: %s) (UniqueID: %s)", func_name, method, obj, orig, unique_id)
-- Check to make sure it wasn't hooked before -- Check to make sure this mod hasn't hooked it before
local hook_data = _registry[mod][orig] local hook_data = _registry[mod][unique_id]
if not hook_data then if not hook_data then
_registry[mod][orig] = create_hook_data(mod, obj, handler, hook_type) _registry[mod][unique_id] = create_hook_data(mod, obj, orig, handler, hook_type)
local hook_registry = _hooks[hook_type] local hook_registry = _hooks[hook_type]
if hook_type == HOOK_TYPE_ORIGIN then if hook_type == HOOK_TYPE_ORIGIN then
if hook_registry[orig] then if hook_registry[unique_id] then
mod:error("(%s): Attempting to hook origin of already hooked function %s", func_name, method) mod:error("(%s): Attempting to hook origin of already hooked function %s", func_name, method)
else else
hook_registry[orig] = create_specialized_hook(mod, orig, hook_type) hook_registry[unique_id] = create_specialized_hook(mod, unique_id, hook_type)
end end
else else
table.insert(hook_registry[orig], create_specialized_hook(mod, orig, hook_type) ) table.insert(hook_registry[unique_id], create_specialized_hook(mod, unique_id, hook_type) )
end end
else else
-- If hook_data already exists and it's the same hook_type, we can safely change the hook handler. -- If hook_data already exists and it's the same hook_type, we can safely change the hook handler.
@ -259,9 +250,9 @@ local function generic_hook(mod, obj, method, handler, func_name)
return return
end end
-- Shift the arguments if needed -- Shift the arguments if no obj is provided. obj becomes the global table.
if not handler then if not handler then
obj, method, handler = nil, obj, method obj, method, handler = _G, obj, method
if not method then if not method then
mod:error("(%s): trying to create hook without giving a method name.", func_name) mod:error("(%s): trying to create hook without giving a method name.", func_name)
return return
@ -286,14 +277,11 @@ local function generic_hook(mod, obj, method, handler, func_name)
return return
end end
end end
-- obj is a either nil or a table reference at this point, it cannot be a string anymore. -- obj should always be a table reference at this point --
-- Quick check to make sure the target exists -- Quick check to make sure the target exists
if obj and not obj[method] then if not obj[method] then
mod:error("(%s): trying to hook method that doesn't exist: [%s.%s]", func_name, obj, method) mod:error("(%s): trying to hook function or method that doesn't exist: [%s.%s]", func_name, obj, method)
return
elseif not obj and not rawget(_G, method) then
mod:error("(%s): trying to hook function that doesn't exist: [%s]", func_name, method)
return return
end end
@ -303,6 +291,11 @@ local function generic_hook(mod, obj, method, handler, func_name)
return return
end end
-- Edge Case: If someone hooks a copy of a function after its been hooked, point it back in the right direction
if _registry.uids[orig] then
orig = _registry.uids[orig]
end
return create_hook(mod, orig, obj, method, handler, func_name, hook_type) return create_hook(mod, orig, obj, method, handler, func_name, hook_type)
end end
@ -316,7 +309,7 @@ local function generic_hook_toggle(mod, obj, method, enabled_state)
-- Shift the arguments if needed -- Shift the arguments if needed
if not method then if not method then
obj, method = nil, obj obj, method = _G, obj
if not method then if not method then
mod:error("(%s): trying to toggle hook without giving a method name.", func_name) mod:error("(%s): trying to toggle hook without giving a method name.", func_name)
return return
@ -338,10 +331,8 @@ local function generic_hook_toggle(mod, obj, method, enabled_state)
end end
end end
local orig = get_orig_function(obj, method) if _registry[mod][obj[method]] then
_registry[mod][obj[method]].active = enabled_state
if _registry[mod][orig] then
_registry[mod][orig].active = enabled_state
else else
-- This has the potential for mod-breaking behavior, but not guaranteed -- This has the potential for mod-breaking behavior, but not guaranteed
mod:warning("(%s): trying to toggle hook that doesn't exist: %s", func_name, method) mod:warning("(%s): trying to toggle hook that doesn't exist: %s", func_name, method)