Refactor API to use generic function for argument checking/handling

A lot of the code ended up being duplicated between the four api calls,
on top of that some type handling checks had to be made in both create_hook and get_orig_function.
This lead to a very messy state of things, so combine all of the handling into a single function,
then we can proceed with the rest of the code expecting one thing.

Having a single entry point made it very easy to add support for delaying hook cleanly.

Also add support for "Obj.Method" notation for backward-compatibility with old hooks.
I initially tried to use string.split provided by fatshark's code, but it creates a table, and requires a lot of supporting code that is all avoided by using a simple string.sub call.
This commit is contained in:
FireSiku 2018-06-02 07:10:45 -04:00
parent 78a1f5cbc0
commit 7a08f1c32b

View file

@ -20,6 +20,8 @@ local HOOK_ERR_NAME = { "hook", "before", "after", "rawhook", }
_registry.origs _registry.origs
]] ]]
local _delayed = {} -- dont need to attach this to registry.
-- This metatable will automatically create a table entry if one doesnt exist. -- This metatable will automatically create a table entry if one doesnt exist.
local auto_table_meta = {__index = function(t, k) t[k] = {} return t[k] end } local auto_table_meta = {__index = function(t, k) t[k] = {} return t[k] end }
@ -50,6 +52,24 @@ local function is_orig_hooked(obj, method)
return false return false
end end
-- Since we replace the original function, we need to keep its reference around.
-- This will grab the cached reference if we hooked it before, otherwise return the function.
local function get_orig_function(self, obj, method)
if obj then
if is_orig_hooked(obj, method) then
return _registry.origs[obj][method]
else
return obj[method]
end
else
if is_orig_hooked(obj, method) then
return _registry.origs[method]
else
return rawget(_G, method)
end
end
end
local function is_existing_hook(self, orig, hook_type) local function is_existing_hook(self, orig, hook_type)
local registry = _registry[self][hook_type] local registry = _registry[self][hook_type]
if registry and registry.handler and registry.handler[orig] then if registry and registry.handler and registry.handler[orig] then
@ -150,17 +170,6 @@ end
local function create_hook(self, orig, obj, method, handler, hook_type) local function create_hook(self, orig, obj, method, handler, hook_type)
local err_name = HOOK_ERR_NAME[hook_type] local err_name = HOOK_ERR_NAME[hook_type]
if type(handler) ~= "function" then
self:error("(%s): 'handler' - function expected, got %s", err_name, type(handler))
return
end
if type(obj) == "string" then
if not rawget(_G, obj) then
return
end
obj = _G[obj]
end
if not is_orig_hooked(obj, method) then if not is_orig_hooked(obj, method) then
create_internal_hook(orig, obj, method) create_internal_hook(orig, obj, method)
@ -170,7 +179,6 @@ local function create_hook(self, orig, obj, method, handler, hook_type)
if not is_existing_hook(self, orig, hook_type) then if not is_existing_hook(self, orig, hook_type) then
-- Also set up related info accessible to the hook object under self. -- Also set up related info accessible to the hook object under self.
if not _registry[self][hook_type] then if not _registry[self][hook_type] then
_registry[self][hook_type] = { _registry[self][hook_type] = {
active = {}, active = {},
handler = {}, handler = {},
@ -194,38 +202,60 @@ local function create_hook(self, orig, obj, method, handler, hook_type)
end end
-- Since we replace the original function, we need to keep its reference around. -- ####################################################################################################################
-- This will grab the cached reference if we hooked it before, otherwise return the function. -- ##### GENERIC API ##################################################################################################
local function get_orig_function(self, obj, method) -- ####################################################################################################################
-- Validate types -- Singular functions that works on a generic basis so the VMFMod API can be tailored for user simplicity.
if obj and not (type(obj) == "table" or type(obj) == "string") then
self:error("(hook): 'object' - table or string expected, got %s [args: %s, %s", type(obj), obj or "nil", method or "nil") -- Valid styles:
-- Giving a string pointing to a global object table and method string and hook function
-- self, string (obj), string (method), function (handler), hook_type(number)
-- Giving an object table and a method string and hook function
-- self, table (obj), string (method), function (handler), hook_type(number)
-- Giving a method string or a Obj.Method string (VT1 Style) and a hook function
-- self, string (method), function (handler), nil, hook_type(number)
local function generic_hook(self, obj, method, handler, hook_type)
local func_name = HOOK_ERR_NAME[hook_type]
if vmf.check_wrong_argument_type(self, func_name, "obj", obj, "string", "table") or
vmf.check_wrong_argument_type(self, func_name, "method", method, "string", "function") or
vmf.check_wrong_argument_type(self, func_name, "handler", handler, "function", "nil") then
return return
end end
if type(method) ~= "string" then
self:error("(hook): 'method' - string expected, got %s [args: %s, %s]", type(method), obj or "nil", method or "nil") -- Adjust the arguments.
return if type(method) == "function" then
handler = method
-- VT1 hooked everything using a "Obj.Method" string
-- Add backward compatibility for that format.
local find_position = string.find(obj, "%.")
if find_position then
method = string.sub(obj, find_position + 1)
obj = string.sub(obj, 1, find_position - 1)
end
end end
if obj then -- Check if hook should be delayed.
-- obj can be a string. We'll need to grab the actual object first. if type(obj) == "string" then
-- if we can't find object, we don't need to go any further. local obj_table = rawget(_G, obj)
if type(obj) == "string" then if obj_table then
if not rawget(_G, obj) then return end -- No delay required, grab object and move on
obj = _G[obj] obj = obj_table
end
if is_orig_hooked(obj, method) then
return _registry.origs[obj][method]
else else
return obj[method] -- Call this func at a later time, using upvalues.
end vmf:info("[%s.%s] needs to be delayed.", obj, method)
else table.insert(_delayed, function()
if is_orig_hooked(obj, method) then generic_hook(self, obj, method, handler, hook_type)
return _registry.origs[method] end)
else return
return _G[method]
end end
end end
-- obj can't be a string for these.
local orig = get_orig_function(self, obj, method)
return create_hook(self, orig, obj, method, handler, hook_type)
end end
-- #################################################################################################################### -- ####################################################################################################################
@ -243,12 +273,7 @@ end
-- These will always be executed before the hook chain. -- These will always be executed before the hook chain.
-- Due to discussion, handler may not receive any arguments, but will see what the use cases are with them first. -- Due to discussion, handler may not receive any arguments, but will see what the use cases are with them first.
function VMFMod:before(obj, method, handler) function VMFMod:before(obj, method, handler)
if type(method) == "function" then return generic_hook(self, obj, method, handler, HOOK_TYPE_BEFORE)
method, handler, obj = obj, method, nil
end
local orig = get_orig_function(self, obj, method)
create_hook(self, orig, obj, method, handler, HOOK_TYPE_BEFORE)
end end
-- :after() provides callback after a function is called. You have no control over the execution of the -- :after() provides callback after a function is called. You have no control over the execution of the
@ -256,12 +281,7 @@ end
-- These will always be executed after the hook chain. -- These will always be executed after the hook chain.
-- This is similar to :front() functionality in V1 modding. -- This is similar to :front() functionality in V1 modding.
function VMFMod:after(obj, method, handler) function VMFMod:after(obj, method, handler)
if type(method) == "function" then return generic_hook(self, obj, method, handler, HOOK_TYPE_AFTER)
method, handler, obj = obj, method, nil
end
local orig = get_orig_function(self, obj, method)
create_hook(self, orig, obj, method, handler, HOOK_TYPE_AFTER)
end end
-- :hook() will allow you to hook a function, allowing your handler to replace the function in the stack, -- :hook() will allow you to hook a function, allowing your handler to replace the function in the stack,
@ -269,12 +289,7 @@ end
-- original function at the end. Your handler has to call the next function in the chain manually. -- original function at the end. Your handler has to call the next function in the chain manually.
-- The chain of event is determined by mod load order. -- The chain of event is determined by mod load order.
function VMFMod:hook(obj, method, handler) function VMFMod:hook(obj, method, handler)
if type(method) == "function" then return generic_hook(self, obj, method, handler, HOOK_TYPE_NORMAL)
method, handler, obj = obj, method, nil
end
local orig = get_orig_function(self, obj, method)
create_hook(self, orig, obj, method, handler, HOOK_TYPE_NORMAL)
end end
-- :rawhook() allows you to directly hook a function, replacing it. The original function will bever be called. -- :rawhook() allows you to directly hook a function, replacing it. The original function will bever be called.
@ -283,12 +298,7 @@ end
-- This there is a limit of a single rawhook for any given function. -- This there is a limit of a single rawhook for any given function.
-- This should only be used as a last resort due to its limitation and its potential to break the game if not careful. -- This should only be used as a last resort due to its limitation and its potential to break the game if not careful.
function VMFMod:rawhook(obj, method, handler) function VMFMod:rawhook(obj, method, handler)
if type(method) == "function" then return generic_hook(self, obj, method, handler, HOOK_TYPE_RAW)
method, handler, obj = obj, method, nil
end
local orig = get_orig_function(self, obj, method)
create_hook(self, orig, obj, method, handler, HOOK_TYPE_RAW)
end end
function VMFMod:enable_all_hooks() function VMFMod:enable_all_hooks()
@ -315,4 +325,14 @@ end
-- -- removes all hooks when VMF is about to be reloaded -- -- removes all hooks when VMF is about to be reloaded
-- vmf.hooks_unload = function() -- vmf.hooks_unload = function()
-- end -- end
vmf.apply_delayed_hooks = function()
if #_delayed > 0 then
-- Go through the table in reverse so we don't get any issues removing entries inside the loop
for i = #_delayed, 1, -1 do
_delayed[i]()
table.remove(_delayed, i)
end
end
end