diff --git a/vmf/scripts/mods/vmf/modules/core/hooks.lua b/vmf/scripts/mods/vmf/modules/core/hooks.lua index 027709b..ee9c7f4 100644 --- a/vmf/scripts/mods/vmf/modules/core/hooks.lua +++ b/vmf/scripts/mods/vmf/modules/core/hooks.lua @@ -16,10 +16,10 @@ local HOOK_TYPE_NORMAL = 1 local HOOK_TYPE_SAFE = 2 local HOOK_TYPE_ORIGIN = 3 ---[[ Planned registry structure: - _registry[self][orig] = { active = true} - _registry.hooks[hook_type] - _registry.origs +--[[ Planned internal structure + _registry[mod][orig] = hook_data table + _hooks[hook_type][orig] = array of hook functions. (Single hook function for hook_origin) + _origs table holds all the original functions ]] -- dont need to attach this to registry. @@ -27,26 +27,26 @@ local _delayed = {} local _delaying_enabled = true -- This metatable will automatically create a table entry if one doesnt exist. --- This lets us easily do _registry[self] without having to worry about nil-checking it. +-- This lets us easily do _registry[mod] without having to worry about nil-checking it. local auto_table_meta = {__index = function(t, k) t[k] = {} return t[k] end } local _registry = setmetatable({}, auto_table_meta) --- This table will hold all of the hooks, in the format of _registry.hooks[hook_type] -_registry.hooks = { +-- This table will hold all of the hooks, in the format of _hooks[hook_type] +local _hooks = { -- Do the same thing with these tables to allow .hooks[hook_type][orig] without a ton of nil-checks. setmetatable({}, auto_table_meta), -- normal setmetatable({}, auto_table_meta), -- safe -- Since there can only be one origin per function, it doesnt need to generate a table. {}, -- origin } -_registry.origs = {} +local _origs = {} -- #################################################################################################################### -- ##### Util functions ############################################################################################### -- #################################################################################################################### local function is_orig_hooked(obj, method) - local orig_registry = _registry.origs + local orig_registry = _origs if obj and orig_registry[obj] and orig_registry[obj][method] then return true elseif orig_registry[method] then @@ -57,16 +57,16 @@ end -- Since we replace the original function, we need to keep its reference around. -- This will grab the cached reference if we hooked it before, otherwise return the function. -local function get_orig_function(self, obj, method) +local function get_orig_function(obj, method) if obj then if is_orig_hooked(obj, method) then - return _registry.origs[obj][method] + return _origs[obj][method] else return obj[method] end else if is_orig_hooked(obj, method) then - return _registry.origs[method] + return _origs[method] else return rawget(_G, method) end @@ -107,6 +107,12 @@ local function get_return_values(...) return num, { ... } end +local function can_rehook(mod, hook_data, obj, hook_type) + if mod:get_internal_data("allow_rehooking") and hook_data.obj == obj and hook_data.hook_type == hook_type then + return true + end +end + -- #################################################################################################################### -- ##### Hook Creation ################################################################################################ -- #################################################################################################################### @@ -115,7 +121,7 @@ end -- Since all hooks of the chain contain the call to the previous one, we don't need to do any manual loops. -- This continues until the end of the chain, where the original function is called. local function get_hook_chain(orig) - local hook_registry = _registry.hooks + local hook_registry = _hooks local hooks = hook_registry[HOOK_TYPE_NORMAL][orig] if hooks and #hooks > 0 then return hooks[#hooks] @@ -130,10 +136,21 @@ local function get_hook_chain(orig) end end +-- Returns a table containing hook data inside of it. +-- { active = mod:is_enabled() } +local function create_hook_data(mod, obj, handler, hook_type) + return { + active = mod:is_enabled(), + hook_type = hook_type, + handler = handler, + obj = obj, + } +end + -- Returns a function closure with all the information needed for a given hook to be handled correctly. -local function create_specialized_hook(self, orig, handler, hook_type) +local function create_specialized_hook(mod, orig, hook_type) local func - local hook_data = _registry[self][orig] + local hook_data = _registry[mod][orig] -- Determine the previous function in the hook stack -- Note: If a previous hook is removed from the table, these functions wouldn't be updated @@ -144,7 +161,7 @@ local function create_specialized_hook(self, orig, handler, hook_type) if hook_type == HOOK_TYPE_NORMAL then func = function(...) if hook_data.active then - return handler(previous_hook, ...) + return hook_data.handler(previous_hook, ...) else return previous_hook(...) end @@ -153,7 +170,7 @@ local function create_specialized_hook(self, orig, handler, hook_type) elseif hook_type == HOOK_TYPE_ORIGIN then func = function(...) if hook_data.active then - return handler(...) + return hook_data.handler(...) else return orig(...) end @@ -161,7 +178,7 @@ local function create_specialized_hook(self, orig, handler, hook_type) elseif hook_type == HOOK_TYPE_SAFE then func = function(...) if hook_data.active then - vmf.xpcall_no_return_values(self, "(safe_hook)", handler, ...) + vmf.xpcall_no_return_values(mod, "(safe_hook)", hook_data.handler, ...) end end end @@ -177,7 +194,7 @@ local function create_internal_hook(orig, obj, method) -- We need to keep return values in case another function depends on them local num_values, values = get_return_values( hook_chain(...) ) - local safe_hooks = _registry.hooks[HOOK_TYPE_SAFE][orig] + local safe_hooks = _hooks[HOOK_TYPE_SAFE][orig] if safe_hooks and #safe_hooks > 0 then for i = 1, #safe_hooks do safe_hooks[i](...) end end @@ -186,44 +203,53 @@ local function create_internal_hook(orig, obj, method) if obj then -- object cannot be a string at this point, so we don't need to check for that. - if not _registry.origs[obj] then _registry.origs[obj] = {} end - _registry.origs[obj][method] = orig + if not _origs[obj] then _origs[obj] = {} end + _origs[obj][method] = orig obj[method] = fn else - _registry.origs[method] = orig + _origs[method] = orig _G[method] = fn end end -local function create_hook(self, orig, obj, method, handler, func_name, hook_type) - self:info("(%s): Hooking '%s' from [%s] (Origin: %s)", func_name, method, obj or "_G", orig) +local function create_hook(mod, orig, obj, method, handler, func_name, hook_type) + mod:info("(%s): Hooking '%s' from [%s] (Origin: %s)", func_name, method, obj or "_G", orig) if not is_orig_hooked(obj, method) then create_internal_hook(orig, obj, method) end -- Check to make sure it wasn't hooked before - if not _registry[self][orig] then - _registry[self][orig] = { active = self:is_enabled() } + local hook_data = _registry[mod][orig] + if not hook_data then + _registry[mod][orig] = create_hook_data(mod, obj, handler, hook_type) - local hook_registry = _registry.hooks[hook_type] + local hook_registry = _hooks[hook_type] -- Add to the hook to registry. Origin hooks are unique, so we check for that too. if hook_type == HOOK_TYPE_ORIGIN then if hook_registry[orig] then - self:error("(%s): Attempting to hook origin of already hooked function %s", func_name, method) + mod:error("(%s): Attempting to hook origin of already hooked function %s", func_name, method) else - hook_registry[orig] = create_specialized_hook(self, orig, handler, hook_type) + hook_registry[orig] = create_specialized_hook(mod, orig, hook_type) end else - table.insert(hook_registry[orig], create_specialized_hook(self, orig, handler, hook_type) ) + table.insert(hook_registry[orig], create_specialized_hook(mod, orig, hook_type) ) end else - -- This should be a warning log, but currently there are no differences between warning and error. - -- Wouldn't want to scare users that mods are broken because this used to be acceptable. - if vmf:get("developer_mode") then - self:warning("(%s): Attempting to rehook already active hook %s.", func_name, method) + -- If hook_data already exists and it's the same hook_type, we can safely change the hook handler. + -- This should (in practice) only be used for debugging by modders who uses DoFile. + -- Revisit purpose when lua files are in plain text. + if can_rehook(mod, hook_data, obj, hook_type) then + hook_data.handler = handler else - self:info("(%s): Attempting to rehook already active hook %s.", func_name, method) + -- This should be a warning log, but currently there are no differences between warning and error. + -- Wouldn't want to scare users that mods are broken because this used to be acceptable. + -- This should be changed to permanently be a warning log after new_hooks deprecation period is over. + if vmf:get("developer_mode") then + mod:warning("(%s): Attempting to rehook active hook [%s] with different hook_type.", func_name, method) + else + mod:info("(%s): Attempting to rehook active hook [%s] with different hook_type.", func_name, method) + end end end @@ -237,16 +263,16 @@ end -- Valid styles: -- Giving a string pointing to a global object table and method string and hook function --- self, string (obj), string (method), function (handler), string (func_name) +-- mod, string (obj), string (method), function (handler), string (func_name) -- Giving an object table and a method string and hook function --- self, table (obj), string (method), function (handler), string (func_name) +-- mod, table (obj), string (method), function (handler), string (func_name) -- Giving a method string or a Obj.Method string (VT1 Style) and a hook function --- self, string (method), function (handler), nil, string (func_name) +-- mod, string (method), function (handler), nil, string (func_name) -local function generic_hook(self, obj, method, handler, func_name) - if vmf.check_wrong_argument_type(self, func_name, "obj", obj, "string", "table", "nil") or - vmf.check_wrong_argument_type(self, func_name, "method", method, "string", "function") or - vmf.check_wrong_argument_type(self, func_name, "handler", handler, "function", "nil") +local function generic_hook(mod, obj, method, handler, func_name) + if vmf.check_wrong_argument_type(mod, func_name, "obj", obj, "string", "table", "nil") or + vmf.check_wrong_argument_type(mod, func_name, "method", method, "string", "function") or + vmf.check_wrong_argument_type(mod, func_name, "handler", handler, "function", "nil") then return end @@ -256,7 +282,7 @@ local function generic_hook(self, obj, method, handler, func_name) handler = method method, obj = split_function_string(obj) if not method then - self:error("(%s): trying to create hook without giving a method name. %s", func_name) + mod:error("(%s): trying to create hook without giving a method name. %s", func_name) return end end @@ -269,41 +295,41 @@ local function generic_hook(self, obj, method, handler, func_name) if obj and not success then if _delaying_enabled and type(obj) == "string" then -- Call this func at a later time, using upvalues. - self:info("(%s): [%s.%s] needs to be delayed.", func_name, obj, method) + mod:info("(%s): [%s.%s] needs to be delayed.", func_name, obj, method) table.insert(_delayed, function() - generic_hook(self, obj, method, handler, func_name) + generic_hook(mod, obj, method, handler, func_name) end) return else - self:error("(%s): trying to hook object that doesn't exist: %s", func_name, obj) + mod:error("(%s): trying to hook object that doesn't exist: %s", func_name, obj) return end end -- Quick check to make sure the target exists if obj and not obj[method] then - self:error("(%s): trying to hook method that doesn't exist: [%s.%s]", func_name, obj, method) + mod:error("(%s): trying to hook method that doesn't exist: [%s.%s]", func_name, obj, method) return elseif not obj and not rawget(_G, method) then - self:error("(%s): trying to hook function that doesn't exist: [%s]", func_name, method) + mod:error("(%s): trying to hook function that doesn't exist: [%s]", func_name, method) return end -- obj can't be a string for these now. - local orig = get_orig_function(self, obj, method) + local orig = get_orig_function(mod, obj, method) if type(orig) ~= "function" then - self:error("(%s): trying to hook %s (a %s), not a function.", func_name, method, type(orig)) + mod:error("(%s): trying to hook %s (a %s), not a function.", func_name, method, type(orig)) return end - return create_hook(self, orig, obj, method, handler, func_name, hook_type) + return create_hook(mod, orig, obj, method, handler, func_name, hook_type) end -local function generic_hook_toggle(self, obj, method, enabled_state) - local func_name = (enabled_state) and "hook_enable" or "hook_disable" +local function generic_hook_toggle(mod, obj, method, enabled_state) + local func_name = (enabled_state and "hook_enable") or "hook_disable" - if vmf.check_wrong_argument_type(self, func_name, "obj", obj, "string", "table") or - vmf.check_wrong_argument_type(self, func_name, "method", method, "string", "nil") then + if vmf.check_wrong_argument_type(mod, func_name, "obj", obj, "string", "table") or + vmf.check_wrong_argument_type(mod, func_name, "method", method, "string", "nil") then return end @@ -312,7 +338,7 @@ local function generic_hook_toggle(self, obj, method, enabled_state) if type(obj) == "string" then method, obj = split_function_string(obj) else - self:error("(%s): trying to toggle hook without giving a method name. %s", func_name) + mod:error("(%s): trying to toggle hook without giving a method name. %s", func_name) end end @@ -320,24 +346,32 @@ local function generic_hook_toggle(self, obj, method, enabled_state) if obj and not success then if _delaying_enabled and type(obj) == "string" then -- Call this func at a later time, using upvalues. - self:info("(%s): [%s.%s] needs to be delayed.", func_name, obj, method) + mod:info("(%s): [%s.%s] needs to be delayed.", func_name, obj, method) table.insert(_delayed, function() - generic_hook_toggle(self, obj, method, enabled_state) + generic_hook_toggle(mod, obj, method, enabled_state) end) return else - self:error("(%s): trying to toggle hook on object that doesn't exist: %s", func_name, obj) + mod:error("(%s): trying to toggle hook on object that doesn't exist: %s", func_name, obj) return end end - local orig = get_orig_function(self, obj, method) + local orig = get_orig_function(mod, obj, method) - if _registry[self][orig] then - _registry[self][orig].active = enabled_state + if _registry[mod][orig] then + _registry[mod][orig].active = enabled_state else -- This has the potential for mod-breaking behavior, but not guaranteed - self:warning("(%s): trying to toggle hook that doesn't exist: %s", func_name, method) + mod:warning("(%s): trying to toggle hook that doesn't exist: %s", func_name, method) + end +end + +local function toggle_all_hooks_for_mod(mod, enabled_state) + local toggle_status = (enabled_state and "Enabling") or "Disabling" + mod:info("(hooks): %s all hooks for mod: %s", toggle_status, mod:get_name()) + for _, hook_data in pairs(_registry[mod]) do + hook_data.active = enabled_state end end @@ -372,21 +406,20 @@ function VMFMod:hook_origin(obj, method, handler) end -- Enable/disable functions for all hook types: -function VMFMod:hook_enable(obj, method) generic_hook_toggle(self, obj, method, true) end -function VMFMod:hook_disable(obj, method) generic_hook_toggle(self, obj, method, false) end +function VMFMod:hook_enable(obj, method) + generic_hook_toggle(self, obj, method, true) +end + +function VMFMod:hook_disable(obj, method) + generic_hook_toggle(self, obj, method, false) +end function VMFMod:enable_all_hooks() - self:info("(hooks): Enabling all hooks for mod: %s", self:get_name()) - for _, hook_data in pairs(_registry[self]) do - hook_data.active = true - end + toggle_all_hooks_for_mod(self, true) end function VMFMod:disable_all_hooks() - self:info("(hooks): Disabling all hooks for mod: %s", self:get_name()) - for _, hook_data in pairs(_registry[self]) do - hook_data.active = false - end + toggle_all_hooks_for_mod(self, false) end -- #################################################################################################################### @@ -395,7 +428,7 @@ end -- Remove all hooks when VMF is about to be reloaded vmf.hooks_unload = function() - for key, value in pairs(_registry.origs) do + for key, value in pairs(_origs) do -- origs[method] = orig if type(value) == "function" then _G[key] = value