Subclass hook fix (#40)
This commit is contained in:
commit
92b12fb5b7
1 changed files with 50 additions and 59 deletions
|
@ -26,6 +26,7 @@ local auto_table_meta = {__index = function(t, k) t[k] = {} return t[k] end }
|
||||||
|
|
||||||
-- This table will hold all mod-specific data.
|
-- This table will hold all mod-specific data.
|
||||||
local _registry = setmetatable({}, auto_table_meta)
|
local _registry = setmetatable({}, auto_table_meta)
|
||||||
|
_registry.uids = {}
|
||||||
|
|
||||||
-- This table will hold all of the hooks, in the format of _hooks[hook_type]
|
-- This table will hold all of the hooks, in the format of _hooks[hook_type]
|
||||||
-- Do the same thing with these tables to allow _hooks[hook_type][orig] without a ton of nil-checks.
|
-- Do the same thing with these tables to allow _hooks[hook_type][orig] without a ton of nil-checks.
|
||||||
|
@ -43,12 +44,7 @@ local _origs = {}
|
||||||
|
|
||||||
-- This will tell us if we already have the given function in our registry.
|
-- This will tell us if we already have the given function in our registry.
|
||||||
local function is_orig_hooked(obj, method)
|
local function is_orig_hooked(obj, method)
|
||||||
local orig_registry = _origs
|
if _origs[obj] and _origs[obj][method] then
|
||||||
if obj then
|
|
||||||
if orig_registry[obj] and orig_registry[obj][method] then
|
|
||||||
return true
|
|
||||||
end
|
|
||||||
elseif not obj and orig_registry[method] then
|
|
||||||
return true
|
return true
|
||||||
end
|
end
|
||||||
return false
|
return false
|
||||||
|
@ -57,18 +53,10 @@ end
|
||||||
-- Since we replace the original function, we need to keep its reference around.
|
-- Since we replace the original function, we need to keep its reference around.
|
||||||
-- This will grab the cached reference if we hooked it before, otherwise return the function.
|
-- This will grab the cached reference if we hooked it before, otherwise return the function.
|
||||||
local function get_orig_function(obj, method)
|
local function get_orig_function(obj, method)
|
||||||
if obj then
|
if is_orig_hooked(obj, method) then
|
||||||
if is_orig_hooked(obj, method) then
|
return _origs[obj][method]
|
||||||
return _origs[obj][method]
|
|
||||||
else
|
|
||||||
return obj[method]
|
|
||||||
end
|
|
||||||
else
|
else
|
||||||
if is_orig_hooked(obj, method) then
|
return obj[method]
|
||||||
return _origs[method]
|
|
||||||
else
|
|
||||||
return rawget(_G, method)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -105,16 +93,16 @@ end
|
||||||
-- For any given original function, return the newest entry of the hook_chain.
|
-- For any given original function, return the newest entry of the hook_chain.
|
||||||
-- Since all hooks of the chain contain the call to the previous one, we don't need to do any manual loops.
|
-- Since all hooks of the chain contain the call to the previous one, we don't need to do any manual loops.
|
||||||
-- This continues until the end of the chain, where the original function is called.
|
-- This continues until the end of the chain, where the original function is called.
|
||||||
local function get_hook_chain(orig)
|
local function get_hook_chain(orig, unique_id)
|
||||||
local hook_registry = _hooks
|
local hook_registry = _hooks
|
||||||
local hooks = hook_registry[HOOK_TYPE_NORMAL][orig]
|
local hooks = hook_registry[HOOK_TYPE_NORMAL][unique_id]
|
||||||
if hooks and #hooks > 0 then
|
if hooks and #hooks > 0 then
|
||||||
return hooks[#hooks]
|
return hooks[#hooks]
|
||||||
end
|
end
|
||||||
-- We can't simply return orig here, or it would cause origins to depend on load order.
|
-- We can't simply return orig here, or it would cause origins to depend on load order.
|
||||||
return function(...)
|
return function(...)
|
||||||
if hook_registry[HOOK_TYPE_ORIGIN][orig] then
|
if hook_registry[HOOK_TYPE_ORIGIN][unique_id] then
|
||||||
return hook_registry[HOOK_TYPE_ORIGIN][orig](...)
|
return hook_registry[HOOK_TYPE_ORIGIN][unique_id](...)
|
||||||
else
|
else
|
||||||
return orig(...)
|
return orig(...)
|
||||||
end
|
end
|
||||||
|
@ -122,11 +110,12 @@ local function get_hook_chain(orig)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Returns a table containing hook data inside of it.
|
-- Returns a table containing hook data inside of it.
|
||||||
local function create_hook_data(mod, obj, handler, hook_type)
|
local function create_hook_data(mod, obj, orig, handler, hook_type)
|
||||||
return {
|
return {
|
||||||
active = mod:is_enabled(),
|
active = mod:is_enabled(),
|
||||||
hook_type = hook_type,
|
hook_type = hook_type,
|
||||||
handler = handler,
|
handler = handler,
|
||||||
|
orig = orig,
|
||||||
obj = obj,
|
obj = obj,
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
@ -135,12 +124,13 @@ end
|
||||||
-- Note: If a previous hook is removed from the table, these functions wouldn't be updated
|
-- Note: If a previous hook is removed from the table, these functions wouldn't be updated
|
||||||
-- This would break the chain, solution is to not remove the hooks, simply make them inactive
|
-- This would break the chain, solution is to not remove the hooks, simply make them inactive
|
||||||
-- Make sure inactive hooks that rely on the chain still call the next function seamlessly.
|
-- Make sure inactive hooks that rely on the chain still call the next function seamlessly.
|
||||||
local function create_specialized_hook(mod, orig, hook_type)
|
local function create_specialized_hook(mod, unique_id, hook_type)
|
||||||
local func
|
local func
|
||||||
local hook_data = _registry[mod][orig]
|
local hook_data = _registry[mod][unique_id]
|
||||||
|
local orig = hook_data.orig
|
||||||
|
|
||||||
-- Determine the previous function in the hook stack
|
-- Determine the previous function in the hook stack
|
||||||
local previous_hook = get_hook_chain(orig)
|
local previous_hook = get_hook_chain(orig, unique_id)
|
||||||
|
|
||||||
if hook_type == HOOK_TYPE_NORMAL then
|
if hook_type == HOOK_TYPE_NORMAL then
|
||||||
func = function(...)
|
func = function(...)
|
||||||
|
@ -172,52 +162,53 @@ end
|
||||||
-- The hook system makes internal functions that replace the original function and handles all the hooks.
|
-- The hook system makes internal functions that replace the original function and handles all the hooks.
|
||||||
-- Once all hooks that are part of the chain have been executed, we can go over the safe hooks.
|
-- Once all hooks that are part of the chain have been executed, we can go over the safe hooks.
|
||||||
-- Note: We need to keep the return values in mind in case another function depends on them.
|
-- Note: We need to keep the return values in mind in case another function depends on them.
|
||||||
-- At this point in the execution, Obj has already been type-checked and cannot be a string anymore.
|
-- We then use this internal hook as a unique identifier for the registry.
|
||||||
local function create_internal_hook(orig, obj, method)
|
local function create_internal_hook(orig, obj, method)
|
||||||
local fn = function(...)
|
local fn -- Needs to be over two line to be usable within the closure.
|
||||||
local hook_chain = get_hook_chain(orig)
|
fn = function(...)
|
||||||
|
local hook_chain = get_hook_chain(orig, fn)
|
||||||
local num_values, values = get_return_values( hook_chain(...) )
|
local num_values, values = get_return_values( hook_chain(...) )
|
||||||
|
|
||||||
local safe_hooks = _hooks[HOOK_TYPE_SAFE][orig]
|
local safe_hooks = _hooks[HOOK_TYPE_SAFE][fn]
|
||||||
if safe_hooks and #safe_hooks > 0 then
|
if safe_hooks and #safe_hooks > 0 then
|
||||||
for i = 1, #safe_hooks do safe_hooks[i](...) end
|
for i = 1, #safe_hooks do safe_hooks[i](...) end
|
||||||
end
|
end
|
||||||
return unpack(values, 1, num_values)
|
return unpack(values, 1, num_values)
|
||||||
end
|
end
|
||||||
|
|
||||||
if obj then
|
if not _origs[obj] then _origs[obj] = {} end
|
||||||
if not _origs[obj] then _origs[obj] = {} end
|
_origs[obj][method] = orig
|
||||||
_origs[obj][method] = orig
|
obj[method] = fn
|
||||||
obj[method] = fn
|
return fn
|
||||||
else
|
|
||||||
_origs[method] = orig
|
|
||||||
_G[method] = fn
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
-- This function handles the handling the hook data and adding them to the registry.
|
-- This function handles the handling the hook data and adding them to the registry.
|
||||||
-- Origin Hooks have to be unique by nature so we have to make sure we don't allow multiple mods to do it.
|
-- Origin Hooks have to be unique by nature so we have to make sure we don't allow multiple mods to do it.
|
||||||
local function create_hook(mod, orig, obj, method, handler, func_name, hook_type)
|
local function create_hook(mod, orig, obj, method, handler, func_name, hook_type)
|
||||||
mod:info("(%s): Hooking '%s' from [%s] (Origin: %s)", func_name, method, obj or "_G", orig)
|
local unique_id
|
||||||
|
|
||||||
if not is_orig_hooked(obj, method) then
|
if not is_orig_hooked(obj, method) then
|
||||||
create_internal_hook(orig, obj, method)
|
unique_id = create_internal_hook(orig, obj, method)
|
||||||
|
_registry.uids[unique_id] = orig
|
||||||
|
else
|
||||||
|
unique_id = obj[method]
|
||||||
end
|
end
|
||||||
|
mod:info("(%s): Hooking '%s' from [%s] (Origin: %s) (UniqueID: %s)", func_name, method, obj, orig, unique_id)
|
||||||
|
|
||||||
-- Check to make sure it wasn't hooked before
|
-- Check to make sure this mod hasn't hooked it before
|
||||||
local hook_data = _registry[mod][orig]
|
local hook_data = _registry[mod][unique_id]
|
||||||
if not hook_data then
|
if not hook_data then
|
||||||
_registry[mod][orig] = create_hook_data(mod, obj, handler, hook_type)
|
_registry[mod][unique_id] = create_hook_data(mod, obj, orig, handler, hook_type)
|
||||||
|
|
||||||
local hook_registry = _hooks[hook_type]
|
local hook_registry = _hooks[hook_type]
|
||||||
if hook_type == HOOK_TYPE_ORIGIN then
|
if hook_type == HOOK_TYPE_ORIGIN then
|
||||||
if hook_registry[orig] then
|
if hook_registry[unique_id] then
|
||||||
mod:error("(%s): Attempting to hook origin of already hooked function %s", func_name, method)
|
mod:error("(%s): Attempting to hook origin of already hooked function %s", func_name, method)
|
||||||
else
|
else
|
||||||
hook_registry[orig] = create_specialized_hook(mod, orig, hook_type)
|
hook_registry[unique_id] = create_specialized_hook(mod, unique_id, hook_type)
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
table.insert(hook_registry[orig], create_specialized_hook(mod, orig, hook_type) )
|
table.insert(hook_registry[unique_id], create_specialized_hook(mod, unique_id, hook_type) )
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
-- If hook_data already exists and it's the same hook_type, we can safely change the hook handler.
|
-- If hook_data already exists and it's the same hook_type, we can safely change the hook handler.
|
||||||
|
@ -259,9 +250,9 @@ local function generic_hook(mod, obj, method, handler, func_name)
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Shift the arguments if needed
|
-- Shift the arguments if no obj is provided. obj becomes the global table.
|
||||||
if not handler then
|
if not handler then
|
||||||
obj, method, handler = nil, obj, method
|
obj, method, handler = _G, obj, method
|
||||||
if not method then
|
if not method then
|
||||||
mod:error("(%s): trying to create hook without giving a method name.", func_name)
|
mod:error("(%s): trying to create hook without giving a method name.", func_name)
|
||||||
return
|
return
|
||||||
|
@ -286,14 +277,11 @@ local function generic_hook(mod, obj, method, handler, func_name)
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
-- obj is a either nil or a table reference at this point, it cannot be a string anymore.
|
-- obj should always be a table reference at this point --
|
||||||
|
|
||||||
-- Quick check to make sure the target exists
|
-- Quick check to make sure the target exists
|
||||||
if obj and not obj[method] then
|
if not obj[method] then
|
||||||
mod:error("(%s): trying to hook method that doesn't exist: [%s.%s]", func_name, obj, method)
|
mod:error("(%s): trying to hook function or method that doesn't exist: [%s.%s]", func_name, obj, method)
|
||||||
return
|
|
||||||
elseif not obj and not rawget(_G, method) then
|
|
||||||
mod:error("(%s): trying to hook function that doesn't exist: [%s]", func_name, method)
|
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -303,6 +291,11 @@ local function generic_hook(mod, obj, method, handler, func_name)
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- Edge Case: If someone hooks a copy of a function after its been hooked, point it back in the right direction
|
||||||
|
if _registry.uids[orig] then
|
||||||
|
orig = _registry.uids[orig]
|
||||||
|
end
|
||||||
|
|
||||||
return create_hook(mod, orig, obj, method, handler, func_name, hook_type)
|
return create_hook(mod, orig, obj, method, handler, func_name, hook_type)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -316,7 +309,7 @@ local function generic_hook_toggle(mod, obj, method, enabled_state)
|
||||||
|
|
||||||
-- Shift the arguments if needed
|
-- Shift the arguments if needed
|
||||||
if not method then
|
if not method then
|
||||||
obj, method = nil, obj
|
obj, method = _G, obj
|
||||||
if not method then
|
if not method then
|
||||||
mod:error("(%s): trying to toggle hook without giving a method name.", func_name)
|
mod:error("(%s): trying to toggle hook without giving a method name.", func_name)
|
||||||
return
|
return
|
||||||
|
@ -338,10 +331,8 @@ local function generic_hook_toggle(mod, obj, method, enabled_state)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
local orig = get_orig_function(obj, method)
|
if _registry[mod][obj[method]] then
|
||||||
|
_registry[mod][obj[method]].active = enabled_state
|
||||||
if _registry[mod][orig] then
|
|
||||||
_registry[mod][orig].active = enabled_state
|
|
||||||
else
|
else
|
||||||
-- This has the potential for mod-breaking behavior, but not guaranteed
|
-- This has the potential for mod-breaking behavior, but not guaranteed
|
||||||
mod:warning("(%s): trying to toggle hook that doesn't exist: %s", func_name, method)
|
mod:warning("(%s): trying to toggle hook that doesn't exist: %s", func_name, method)
|
||||||
|
|
Loading…
Add table
Reference in a new issue