diff --git a/vmf/scripts/mods/vmf/modules/core/hooks.lua b/vmf/scripts/mods/vmf/modules/core/hooks.lua index 027709b..3bb1bac 100644 --- a/vmf/scripts/mods/vmf/modules/core/hooks.lua +++ b/vmf/scripts/mods/vmf/modules/core/hooks.lua @@ -130,8 +130,19 @@ local function get_hook_chain(orig) end end +-- Returns a table containing hook data inside of it. +-- { active = self:is_enabled() } +local function create_hook_data(self, obj, handler, hook_type) + return { + active = self:is_enabled(), + hook_type = hook_type, + handler = handler, + obj = obj, + } +end + -- Returns a function closure with all the information needed for a given hook to be handled correctly. -local function create_specialized_hook(self, orig, handler, hook_type) +local function create_specialized_hook(self, orig, hook_type) local func local hook_data = _registry[self][orig] @@ -144,7 +155,7 @@ local function create_specialized_hook(self, orig, handler, hook_type) if hook_type == HOOK_TYPE_NORMAL then func = function(...) if hook_data.active then - return handler(previous_hook, ...) + return hook_data.handler(previous_hook, ...) else return previous_hook(...) end @@ -153,7 +164,7 @@ local function create_specialized_hook(self, orig, handler, hook_type) elseif hook_type == HOOK_TYPE_ORIGIN then func = function(...) if hook_data.active then - return handler(...) + return hook_data.handler(...) else return orig(...) end @@ -161,7 +172,7 @@ local function create_specialized_hook(self, orig, handler, hook_type) elseif hook_type == HOOK_TYPE_SAFE then func = function(...) if hook_data.active then - vmf.xpcall_no_return_values(self, "(safe_hook)", handler, ...) + vmf.xpcall_no_return_values(self, "(safe_hook)", hook_data.handler, ...) end end end @@ -203,8 +214,9 @@ local function create_hook(self, orig, obj, method, handler, func_name, hook_typ end -- Check to make sure it wasn't hooked before - if not _registry[self][orig] then - _registry[self][orig] = { active = self:is_enabled() } + local hook_data = _registry[self][orig] + if not hook_data then + _registry[self][orig] = create_hook_data(self, obj, handler, hook_type) local hook_registry = _registry.hooks[hook_type] -- Add to the hook to registry. Origin hooks are unique, so we check for that too. @@ -212,18 +224,26 @@ local function create_hook(self, orig, obj, method, handler, func_name, hook_typ if hook_registry[orig] then self:error("(%s): Attempting to hook origin of already hooked function %s", func_name, method) else - hook_registry[orig] = create_specialized_hook(self, orig, handler, hook_type) + hook_registry[orig] = create_specialized_hook(self, orig, hook_type) end else - table.insert(hook_registry[orig], create_specialized_hook(self, orig, handler, hook_type) ) + table.insert(hook_registry[orig], create_specialized_hook(self, orig, hook_type) ) end else - -- This should be a warning log, but currently there are no differences between warning and error. - -- Wouldn't want to scare users that mods are broken because this used to be acceptable. - if vmf:get("developer_mode") then - self:warning("(%s): Attempting to rehook already active hook %s.", func_name, method) + -- If hook_data already exists and it's the same hook_type, we can safely change the hook handler. + -- This should (in practice) only be used for debugging by modders who uses DoFile. + -- Revisit purpose when lua files are in plain text. + if hook_data.obj == obj and hook_data.hook_type == hook_type then + hook_data.handler = handler else - self:info("(%s): Attempting to rehook already active hook %s.", func_name, method) + -- This should be a warning log, but currently there are no differences between warning and error. + -- Wouldn't want to scare users that mods are broken because this used to be acceptable. + -- This should be changed to permanently be a warning log after new_hooks deprecation period is over. + if vmf:get("developer_mode") then + self:warning("(%s): Attempting to rehook active hook [%s] with different hook_type.", func_name, method) + else + self:info("(%s): Attempting to rehook active hook [%s] with different hook_type.", func_name, method) + end end end