From fa393c4f78acf6b9f369a5d29f066c8a7ff5dfe4 Mon Sep 17 00:00:00 2001 From: FireSiku Date: Tue, 5 Jun 2018 21:38:27 -0400 Subject: [PATCH] Prevent creating multiple hooks on the same function using different API --- .../mods/vmf/modules/core/newhooks.lua | 126 ++++++------------ 1 file changed, 43 insertions(+), 83 deletions(-) diff --git a/vmf/scripts/mods/vmf/modules/core/newhooks.lua b/vmf/scripts/mods/vmf/modules/core/newhooks.lua index 580966c..01d8d89 100644 --- a/vmf/scripts/mods/vmf/modules/core/newhooks.lua +++ b/vmf/scripts/mods/vmf/modules/core/newhooks.lua @@ -7,9 +7,9 @@ local vmf = get_mod("VMF") -- Constants for hook_type local HOOK_TYPES = { hook = 1, - before = 2, - after = 3, - rawhook = 4, + before = 1, + after = 2, + rawhook = 3, } -- Upvalued constants to ease on table lookups when not needed @@ -19,10 +19,7 @@ local HOOK_TYPE_AFTER = HOOK_TYPES.after local HOOK_TYPE_RAW = HOOK_TYPES.rawhook --[[ Planned registry structure: - _registry[self][hook_type] = { - active[orig] = true, - handler[orig] = func, - } + _registry[self][orig] = { active = true} _registry.hooks[hook_type] _registry.origs ]] @@ -37,9 +34,8 @@ local _registry = setmetatable({}, auto_table_meta) -- This table will hold all of the hooks, in the format of _registry.hooks[hook_type] _registry.hooks = { -- Do the same thing with these tables to allow .hooks[hook_type][orig] without a ton of nil-checks. - setmetatable({}, auto_table_meta), -- before - setmetatable({}, auto_table_meta), -- after setmetatable({}, auto_table_meta), -- normal + setmetatable({}, auto_table_meta), -- after -- Since there can only be one rawhook per function, it doesnt need to generate a table. {}, -- raw } @@ -77,13 +73,6 @@ local function get_orig_function(self, obj, method) end end -local function is_existing_hook(self, orig, hook_type) - local registry = _registry[self][hook_type] - if registry and registry.handler and registry.handler[orig] then - return true - end -end - -- Return an object from the global table. Second return value is if it was sucessful. local function get_object_from_string(obj) if type(obj) == "table" then @@ -137,35 +126,39 @@ end -- Returns a function closure with all the information needed for a given hook to be handled correctly. local function create_specialized_hook(self, orig, handler, hook_type) local func - local active = _registry[self][hook_type].active + local hook_data = _registry[self][orig] + + -- Determine the previous function in the hook stack + -- Note: If a previous hook is removed from the table, these functions wouldnt be updated + -- This would break the chain, solution is to not remove the hooks, simply make them inactive + -- Make sure inactive hooks that rely on the chain still call the next function seamlessly. + local previous_hook = get_hook_chain(orig) + if hook_type == HOOK_TYPE_NORMAL then - -- Determine the previous function in the hook stack - local previous_hook = get_hook_chain(orig) - -- Note: If a previous hook is removed from the table, this function wouldnt be updated - -- This would break the chain, solution would be not to remove the hook, but make it inactive - -- Make sure inactive hooks just seamlessly call the next function on the list without disruption. func = function(...) - if active[orig] then + if hook_data.active then return handler(previous_hook, ...) else return previous_hook(...) end end - -- Need to make sure a disabled Rawhook will correctly call the original. + -- Rawhooks need to directly call the original function is inactive. elseif hook_type == HOOK_TYPE_RAW then func = function(...) - if active[orig] then + if hook_data.active then return handler(...) else return orig(...) end end - else + elseif hook_type == HOOK_TYPE_AFTER then func = function(...) - if active[orig] then + if hook_data.active then return handler(...) end end + else + self:error("(create_specialized_hook): Invalid hook_type given. You should never this see.") end return func end @@ -174,16 +167,12 @@ end -- The hook system makes internal functions that replace the original function and handles all the hooks. local function create_internal_hook(orig, obj, method) local fn = function(...) - local before_hooks = _registry.hooks[HOOK_TYPE_BEFORE][orig] - local after_hooks = _registry.hooks[HOOK_TYPE_AFTER][orig] - if before_hooks and #before_hooks > 0 then - for i = 1, #before_hooks do before_hooks[i](...) end - end -- Execute the hook chain. Note that we need to keep the return values -- in case another function depends on them. local hook_chain = get_hook_chain(orig) -- We need to keep return values in case another function depends on them local values = { hook_chain(...) } + local after_hooks = _registry.hooks[HOOK_TYPE_AFTER][orig] if after_hooks and #after_hooks > 0 then for i = 1, #after_hooks do after_hooks[i](...) end end @@ -209,25 +198,19 @@ local function create_hook(self, orig, obj, method, handler, func_name, hook_typ end -- Check to make sure it wasn't hooked before - if not is_existing_hook(self, orig, hook_type) then - -- Also set up related info accessible to the hook object under self. - if not _registry[self][hook_type] then - _registry[self][hook_type] = { - active = {}, - handler = {}, - } - end - _registry[self][hook_type].active[orig] = true - _registry[self][hook_type].handler[orig] = handler + if not _registry[self][orig] then + _registry[self][orig] = { active = true } + + local hook_registry = _registry.hooks[hook_type] -- Add to the hook to registry. Raw hooks are unique, so we check for that too. if hook_type == HOOK_TYPE_RAW then - if _registry.hooks[hook_type][orig] then + if hook_registry[orig] then self:error("(%s): Attempting to rawhook already hooked function %s", func_name, method) else - _registry.hooks[hook_type][orig] = create_specialized_hook(self, orig, handler, hook_type) + hook_registry[orig] = create_specialized_hook(self, orig, handler, hook_type) end else - table.insert(_registry.hooks[hook_type][orig], create_specialized_hook(self, orig, handler, hook_type)) + table.insert(hook_registry[orig], create_specialized_hook(self, orig, handler, hook_type)) end else local hook_type_name = func_name @@ -285,7 +268,9 @@ local function generic_hook(self, obj, method, handler, func_name) return create_hook(self, orig, obj, method, handler, func_name, hook_type) end -local function generic_hook_toggle(self, obj, method, func_name) +local function generic_hook_toggle(self, obj, method, enabled_state) + local func_name = (enabled_state) and "hook_enable" or "hook_disable" + if vmf.check_wrong_argument_type(self, func_name, "obj", obj, "string", "table") or vmf.check_wrong_argument_type(self, func_name, "method", method, "string", "nil") then return @@ -298,28 +283,16 @@ local function generic_hook_toggle(self, obj, method, func_name) local obj, sucess = get_object_from_string(obj) --luacheck: ignore if not sucess then - self:error("(%s): object doesn't exist: %s", func_name, obj) + self:error("(%s): object doesn't exist.", func_name) return end - - -- get hook_type and enabled_state from function name - local underscore_position = string.find(func_name, "_") - local hook_type = HOOK_TYPES[string.sub(func_name, 1, underscore_position - 1)] - local enabled_state = (string.sub(func_name, underscore_position + 1) == "enable") - - local registry = _registry[self][hook_type] - if registry then - local orig = get_orig_function(self, obj, method) - -- Check if handler exists, because active[orig] would fail if disabled (false) - if registry.handler[orig] then - registry.active[orig] = enabled_state - else - self:warning("(%s): trying to toggle hook that doesn't exist: %s", func_name) - return - end + local orig = get_orig_function(self, obj, method) + + if _registry[self][orig] then + _registry[self][orig].active = enabled_state else - self:warning("(%s): trying to toggle hook that doesn't exist: %s", func_name) + self:warning("(%s): trying to toggle hook that doesn't exist: %s", func_name, method) return end end @@ -336,8 +309,6 @@ end -- :before() provides a callback before a function is called. You have no control over the execution of the -- original function, nor can you change its return values. -- This type of hook is typically used if you need to know a function was called, but dont want to modify it. --- These will always be executed before the hook chain. --- Due to discussion, handler may not receive any arguments, but will see what the use cases are with them first. function VMFMod:before(obj, method, handler) return generic_hook(self, obj, method, handler, "before") end @@ -368,31 +339,20 @@ function VMFMod:rawhook(obj, method, handler) end -- Enable/disable functions for all hook types: -function VMFMod:hook_enable(obj, method) generic_hook_toggle(self, obj, method, "hook_enable") end -function VMFMod:before_enable(obj, method) generic_hook_toggle(self, obj, method, "before_enable") end -function VMFMod:after_enable(obj, method) generic_hook_toggle(self, obj, method, "after_enable") end -function VMFMod:rawhook_enable(obj, method) generic_hook_toggle(self, obj, method, "rawhook_enable") end +function VMFMod:hook_enable(obj, method) generic_hook_toggle(self, obj, method, true) end +function VMFMod:hook_disable(obj, method) generic_hook_toggle(self, obj, method, false) end -function VMFMod:hook_disable(obj, method) generic_hook_toggle(self, obj, method, "hook_disable") end -function VMFMod:before_disable(obj, method) generic_hook_toggle(self, obj, method, "before_disable") end -function VMFMod:after_disable(obj, method) generic_hook_toggle(self, obj, method, "after_disable") end -function VMFMod:rawhook_disable(obj, method) generic_hook_toggle(self, obj, method, "rawhook_disable") end - function VMFMod:enable_all_hooks() -- Using pairs because the self table may contain nils, and order isnt important. - for _, hooks in pairs(_registry[self]) do - for orig, _ in pairs(hooks.active) do - hooks.active[orig] = true - end + for _, hook in pairs(_registry[self]) do + hook.active = true end end function VMFMod:disable_all_hooks() -- Using pairs because the self table may contain nils, and order isnt important. - for _, hooks in pairs(_registry[self]) do - for orig, _ in pairs(hooks.active) do - hooks.active[orig] = false - end + for _, hook in pairs(_registry[self]) do + hook.active = false end end