diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index afc69df..00f61e6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,12 +5,14 @@ on: branches: [ master ] pull_request: branches: [ master ] + workflow_dispatch: jobs: build: strategy: matrix: os: [ubuntu-latest, windows-latest] + arch: [x86, x86_64] include: - os: ubuntu-latest cc: clang @@ -19,7 +21,7 @@ jobs: cc: msvc fail-fast: false - name: ${{ matrix.os }} - ${{ matrix.cc }} + name: ${{ matrix.os }} - ${{ matrix.cc }} (${{ matrix.arch }}) runs-on: ${{ matrix.os }} steps: @@ -42,18 +44,18 @@ jobs: run: | python -m pip install --upgrade pip setuptools wheel - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 name: Repository checkout with: fetch-depth: 0 submodules: recursive path: repository - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 name: AMBuild checkout with: repository: alliedmodders/ambuild - ref: fe8c99ccd24fa926bf6ac6b9935a5fc03df06a72 + ref: 25a23ac92307eb1e181fd3e7d9385412d4034532 fetch-depth: 0 submodules: recursive path: ambuild @@ -61,7 +63,6 @@ jobs: - name: Setup AMBuild shell: bash run: | - cd "${{ env.CACHE_PATH }}" python -m pip install ./ambuild - name: Select clang compiler @@ -78,5 +79,11 @@ jobs: run: | mkdir build cd build - python ../configure.py - ambuild \ No newline at end of file + python ../configure.py --targets ${{ matrix.arch }} --enable-tests + ambuild + + - name: Test + shell: bash + working-directory: repository + run: | + python ./test.py --target ${{ matrix.arch }} --skip-build diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5e230ba --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +build/ +.vscode/ diff --git a/.gitmodules b/.gitmodules index 4ce19d0..32d5075 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,9 @@ [submodule "third_party/safetyhook"] path = third_party/safetyhook url = https://github.com/alliedmodders/safetyhook +[submodule "third_party/googletest"] + path = third_party/googletest + url = https://github.com/google/googletest.git +[submodule "third_party/gtest-parallel"] + path = third_party/gtest-parallel + url = https://github.com/google/gtest-parallel.git diff --git a/AMBuildScript b/AMBuildScript index f2972dd..dba145b 100644 --- a/AMBuildScript +++ b/AMBuildScript @@ -3,6 +3,9 @@ class KHook(object): self.all_targets = [] self.libsafetyhook = {} self.libkhook = {} + self.libgtest = {} + self.test_binaries = [] + def configure(self): target_archs = [] if builder.options.targets: @@ -29,9 +32,17 @@ class KHook(object): cxx.cxxflags += [ '/std:c++17' ] + if getattr(builder.options, 'enable_tests', False): + cxx.defines += ['KHOOK_TESTS'] KH = KHook() KH.configure() builder.Build('third_party/safetyhook/AMBuilder', {'SafetyHook': KH}) -builder.Build('AMBuilder', {'KHook': KH}) \ No newline at end of file +builder.Build('AMBuilder', {'KHook': KH}) + +if getattr(builder.options, 'enable_tests', False): + builder.Build('test/AMBuilder.gtest', {'GTest': KH}) + builder.Build('test/AMBuilder', {'TestRunner': KH, 'KHook': KH, 'SafetyHook': KH}) + +builder.Build('PackageScript', {'KHook': KH}) \ No newline at end of file diff --git a/AMBuilder b/AMBuilder index d3981df..f63e786 100644 --- a/AMBuilder +++ b/AMBuilder @@ -8,7 +8,8 @@ def AddSourceFilesFromDir(path, files): libkhook = builder.StaticLibraryProject('khook') libkhook.sources = AddSourceFilesFromDir(os.path.join(builder.currentSourcePath, 'src'),[ - "detour.cpp" + "detour.cpp", + "ranges.cpp" ]) for compiler in KHook.all_targets: diff --git a/PackageScript b/PackageScript new file mode 100644 index 0000000..265602d --- /dev/null +++ b/PackageScript @@ -0,0 +1,19 @@ +import os + +builder.SetBuildFolder('package') + +folder_list = [ + 'x86', + 'x86_64' +] + +folder_map = {} +for folder in folder_list: + norm_folder = os.path.normpath(folder) + folder_map[folder] = builder.AddFolder(norm_folder) + +for cxx_task in KHook.test_binaries: + if cxx_task.target.arch == 'x86_64': + builder.AddCopy(cxx_task.binary, folder_map['x86_64']) + else: + builder.AddCopy(cxx_task.binary, folder_map['x86']) \ No newline at end of file diff --git a/README.md b/README.md index 14f7ca8..526f532 100644 --- a/README.md +++ b/README.md @@ -90,4 +90,8 @@ All hooks can be configured through the use of the function `Configure`. ## Testing -There is currently no test suite. +The test suite uses the [GoogleTest](https://github.com/google/googletest) framework and can be built by configuring AMBuild with the `--enable-tests` flag. Tests must be executed in parallel with [gtest-parallel](https://github.com/google/gtest-parallel) in their own processes to avoid conflicts with each other. + +To execute the test suite, run the `test.py` script with the `--target` argument specifying the target architecture to test, `x86` for 32-bit or `x86_64` for 64-bit. For example, to test the 64-bit target, run `python test.py --target x86_64`. The script will handle building the tests and executing them in parallel. + +For more options, run `python test.py --help`. \ No newline at end of file diff --git a/configure.py b/configure.py index 9972f24..2eb25f8 100644 --- a/configure.py +++ b/configure.py @@ -4,4 +4,6 @@ parser = run.BuildParser(sourcePath = sys.path[0], api='2.2') parser.options.add_argument('--targets', type=str, dest='targets', default=None, help="Override the target architecture (use commas to separate multiple targets).") +parser.options.add_argument('--enable-tests', default=False, dest='enable_tests', action='store_true', + help='Build tests.') parser.Configure() \ No newline at end of file diff --git a/include/khook.hpp b/include/khook.hpp index 6ca0eee..0b9b07d 100644 --- a/include/khook.hpp +++ b/include/khook.hpp @@ -10,6 +10,7 @@ #pragma once #include +#include #include #include #include @@ -61,6 +62,8 @@ struct Return { }; class __Hook { +public: + virtual ~__Hook() = default; }; template @@ -74,6 +77,30 @@ class Hook : public __Hook { } } } + + template + static constexpr std::uint32_t _copy_stack_size() { +#ifdef _WIN64 + std::uint32_t num_args = sizeof...(ARGS); + if constexpr(!std::is_void_v) { + // Return value can be inserted as first arg as pointer to caller-alloc'd buffer + // (i.e., size != 1,2,4,8 or not C++03 POD). + // Add +1 to cover this case even if the return value doesn't end up being pushed + // as an arg, better to copy more than not enough. + num_args += 1; + } + + std::uint32_t num_stack_args = (num_args <= 4) ? 0 : (num_args - 4); + return num_stack_args * 8 + 32; +#else + std::uint32_t return_size = 0; + if constexpr(!std::is_void_v) { + return_size = std::max(sizeof(void*), sizeof(RETURN)); + } + + return return_size + (std::max(sizeof(void*), sizeof(ARGS)) + ... + 0); +#endif + } protected: RETURN* _fake_return = nullptr; }; @@ -93,48 +120,103 @@ using __mfp_const__ = RETURN (CLASS::*)(ARGS...) const; template using __mfp__ = RETURN (CLASS::*)(ARGS...); -template -inline __mfp__ BuildMFP(void* addr) { +template +inline FUNC BuildMFP(const void* addr) { + static_assert(std::is_member_function_pointer::value, "Error: FUNC is not a member function pointer!"); union { - R (C::*mfp)(A...); + FUNC mfp; struct { - void* addr; -#ifdef _WIN32 -#else - intptr_t adjustor; -#endif + const void* addr; + intptr_t chunks[3]; } details; } open; open.details.addr = addr; -#ifdef _WIN32 -#else - open.details.adjustor = 0; -#endif + + + if constexpr(sizeof(FUNC) >= 2 * sizeof(void*)) { + open.details.chunks[0] = 0; + } + + if constexpr(sizeof(FUNC) >= 3 * sizeof(void*)) { + open.details.chunks[1] = 0; + } + + if constexpr(sizeof(FUNC) >= 4 * sizeof(void*)) { + open.details.chunks[2] = 0; + } return open.mfp; } -template -inline __mfp_const__ BuildMFP(const void* addr) { +template +inline void FillMFP(FUNC* mfp, void* addr) { + static_assert(std::is_member_function_pointer::value, "Error: FUNC is not a member function pointer!"); + union open { + FUNC mfp; + struct { + void* addr; + intptr_t chunks[3]; + } details; + }; + + ((open*)mfp)->details.addr = addr; + + if constexpr(sizeof(FUNC) >= 2 * sizeof(void*)) { + ((open*)mfp)->details.chunks[0] = 0; + } + + if constexpr(sizeof(FUNC) >= 3 * sizeof(void*)) { + ((open*)mfp)->details.chunks[1] = 0; + } + + if constexpr(sizeof(FUNC) >= 4 * sizeof(void*)) { + ((open*)mfp)->details.chunks[2] = 0; + } +} + +template +inline void* ExtractMFP(FUNC mfp) { + static_assert(std::is_member_function_pointer::value, "Error: FUNC is not a member function pointer!"); union { - R (C::*mfp)(A...) const; + FUNC mfp; struct { - const void* addr; -#ifdef _WIN32 -#else - intptr_t adjustor; -#endif + void* addr; + intptr_t chunks[4]; } details; } open; - open.details.addr = addr; -#ifdef _WIN32 -#else - open.details.adjustor = 0; -#endif - return open.mfp; + open.mfp = mfp; + return open.details.addr; } +template +struct __internal__member_function_class; + +#define __INTERNAL__KHOOK_MAKE_TRAIT(...) \ +template \ +struct __internal__member_function_class \ +{ \ + using type = C; \ +}; + +__INTERNAL__KHOOK_MAKE_TRAIT() +__INTERNAL__KHOOK_MAKE_TRAIT(const) +__INTERNAL__KHOOK_MAKE_TRAIT(volatile) +__INTERNAL__KHOOK_MAKE_TRAIT(const volatile) + +__INTERNAL__KHOOK_MAKE_TRAIT(&) +__INTERNAL__KHOOK_MAKE_TRAIT(const &) +__INTERNAL__KHOOK_MAKE_TRAIT(volatile &) +__INTERNAL__KHOOK_MAKE_TRAIT(const volatile &) + +__INTERNAL__KHOOK_MAKE_TRAIT(&&) +__INTERNAL__KHOOK_MAKE_TRAIT(const &&) +__INTERNAL__KHOOK_MAKE_TRAIT(volatile &&) +__INTERNAL__KHOOK_MAKE_TRAIT(const volatile &&) + +__INTERNAL__KHOOK_MAKE_TRAIT(noexcept) +__INTERNAL__KHOOK_MAKE_TRAIT(const noexcept) + /** * Creates a hook around the given function address. * @@ -145,10 +227,11 @@ inline __mfp_const__ BuildMFP(const void* addr) { * @param post Function to call with the original this ptr (if any), after the hooked function is called. * @param make_return Function to call with the original this ptr (if any), to make the final return value. * @param make_call_original Function to call with the original this ptr (if any), to call the original function and store the return value if needed. + * @param stack_size Size of the stack in bytes for the detoured function. * @param async By default set to false. If set to true, the hook will be added synchronously. Beware if performed while the hooked function is processing this could deadlock. * @return The created hook id on success, INVALID_HOOK otherwise. */ -KHOOK_API HookID_t SetupHook(void* function, void* context, void* removed_function, void* pre, void* post, void* make_return, void* make_call_original, bool async = false); +KHOOK_API HookID_t SetupHook(void* function, void* context, void* removed_function, void* pre, void* post, void* make_return, void* make_call_original, unsigned int stack_size, bool async = false); /** * Creates a hook around the given function retrieved from a vtable. @@ -161,10 +244,11 @@ KHOOK_API HookID_t SetupHook(void* function, void* context, void* removed_functi * @param post Function to call with the original this ptr (if any), after the hooked function is called. * @param make_return Function to call with the original this ptr (if any), to make the final return value. * @param make_call_original Function to call with the original this ptr (if any), to call the original function and store the return value if needed. + * @param stack_size Size of the stack in bytes for the detoured function. * @param async By default set to false. If set to true, the hook will be added synchronously. Beware if performed while the hooked function is processing this could deadlock. * @return The created hook id on success, INVALID_HOOK otherwise. */ -KHOOK_API HookID_t SetupVirtualHook(void** vtable, int index, void* context, void* removed_function, void* pre, void* post, void* make_return, void* make_call_original, bool async = false); +KHOOK_API HookID_t SetupVirtualHook(void** vtable, int index, void* context, void* removed_function, void* pre, void* post, void* make_return, void* make_call_original, unsigned int stack_size, bool async = false); /** * Removes a given hook. Beware if this is performed synchronously under a hook callback this could deadlock or crash. @@ -195,6 +279,18 @@ KHOOK_API void* DoRecall(KHook::Action action, void* ptr_to_return, std::size_t */ KHOOK_API void SaveReturnValue(KHook::Action action, void* ptr_to_return, std::size_t return_size, void* init_op, void* deinit_op, bool original); +/** + * Lookup a bytes sequence (signature) in a given address block. This takes into account any detour that might have been created by the framework, + * and ensure that the compared bytes are the original bytes. + * + * @param start Start address of the memory block to search. + * @param size Size in bytes of the memory block to search. + * @param signature Byte sequence in the format of "0F A2 28 ?? EA" (IDA bytes sequence format), where ?? signifies a wildcard byte. + * + * @return An address if lookup succeeded. nullptr otherwise. + */ +KHOOK_API void* LookupSignature(void* start, std::size_t size, const char* signature); + template void init_operator(TYPE* assignee, TYPE* value) { new (assignee) TYPE(*value); @@ -237,45 +333,28 @@ inline void __internal__savereturnvalue(const ::KHook::Return &ret, bool ::KHook::SaveReturnValue(ret.action, return_ptr, size, init_op, deinit_op, original); } -template -inline ::KHook::Return Recall(RETURN (*)(ARGS...), const ::KHook::Return &ret, ARGS... args) { - RETURN (*function)(ARGS...) = (decltype(function))::KHook::__internal__dorecall(ret); - (*function)(args...); - return ret; -} - -template -inline ::KHook::Return Recall(RETURN (CLASS::*)(ARGS...), const ::KHook::Return &ret, CLASS* ptr, ARGS... args) { - auto mfp = ::KHook::BuildMFP(::KHook::__internal__dorecall(ret)); - (ptr->*mfp)(args...); - return ret; -} - -template -inline ::KHook::Return Recall(RETURN (CLASS::*)(ARGS...), const ::KHook::Return &ret, const CLASS* ptr, ARGS... args) { - auto mfp = ::KHook::BuildMFP((const void*)::KHook::__internal__dorecall(ret)); - (ptr->*mfp)(args...); - return ret; -} - -template -inline ::KHook::Return Recall(const ::KHook::Return &ret, ARGS... args) { - RETURN (*function)(ARGS...) = (decltype(function))::KHook::__internal__dorecall(ret); - (*function)(args...); - return ret; +template +inline RETURN ManualReturn(const ::KHook::Return &ret, bool original = false) { + ::KHook::__internal__savereturnvalue(ret, original); + return ret.ret; } -template -inline ::KHook::Return Recall(const ::KHook::Return &ret, CLASS* ptr, ARGS... args) { - auto mfp = ::KHook::BuildMFP(::KHook::__internal__dorecall(ret)); - (ptr->*mfp)(args...); - return ret; +template +inline void __MFP__Recall(void* addr, F f, CLASS&& this_ptr, ARGS&&... args) { + F dummy_func = nullptr; + ::KHook::FillMFP(&dummy_func, addr); + (this_ptr->*dummy_func)(std::forward(args)...); } -template -inline ::KHook::Return Recall(const ::KHook::Return &ret, const CLASS* ptr, ARGS... args) { - auto mfp = ::KHook::BuildMFP((const void*)::KHook::__internal__dorecall(ret)); - (ptr->*mfp)(args...); +template +inline ::KHook::Return> Recall(F f, const ::KHook::Return> &ret, ARGS&&... args) { + auto addr = ::KHook::__internal__dorecall(ret); + if constexpr (std::is_member_function_pointer::value) { + ::KHook::__MFP__Recall(addr, f, std::forward(args)...); + } else { + F function = (decltype(f))addr; + (*function)(std::forward(args)...); + } return ret; } @@ -287,7 +366,7 @@ inline ::KHook::Return Recall(const ::KHook::Return &ret, const KHOOK_API void* GetOriginalFunction(); /** - * Thread local function, only to be called under KHook callbacks. It returns a pointer containing the original return value (if not superceded). + * Thread local function, only to be called under KHook callbacks. It returns a pointer containing the original return value (if not superseded). * * @return The original value pointer. Behaviour is undefined if called outside POST callbacks. */ @@ -307,6 +386,13 @@ KHOOK_API void* GetOverrideValuePtr(); */ KHOOK_API void* GetCurrentValuePtr(bool pop = false); +/** + * Thread local function, only to be called under KHook callbacks. It informs whether or not the original function was skipped. + * + * @return True if skipped, false otherwise. Behaviour is undefined if called outside POST callbacks. + */ +KHOOK_API bool WasOriginalFunctionSkipped(); + /** * Thread local function, only to be called when the hook callbacks loop is over, any earlier will cause undefined behaviour. * @@ -340,40 +426,6 @@ KHOOK_API void* FindOriginalVirtual(void** vtable, int index); */ KHOOK_API void Shutdown(); -template -inline void* ExtractMFP(R (C::*mfp)(A...)) { - union { - R (C::*mfp)(A...); - struct { - void* addr; -#ifdef _WIN32 -#else - intptr_t adjustor; -#endif - } details; - } open; - - open.mfp = mfp; - return open.details.addr; -} - -template -inline const void* ExtractMFP(R (C::*mfp)(A...) const) { - union { - R (C::*mfp)(A...) const; - struct { - const void* addr; -#ifdef _WIN32 -#else - intptr_t adjustor; -#endif - } details; - } open; - - open.mfp = mfp; - return open.details.addr; -} - template class Function : public Hook { class EmptyClass {}; @@ -386,9 +438,6 @@ class Function : public Hook { Function(fnCallback pre, fnCallback post) : _pre_callback(pre), _post_callback(post), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -397,9 +446,6 @@ class Function : public Hook { Function(RETURN (*function)(ARGS...), fnCallback pre, fnCallback post) : _pre_callback(pre), _post_callback(post), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -409,9 +455,6 @@ class Function : public Hook { Function(RETURN (*function)(ARGS...), fnCallback pre, std::nullptr_t) : _pre_callback(pre), _post_callback(nullptr), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -421,9 +464,6 @@ class Function : public Hook { Function(RETURN (*function)(ARGS...), std::nullptr_t, fnCallback post) : _pre_callback(nullptr), _post_callback(post), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -434,48 +474,52 @@ class Function : public Hook { Function(CONTEXT* context, fnContextCallback pre, fnContextCallback post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; } template Function(CONTEXT* context, fnContextCallback pre, std::nullptr_t) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; } template Function(CONTEXT* context, std::nullptr_t, fnContextCallback post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(nullptr), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; } template Function(RETURN (*function)(ARGS...), CONTEXT* context, fnContextCallback pre, fnContextCallback post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; Configure(function); } @@ -483,12 +527,13 @@ class Function : public Hook { Function(RETURN (*function)(ARGS...), CONTEXT* context, fnContextCallback pre, std::nullptr_t) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; Configure(function); } @@ -496,12 +541,13 @@ class Function : public Hook { Function(RETURN (*function)(ARGS...), CONTEXT* context, std::nullptr_t, fnContextCallback post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(nullptr), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; Configure(function); } @@ -518,7 +564,53 @@ class Function : public Hook { } } - void Configure(const void* address) { + template + void AddContext(CONTEXT* context, fnContextCallback pre, std::nullptr_t) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; + } + + template + void AddContext(CONTEXT* context, fnContextCallback pre, fnContextCallback post) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; + } + + template + void AddContext(CONTEXT* context, std::nullptr_t, fnContextCallback post) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; + } + + template + void RemoveContext(CONTEXT* context) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs.erase((EmptyClass*)context); + } + + inline void Configure(RETURN (*function)(ARGS...)) { + return _Configure(reinterpret_cast(function)); + } + + RETURN CallOriginal(ARGS... args) { + RETURN (*function)(ARGS...) = (decltype(function))::KHook::FindOriginal((void*)_hooked_addr); + return (*function)(args...); + } +protected: + inline void _Configure(void* address) { + return _Configure(reinterpret_cast(address)); + } + + void _Configure(const void* address) { if (address == nullptr || _in_deletion) { return; } @@ -536,11 +628,12 @@ class Function : public Hook { _associated_hook_id = ::KHook::SetupHook( (void*)address, this, - ExtractMFP(&Self::_KHook_RemovedHook), + (void*)Self::_KHook_RemovedHook, (void*)Self::_KHook_Callback_PRE, // preMFP (void*)Self::_KHook_Callback_POST, // postMFP (void*)Self::_KHook_MakeReturn, // returnMFP, (void*)Self::_KHook_MakeOriginalCall, // callOriginalMFP + Self::template _copy_stack_size(), true // For safety reasons we are adding hooks asynchronously. If performance is required, reimplement this class ); if (_associated_hook_id != INVALID_HOOK) { @@ -550,25 +643,15 @@ class Function : public Hook { } } - inline void Configure(void* address) { - return Configure(reinterpret_cast(address)); - } - - inline void Configure(RETURN (*function)(ARGS...)) { - return Configure(reinterpret_cast(function)); - } - - RETURN CallOriginal(ARGS... args) { - RETURN (*function)(ARGS...) = (decltype(function))::KHook::FindOriginal((void*)_hooked_addr); - return (*function)(args...); - } -protected: - // Various filters to make MemberHook class useful fnCallback _pre_callback; fnCallback _post_callback; - void* _context; - void* _context_pre_callback; - void* _context_post_callback; + + struct __context_details { + __mfp__, ARGS...> pre; + __mfp__, ARGS...> post; + }; + std::mutex _m_context_ptrs; + std::unordered_map _context_ptrs; bool _in_deletion; std::mutex _hooks_stored; @@ -577,24 +660,50 @@ class Function : public Hook { HookID_t _associated_hook_id; const void* _hooked_addr; // Called by KHook - void _KHook_RemovedHook(HookID_t id) { - std::lock_guard guard(_hooks_stored); - _hook_ids.erase(id); - if (id == _associated_hook_id) { - _associated_hook_id = INVALID_HOOK; + static void _KHook_RemovedHook(HookID_t id) { + auto ctx = reinterpret_cast(KHook::GetContext()); + + std::lock_guard guard(ctx->_hooks_stored); + ctx->_hook_ids.erase(id); + if (id == ctx->_associated_hook_id) { + ctx->_associated_hook_id = INVALID_HOOK; } } // Fixed KHook callback void _KHook_Callback_Fixed(bool post, ARGS... args) { - auto context_callback = (post) ? this->_context_post_callback : this->_context_pre_callback; - auto callback = (post) ? this->_post_callback : this->_pre_callback; + KHook::Return action; + action.action = KHook::Action::Ignore; - if (callback == nullptr && context_callback == nullptr) { - return; + if (post && this->_post_callback) { + action = (*this->_post_callback)(args...); + } else if (!post && this->_pre_callback) { + action = (*this->_pre_callback)(args...); } - Return action = (_context) ? (((EmptyClass*)_context)->*BuildMFP, ARGS...>(context_callback))(args...) : (*callback)(args...); + { + decltype(this->_context_ptrs) copied_ctxs; + { + // This can deadlock (in case of recalls) + // so make a deep-copy + std::lock_guard guard(this->_m_context_ptrs); + copied_ctxs = this->_context_ptrs; + } + for (const auto& context : copied_ctxs) { + auto context_ptr = context.first; + if (post && context.second.post) { + auto new_action = (context_ptr->*(context.second.post))(args...); + if (new_action.action > action.action) { + action = new_action; + } + } else if (!post && context.second.pre) { + auto new_action = (context_ptr->*(context.second.pre))(args...); + if (new_action.action > action.action) { + action = new_action; + } + } + } + } ::KHook::__internal__savereturnvalue(action, false); } @@ -656,13 +765,18 @@ class Member : public Hook { using fnCallbackConst = ::KHook::Return (*)(const CLASS*, ARGS...); using Self = ::KHook::Member; + Member() : + _pre_callback(nullptr), + _post_callback(nullptr), + _in_deletion(false), + _associated_hook_id(INVALID_HOOK), + _hooked_addr(nullptr) { + } + // CTOR - No function Member(fnCallback pre, fnCallback post) : _pre_callback(pre), _post_callback(post), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -672,9 +786,6 @@ class Member : public Hook { Member(fnCallbackConst pre, fnCallbackConst post) : _pre_callback(reinterpret_cast(pre)), _post_callback(reinterpret_cast(post)), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -684,9 +795,6 @@ class Member : public Hook { Member(RETURN (CLASS::*function)(ARGS...), fnCallback pre, fnCallback post) : _pre_callback(pre), _post_callback(post), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -695,9 +803,6 @@ class Member : public Hook { Member(void* function, fnCallback pre, fnCallback post) : _pre_callback(pre), _post_callback(post), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -708,9 +813,6 @@ class Member : public Hook { Member(RETURN (CLASS::*function)(ARGS...), std::nullptr_t, fnCallback post) : _pre_callback(nullptr), _post_callback(post), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -719,9 +821,6 @@ class Member : public Hook { Member(void* function, std::nullptr_t, fnCallback post) : _pre_callback(nullptr), _post_callback(post), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -732,9 +831,6 @@ class Member : public Hook { Member(RETURN (CLASS::*function)(ARGS...), fnCallback pre, std::nullptr_t) : _pre_callback(pre), _post_callback(nullptr), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -743,9 +839,6 @@ class Member : public Hook { Member(void* function, fnCallback pre, std::nullptr_t) : _pre_callback(pre), _post_callback(nullptr), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -756,9 +849,6 @@ class Member : public Hook { Member(RETURN (CLASS::*function)(ARGS...) const, fnCallbackConst pre, fnCallbackConst post) : _pre_callback(pre), _post_callback(post), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -767,9 +857,6 @@ class Member : public Hook { Member(const void* function, fnCallbackConst pre, fnCallbackConst post) : _pre_callback(pre), _post_callback(post), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -780,9 +867,6 @@ class Member : public Hook { Member(RETURN (CLASS::*function)(ARGS...) const, std::nullptr_t, fnCallbackConst post) : _pre_callback(nullptr), _post_callback(post), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -791,9 +875,6 @@ class Member : public Hook { Member(const void* function, std::nullptr_t, fnCallbackConst post) : _pre_callback(nullptr), _post_callback(post), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -804,9 +885,6 @@ class Member : public Hook { Member(RETURN (CLASS::*function)(ARGS...) const, fnCallbackConst pre, std::nullptr_t) : _pre_callback(pre), _post_callback(nullptr), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -815,9 +893,6 @@ class Member : public Hook { Member(const void* function, fnCallbackConst pre, std::nullptr_t) : _pre_callback(pre), _post_callback(nullptr), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { @@ -829,12 +904,13 @@ class Member : public Hook { Member(CONTEXT* context, fnContextCallback pre, fnContextCallback post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; } // CTOR - CONST - No function - Context @@ -842,12 +918,13 @@ class Member : public Hook { Member(CONTEXT* context, fnContextCallbackConst pre, fnContextCallbackConst post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; } // CTOR - No function - Context - NULL POST @@ -855,12 +932,13 @@ class Member : public Hook { Member(CONTEXT* context, fnContextCallback pre, std::nullptr_t) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; } // CTOR - CONST - No function - Context - NULL POST @@ -868,12 +946,13 @@ class Member : public Hook { Member(CONTEXT* context, fnContextCallbackConst pre, std::nullptr_t) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; } // CTOR - No function - Context - NULL PRE @@ -881,12 +960,13 @@ class Member : public Hook { Member(CONTEXT* context, std::nullptr_t, fnContextCallback post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(nullptr), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; } // CTOR - CONST - No function - Context - NULL PRE @@ -894,12 +974,13 @@ class Member : public Hook { Member(CONTEXT* context, std::nullptr_t, fnContextCallbackConst post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(nullptr), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; } // CTOR - Function - Context @@ -907,24 +988,26 @@ class Member : public Hook { Member(RETURN (CLASS::*function)(ARGS...), CONTEXT* context, fnContextCallback pre, fnContextCallback post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; Configure(function); } template Member(void* function, CONTEXT* context, fnContextCallback pre, fnContextCallback post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; Configure(function); } @@ -933,24 +1016,26 @@ class Member : public Hook { Member(RETURN (CLASS::*function)(ARGS...) const, CONTEXT* context, fnContextCallbackConst pre, fnContextCallbackConst post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; Configure(function); } template Member(const void* function, CONTEXT* context, fnContextCallbackConst pre, fnContextCallbackConst post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; Configure(function); } @@ -959,24 +1044,26 @@ class Member : public Hook { Member(RETURN (CLASS::*function)(ARGS...), CONTEXT* context, fnContextCallback pre, std::nullptr_t) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; Configure(function); } template Member(void* function, CONTEXT* context, fnContextCallback pre, std::nullptr_t) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; Configure(function); } @@ -985,24 +1072,26 @@ class Member : public Hook { Member(RETURN (CLASS::*function)(ARGS...) const, CONTEXT* context, fnContextCallbackConst pre, std::nullptr_t) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; Configure(function); } template Member(const void* function, CONTEXT* context, fnContextCallbackConst pre, std::nullptr_t) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(nullptr), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; Configure(function); } @@ -1011,24 +1100,26 @@ class Member : public Hook { Member(RETURN (CLASS::*function)(ARGS...), CONTEXT* context, std::nullptr_t, fnContextCallback post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(nullptr), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; Configure(function); } template Member(void* function, CONTEXT* context, std::nullptr_t, fnContextCallback post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(nullptr), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; Configure(function); } @@ -1037,24 +1128,26 @@ class Member : public Hook { Member(RETURN (CLASS::*function)(ARGS...) const, CONTEXT* context, std::nullptr_t, fnContextCallbackConst post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(nullptr), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; Configure(function); } template Member(const void* function, CONTEXT* context, std::nullptr_t, fnContextCallbackConst post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(nullptr), - _context_post_callback(ExtractMFP(post)), _in_deletion(false), _associated_hook_id(INVALID_HOOK), _hooked_addr(nullptr) { + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; Configure(function); } @@ -1071,7 +1164,88 @@ class Member : public Hook { } } - void Configure(const void* address) { + template + void AddContext(CONTEXT* context, fnContextCallback pre, std::nullptr_t) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; + } + + template + void AddContext(CONTEXT* context, fnContextCallbackConst pre, std::nullptr_t) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; + } + + template + void AddContext(CONTEXT* context, fnContextCallback pre, fnContextCallback post) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; + } + + template + void AddContext(CONTEXT* context, fnContextCallbackConst pre, fnContextCallbackConst post) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; + } + + template + void AddContext(CONTEXT* context, std::nullptr_t, fnContextCallback post) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; + } + + template + void AddContext(CONTEXT* context, std::nullptr_t, fnContextCallbackConst post) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; + } + + template + void RemoveContext(CONTEXT* context) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs.erase((EmptyClass*)context); + } + + inline void Configure(const void* addr) { + return _Configure(addr); + } + + inline void Configure(RETURN (CLASS::*function)(ARGS...)) { + return _Configure(::KHook::ExtractMFP(function)); + } + + inline void Configure(RETURN (CLASS::*function)(ARGS...) const) { + return _Configure(::KHook::ExtractMFP(function)); + } + + RETURN CallOriginal(CLASS* this_ptr, ARGS... args) { + auto original_func = KHook::FindOriginal((void*)_hooked_addr); + auto mfp = KHook::BuildMFP(original_func); + return (this_ptr->*mfp)(args...); + } +protected: + inline void _Configure(void* address) { + return _Configure(reinterpret_cast(address)); + } + void _Configure(const void* address) { if (address == nullptr || _in_deletion) { return; } @@ -1086,14 +1260,15 @@ class Member : public Hook { ::KHook::RemoveHook(_associated_hook_id, true); } - _associated_hook_id = SetupHook( + _associated_hook_id = ::KHook::SetupHook( (void*)address, this, - ExtractMFP(&Self::_KHook_RemovedHook), + (void*)&Self::_KHook_RemovedHook, ExtractMFP(&Self::_KHook_Callback_PRE), // preMFP ExtractMFP(&Self::_KHook_Callback_POST), // postMFP ExtractMFP(&Self::_KHook_MakeReturn), // returnMFP, ExtractMFP(&Self::_KHook_MakeOriginalCall), // callOriginalMFP + Self::template _copy_stack_size(), true // For safety reasons we are adding hooks asynchronously. If performance is required, reimplement this class ); if (_associated_hook_id != INVALID_HOOK) { @@ -1103,30 +1278,15 @@ class Member : public Hook { } } - inline void Configure(void* address) { - return Configure(reinterpret_cast(address)); - } - - inline void Configure(RETURN (CLASS::*function)(ARGS...)) { - return Configure(ExtractMFP(function)); - } - - inline void Configure(RETURN (CLASS::*function)(ARGS...) const) { - return Configure(ExtractMFP(function)); - } - - RETURN CallOriginal(CLASS* this_ptr, ARGS... args) { - auto original_func = KHook::FindOriginal((void*)_hooked_addr); - auto mfp = KHook::BuildMFP(original_func); - return (this_ptr->*mfp)(args...); - } -protected: - // Various filters to make MemberHook class useful fnCallback _pre_callback; fnCallback _post_callback; - void* _context; - void* _context_pre_callback; - void* _context_post_callback; + + struct __context_details { + __mfp__, CLASS*, ARGS...> pre; + __mfp__, CLASS*, ARGS...> post; + }; + std::mutex _m_context_ptrs; + std::unordered_map _context_ptrs; bool _in_deletion; std::mutex _hooks_stored; @@ -1136,24 +1296,50 @@ class Member : public Hook { const void* _hooked_addr; // Called by KHook - void _KHook_RemovedHook(HookID_t id) { - std::lock_guard guard(_hooks_stored); - _hook_ids.erase(id); - if (id == _associated_hook_id) { - _associated_hook_id = INVALID_HOOK; + static void _KHook_RemovedHook(HookID_t id) { + auto ctx = reinterpret_cast(KHook::GetContext()); + + std::lock_guard guard(ctx->_hooks_stored); + ctx->_hook_ids.erase(id); + if (id == ctx->_associated_hook_id) { + ctx->_associated_hook_id = INVALID_HOOK; } } // Fixed KHook callback void _KHook_Callback_Fixed(bool post, CLASS* hooked_this, ARGS... args) { - fnContextCallback context_callback = KHook::BuildMFP, CLASS*, ARGS...>((post) ? this->_context_post_callback : this->_context_pre_callback); - auto callback = (post) ? this->_post_callback : this->_pre_callback; + KHook::Return action; + action.action = KHook::Action::Ignore; - if (callback == nullptr && context_callback == nullptr) { - return; + if (post && this->_post_callback) { + action = (*this->_post_callback)(hooked_this, args...); + } else if (!post && this->_pre_callback) { + action = (*this->_pre_callback)(hooked_this, args...); } - Return action = (_context) ? (((EmptyClass*)_context)->*context_callback)(hooked_this, args...) : (*callback)(hooked_this, args...); + { + decltype(this->_context_ptrs) copied_ctxs; + { + // This can deadlock (in case of recalls) + // so make a deep-copy + std::lock_guard guard(this->_m_context_ptrs); + copied_ctxs = this->_context_ptrs; + } + for (const auto& context : copied_ctxs) { + auto context_ptr = context.first; + if (post && context.second.post) { + auto new_action = (context_ptr->*(context.second.post))(hooked_this, args...); + if (new_action.action > action.action) { + action = new_action; + } + } else if (!post && context.second.pre) { + auto new_action = (context_ptr->*(context.second.pre))(hooked_this, args...); + if (new_action.action > action.action) { + action = new_action; + } + } + } + } ::KHook::__internal__savereturnvalue(action, false); } @@ -1193,7 +1379,7 @@ class Member : public Hook { // Called if the hook wasn't superceded RETURN _KHook_MakeOriginalCall(ARGS ...args) { - RETURN (EmptyClass::*ptr)(ARGS...) = BuildMFP(::KHook::GetOriginalFunction()); + auto ptr = ::KHook::BuildMFP(::KHook::GetOriginalFunction()); if constexpr(std::is_same::value) { (((EmptyClass*)this)->*ptr)(args...); ::KHook::__internal__savereturnvalue(KHook::Return{ KHook::Action::Ignore }, true); @@ -1205,20 +1391,18 @@ class Member : public Hook { } }; -template -inline std::int32_t GetVtableIndex(RETURN (CLASS::*function)(ARGS...)); +template +inline std::int32_t GetVtableIndex(FUNC function); -template -inline std::int32_t GetVtableIndex(RETURN (CLASS::*function)(ARGS...) const); - -template -inline __mfp__ GetVtableFunction(CLASS* ptr, RETURN (CLASS::*mfp)(ARGS...)) { +template +inline FUNC GetVtableFunction(CLASS* ptr, FUNC mfp) { + static_assert(std::is_member_function_pointer::value, "Error: FUNC is not a member function pointer!"); void** vtable = *(void***)ptr; auto index = ::KHook::GetVtableIndex(mfp); if (index == -1) { return nullptr; } - return BuildMFP(vtable[index]); + return ::KHook::BuildMFP(vtable[index]); } template @@ -1228,15 +1412,17 @@ inline __mfp_const__ GetVtableFunction(const CLASS* ptr, if (index == -1) { return nullptr; } - return BuildMFP(vtable[index]); + return ::KHook::BuildMFP<__mfp_const__>(vtable[index]); } template inline __mfp__ GetVtableFunction(CLASS* ptr, std::uint32_t index) { void** vtable = *(void***)ptr; - return BuildMFP(vtable[index]); + return ::KHook::BuildMFP<__mfp__>(vtable[index]); } +using VirtualHookId_t = std::uint32_t; + template class Virtual : public Hook { static constexpr std::uint32_t INVALID_VTBL_INDEX = -1; @@ -1250,13 +1436,17 @@ class Virtual : public Hook { using fnCallbackConst = ::KHook::Return (*)(const CLASS*, ARGS...); using Self = ::KHook::Virtual; + Virtual() : + _pre_callback(nullptr), + _post_callback(nullptr), + _vtbl_index(INVALID_VTBL_INDEX), + _in_deletion(false) { + } + // CTOR - No function Virtual(fnCallback pre, fnCallback post) : _pre_callback(pre), _post_callback(post), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _vtbl_index(INVALID_VTBL_INDEX), _in_deletion(false) { } @@ -1265,20 +1455,22 @@ class Virtual : public Hook { Virtual(fnCallbackConst pre, fnCallbackConst post) : _pre_callback(reinterpret_cast(pre)), _post_callback(reinterpret_cast(post)), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _vtbl_index(INVALID_VTBL_INDEX), _in_deletion(false) { } + // CTOR - Function - NO PRE OR POST + Virtual(RETURN (CLASS::*function)(ARGS...)) : + _pre_callback(nullptr), + _post_callback(nullptr), + _vtbl_index(GetVtableIndex(function)), + _in_deletion(false) { + } + // CTOR - Function Virtual(RETURN (CLASS::*function)(ARGS...), fnCallback pre, fnCallback post) : _pre_callback(pre), _post_callback(post), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _vtbl_index(GetVtableIndex(function)), _in_deletion(false) { } @@ -1287,9 +1479,6 @@ class Virtual : public Hook { Virtual(RETURN (CLASS::*function)(ARGS...), std::nullptr_t, fnCallback post) : _pre_callback(nullptr), _post_callback(post), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _vtbl_index(GetVtableIndex(function)), _in_deletion(false) { } @@ -1298,9 +1487,6 @@ class Virtual : public Hook { Virtual(RETURN (CLASS::*function)(ARGS...), fnCallback pre, std::nullptr_t) : _pre_callback(pre), _post_callback(nullptr), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _vtbl_index(GetVtableIndex(function)), _in_deletion(false) { } @@ -1309,9 +1495,6 @@ class Virtual : public Hook { Virtual(RETURN (CLASS::*function)(ARGS...) const, fnCallbackConst pre, fnCallbackConst post) : _pre_callback(reinterpret_cast(pre)), _post_callback(reinterpret_cast(post)), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _vtbl_index(GetVtableIndex(function)), _in_deletion(false) { } @@ -1320,9 +1503,6 @@ class Virtual : public Hook { Virtual(RETURN (CLASS::*function)(ARGS...) const, std::nullptr_t, fnCallbackConst post) : _pre_callback(nullptr), _post_callback(reinterpret_cast(post)), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _vtbl_index(GetVtableIndex(function)), _in_deletion(false) { } @@ -1331,155 +1511,212 @@ class Virtual : public Hook { Virtual(RETURN (CLASS::*function)(ARGS...) const, fnCallbackConst pre, std::nullptr_t) : _pre_callback(reinterpret_cast(pre)), _post_callback(nullptr), - _context(nullptr), - _context_pre_callback(nullptr), - _context_post_callback(nullptr), _vtbl_index(GetVtableIndex(function)), _in_deletion(false) { } - - // CTOR - No Function - Context + + // CTOR - Function - Context template - Virtual(CONTEXT* context, fnContextCallback pre, fnContextCallback post) : + Virtual(RETURN (CLASS::*function)(ARGS...), CONTEXT* context, fnContextCallback pre, fnContextCallback post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(ExtractMFP(post)), - _vtbl_index(INVALID_VTBL_INDEX), + _vtbl_index(GetVtableIndex(function)), _in_deletion(false) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; } - - // CTOR - CONST - No Function - Context + + // CTOR - CONST - Function - Context template - Virtual(CONTEXT* context, fnContextCallbackConst pre, fnContextCallbackConst post) : + Virtual(RETURN (CLASS::*function)(ARGS...) const, CONTEXT* context, fnContextCallbackConst pre, fnContextCallbackConst post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(ExtractMFP(post)), - _vtbl_index(INVALID_VTBL_INDEX), + _vtbl_index(GetVtableIndex(function)), _in_deletion(false) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; } - - // CTOR - No function - Context - NULL PRE + + // CTOR - Function - Context - NULL PRE template - Virtual(CONTEXT* context, std::nullptr_t, fnContextCallback post) : + Virtual(RETURN (CLASS::*function)(ARGS...), CONTEXT* context, std::nullptr_t, fnContextCallback post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(nullptr), - _context_post_callback(ExtractMFP(post)), - _vtbl_index(INVALID_VTBL_INDEX), + _vtbl_index(GetVtableIndex(function)), _in_deletion(false) { + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; } - // CTOR - CONST - No function - Context - NULL PRE + // CTOR - CONST - Function - Context - NULL PRE template - Virtual(CONTEXT* context, std::nullptr_t, fnContextCallbackConst post) : + Virtual(RETURN (CLASS::*function)(ARGS...) const, CONTEXT* context, std::nullptr_t, fnContextCallbackConst post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(nullptr), - _context_post_callback(ExtractMFP(post)), - _vtbl_index(INVALID_VTBL_INDEX), + _vtbl_index(GetVtableIndex(function)), _in_deletion(false) { + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; } - // CTOR - No function - Context - NULL POST + // CTOR - Function - Context - NULL POST template - Virtual(CONTEXT* context, fnContextCallback pre, std::nullptr_t) : + Virtual(RETURN (CLASS::*function)(ARGS...), CONTEXT* context, fnContextCallback pre, std::nullptr_t) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(nullptr), - _vtbl_index(INVALID_VTBL_INDEX), + _vtbl_index(GetVtableIndex(function)), _in_deletion(false) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; } - - // CTOR - CONST - No function - Context - NULL POST + + // CTOR - CONST - Function - Context - NULL POST template - Virtual(CONTEXT* context, fnContextCallbackConst pre, std::nullptr_t) : + Virtual(RETURN (CLASS::*function)(ARGS...) const, CONTEXT* context, fnContextCallbackConst pre, std::nullptr_t) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(nullptr), - _vtbl_index(INVALID_VTBL_INDEX), + _vtbl_index(GetVtableIndex(function)), _in_deletion(false) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; } - - // CTOR - Function - Context - template - Virtual(RETURN (CLASS::*function)(ARGS...), CONTEXT* context, fnContextCallback pre, fnContextCallback post) : + + // CTOR - VTable index + Virtual(std::uint32_t index, fnCallback pre, fnCallback post) : + _pre_callback(pre), + _post_callback(post), + _vtbl_index(index), + _in_deletion(false) { + } + + // CTOR - VTable index - NULL PRE + Virtual(std::uint32_t index, std::nullptr_t, fnCallback post) : _pre_callback(nullptr), + _post_callback(post), + _vtbl_index(index), + _in_deletion(false) { + } + + // CTOR - VTable index - NULL POST + Virtual(std::uint32_t index, fnCallback pre, std::nullptr_t) : + _pre_callback(pre), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(ExtractMFP(post)), - _vtbl_index(GetVtableIndex(function)), + _vtbl_index(index), _in_deletion(false) { } - // CTOR - CONST - Function - Context + // CTOR - CONST - VTable index + Virtual(std::uint32_t index, fnCallbackConst pre, fnCallbackConst post) : + _pre_callback(reinterpret_cast(pre)), + _post_callback(reinterpret_cast(post)), + _vtbl_index(index), + _in_deletion(false) { + } + + // CTOR - CONST - VTable index - NULL PRE + Virtual(std::uint32_t index, std::nullptr_t, fnCallbackConst post) : + _pre_callback(nullptr), + _post_callback(reinterpret_cast(post)), + _vtbl_index(index), + _in_deletion(false) { + } + + // CTOR - CONST - VTable index - NULL POST + Virtual(std::uint32_t index, fnCallbackConst pre, std::nullptr_t) : + _pre_callback(reinterpret_cast(pre)), + _post_callback(nullptr), + _vtbl_index(index), + _in_deletion(false) { + } + + // CTOR - VTable Index - Context template - Virtual(RETURN (CLASS::*function)(ARGS...) const, CONTEXT* context, fnContextCallbackConst pre, fnContextCallbackConst post) : + Virtual(std::uint32_t index, CONTEXT* context, fnContextCallback pre, fnContextCallback post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(ExtractMFP(post)), - _vtbl_index(GetVtableIndex(function)), + _vtbl_index(index), _in_deletion(false) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; } - // CTOR - Function - Context - NULL PRE + // CTOR - CONST - VTable Index - Context template - Virtual(RETURN (CLASS::*function)(ARGS...), CONTEXT* context, std::nullptr_t, fnContextCallback post) : + Virtual(std::uint32_t index, CONTEXT* context, fnContextCallbackConst pre, fnContextCallbackConst post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(nullptr), - _context_post_callback(ExtractMFP(post)), - _vtbl_index(GetVtableIndex(function)), + _vtbl_index(index), _in_deletion(false) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; } - // CTOR - CONST - Function - Context - NULL PRE + // CTOR - VTable Index - Context - NULL PRE template - Virtual(RETURN (CLASS::*function)(ARGS...) const, CONTEXT* context, std::nullptr_t, fnContextCallbackConst post) : + Virtual(std::uint32_t index, CONTEXT* context, std::nullptr_t, fnContextCallback post) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(nullptr), - _context_post_callback(ExtractMFP(post)), - _vtbl_index(GetVtableIndex(function)), + _vtbl_index(index), _in_deletion(false) { + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; + } + + // CTOR - CONST - VTable Index - Context - NULL PRE + template + Virtual(std::uint32_t index, CONTEXT* context, std::nullptr_t, fnContextCallbackConst post) : + _pre_callback(nullptr), + _post_callback(nullptr), + _vtbl_index(index), + _in_deletion(false) { + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; } - // CTOR - Function - Context - NULL POST + // CTOR - VTable Index - Context - NULL POST template - Virtual(RETURN (CLASS::*function)(ARGS...), CONTEXT* context, fnContextCallback pre, std::nullptr_t) : + Virtual(std::uint32_t index, CONTEXT* context, fnContextCallback pre, std::nullptr_t) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(nullptr), - _vtbl_index(GetVtableIndex(function)), + _vtbl_index(index), _in_deletion(false) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; } - - // CTOR - CONST - Function - Context - NULL POST + + // CTOR - CONST - VTable Index - Context - NULL POST template - Virtual(RETURN (CLASS::*function)(ARGS...) const, CONTEXT* context, fnContextCallbackConst pre, std::nullptr_t) : + Virtual(std::uint32_t index, CONTEXT* context, fnContextCallbackConst pre, std::nullptr_t) : _pre_callback(nullptr), _post_callback(nullptr), - _context(context), - _context_pre_callback(ExtractMFP(pre)), - _context_post_callback(nullptr), - _vtbl_index(GetVtableIndex(function)), + _vtbl_index(index), _in_deletion(false) { + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; } virtual ~Virtual() { @@ -1495,12 +1732,72 @@ class Virtual : public Hook { } } + template + void AddContext(CONTEXT* context, fnContextCallback pre, std::nullptr_t) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; + } + + template + void AddContext(CONTEXT* context, fnContextCallbackConst pre, std::nullptr_t) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + nullptr + }; + } + + template + void AddContext(CONTEXT* context, fnContextCallback pre, fnContextCallback post) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; + } + + template + void AddContext(CONTEXT* context, fnContextCallbackConst pre, fnContextCallbackConst post) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs[(EmptyClass*)context] = { + ::KHook::BuildMFP>(::KHook::ExtractMFP(pre)), + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; + } + + template + void AddContext(CONTEXT* context, std::nullptr_t, fnContextCallback post) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; + } + + template + void AddContext(CONTEXT* context, std::nullptr_t, fnContextCallbackConst post) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs[(EmptyClass*)context] = { + nullptr, + ::KHook::BuildMFP>(::KHook::ExtractMFP(post)) + }; + } + + template + void RemoveContext(CONTEXT* context) { + std::lock_guard guard(this->_m_context_ptrs); + _context_ptrs.erase((EmptyClass*)context); + } + void Add(CLASS* this_ptr) { { std::lock_guard guard(_m_hooked_this); _hooked_this.insert(this_ptr); } - Configure(*(void***)this_ptr); + _Setup(*(void***)this_ptr); } void Remove(CLASS* this_ptr) { @@ -1510,13 +1807,28 @@ class Virtual : public Hook { } } + void AddGlobal(CLASS* this_ptr) { + { + std::lock_guard guard(_m_hooked_this); + _hooked_global.insert(*(void***)this_ptr); + } + _Setup(*(void***)this_ptr); + } + + void RemoveGlobal(CLASS* this_ptr) { + { + std::lock_guard guard(_m_hooked_this); + _hooked_global.erase(*(void***)this_ptr); + } + } + RETURN CallOriginal(CLASS* this_ptr, ARGS... args) { auto original_func = KHook::FindOriginalVirtual(*(void***)this_ptr, _vtbl_index); - auto mfp = KHook::BuildMFP(original_func); + auto mfp = KHook::BuildMFP(original_func); return (this_ptr->*mfp)(args...); } - void SetIndex(std::int32_t index) { + void Configure(std::int32_t index) { if (_vtbl_index == index) { return; } @@ -1536,13 +1848,43 @@ class Virtual : public Hook { } _vtbl_index = index; } + + void Configure(RETURN (CLASS::*function)(ARGS...)) { + std::int32_t index = KHook::GetVtableIndex(function); + if (index == -1) { + return; + } + Configure(index); + } + + void Configure(RETURN (CLASS::*function)(ARGS...) const) { + std::int32_t index = KHook::GetVtableIndex(function); + if (index == -1) { + return; + } + Configure(index); + } + + bool IsActive() { + std::lock_guard guard(this->_m_hooked_this); + return _hooked_this.size() != 0 || _hooked_global.size() != 0; + } + + void ClearHooks() { + std::lock_guard guard(this->_m_hooked_this); + _hooked_this.clear(); + _hooked_global.clear(); + } protected: - // Various filters to make MemberHook class useful fnCallback _pre_callback; fnCallback _post_callback; - void* _context; - void* _context_pre_callback; - void* _context_post_callback; + + struct __context_details { + __mfp__, CLASS*, ARGS...> pre; + __mfp__, CLASS*, ARGS...> post; + }; + std::mutex _m_context_ptrs; + std::unordered_map _context_ptrs; std::int32_t _vtbl_index; @@ -1553,17 +1895,20 @@ class Virtual : public Hook { std::mutex _m_hooked_this; std::unordered_set _hooked_this; + std::unordered_set _hooked_global; // Called by KHook - void _KHook_RemovedHook(HookID_t id) { - std::lock_guard guard(_hooks_stored); - auto it = _hook_ids_addr.find(id); - if (it != _hook_ids_addr.end()) { - _addr_hook_ids.erase(it->second); + static void _KHook_RemovedHook(HookID_t id) { + auto ctx = reinterpret_cast(KHook::GetContext()); + + std::lock_guard guard(ctx->_hooks_stored); + auto it = ctx->_hook_ids_addr.find(id); + if (it != ctx->_hook_ids_addr.end()) { + ctx->_addr_hook_ids.erase(it->second); } } - void Configure(void** vtable) { + void _Setup(void** vtable) { if (vtable == nullptr || _in_deletion || _vtbl_index == INVALID_VTBL_INDEX) { return; } @@ -1571,7 +1916,7 @@ class Virtual : public Hook { { std::lock_guard guard(_hooks_stored); // Retrieve the hookID with this vtable if it exists - if (_addr_hook_ids.find(vtable) != _addr_hook_ids.end()) { + if (_addr_hook_ids.find(vtable + _vtbl_index) != _addr_hook_ids.end()) { // Already hooked so ignore return; } @@ -1581,17 +1926,18 @@ class Virtual : public Hook { vtable, _vtbl_index, this, - ExtractMFP(&Self::_KHook_RemovedHook), + (void*)&Self::_KHook_RemovedHook, ExtractMFP(&Self::_KHook_Callback_PRE), // preMFP ExtractMFP(&Self::_KHook_Callback_POST), // postMFP ExtractMFP(&Self::_KHook_MakeReturn), // returnMFP, ExtractMFP(&Self::_KHook_MakeOriginalCall), // callOriginalMFP + Self::template _copy_stack_size(), true // For safety reasons we are adding hooks asynchronously. If performance is required, reimplement this class ); if (id != INVALID_HOOK) { std::lock_guard guard(_hooks_stored); - _hook_ids_addr[id] = vtable[_vtbl_index]; - _addr_hook_ids[vtable[_vtbl_index]] = id; + _hook_ids_addr[id] = vtable + _vtbl_index; + _addr_hook_ids[vtable + _vtbl_index] = id; } } @@ -1599,19 +1945,47 @@ class Virtual : public Hook { void _KHook_Callback_Fixed(bool post, CLASS* hooked_this, ARGS... args) { { std::lock_guard guard(this->_m_hooked_this); + // Did we hook this ptr if (_hooked_this.find(hooked_this) == _hooked_this.end()) { - return; + // This is perhaps a global hook instead + if (_hooked_global.find(*(void***)hooked_this) == _hooked_global.end()) { + return; + } } } - fnContextCallback context_callback = KHook::BuildMFP, CLASS*, ARGS...>((post) ? this->_context_post_callback : this->_context_pre_callback); - auto callback = (post) ? this->_post_callback : this->_pre_callback; + KHook::Return action; + action.action = KHook::Action::Ignore; - if (callback == nullptr && context_callback == nullptr) { - return; + if (post && this->_post_callback) { + action = (*this->_post_callback)(hooked_this, args...); + } else if (!post && this->_pre_callback) { + action = (*this->_pre_callback)(hooked_this, args...); } - Return action = (_context) ? (((EmptyClass*)_context)->*context_callback)(hooked_this, args...) : (*callback)(hooked_this, args...); + { + decltype(this->_context_ptrs) copied_ctxs; + { + // This can deadlock (in case of recalls) + // so make a deep-copy + std::lock_guard guard(this->_m_context_ptrs); + copied_ctxs = this->_context_ptrs; + } + for (const auto& context : copied_ctxs) { + auto context_ptr = context.first; + if (post && context.second.post) { + auto new_action = (context_ptr->*(context.second.post))(hooked_this, args...); + if (new_action.action > action.action) { + action = new_action; + } + } else if (!post && context.second.pre) { + auto new_action = (context_ptr->*(context.second.pre))(hooked_this, args...); + if (new_action.action > action.action) { + action = new_action; + } + } + } + } ::KHook::__internal__savereturnvalue(action, false); } @@ -1651,7 +2025,7 @@ class Virtual : public Hook { // Called if the hook wasn't superceded RETURN _KHook_MakeOriginalCall(ARGS ...args) { - RETURN (EmptyClass::*ptr)(ARGS...) = BuildMFP(::KHook::GetOriginalFunction()); + auto ptr = ::KHook::BuildMFP(::KHook::GetOriginalFunction()); if constexpr(std::is_same::value) { (((EmptyClass*)this)->*ptr)(args...); ::KHook::__internal__savereturnvalue(KHook::Return{ KHook::Action::Ignore }, true); @@ -1719,8 +2093,9 @@ struct __MFPInfo__ std::intptr_t delta; }; -template -inline std::int32_t GetVtableIndex(RETURN (CLASS::*function)(ARGS...)) { +template +inline std::int32_t GetVtableIndex(FUNC function) { + static_assert(std::is_member_function_pointer::value, "Error: FUNC is not a member function pointer!"); #ifdef _WIN32 return __GetVtableIndex__(reinterpret_cast(ExtractMFP(function))); #else @@ -1732,21 +2107,10 @@ inline std::int32_t GetVtableIndex(RETURN (CLASS::*function)(ARGS...)) { #endif } -template -inline std::int32_t GetVtableIndex(RETURN (CLASS::*function)(ARGS...) const) { -#ifdef _WIN32 - return __GetVtableIndex__(reinterpret_cast(ExtractMFP(function))); -#else - __MFPInfo__* info = (__MFPInfo__*)&function; - if (info->vtbl_index & 1) { - return (info->vtbl_index - 1) / sizeof(void*); - } - return -1; -#endif -} +template +inline std::invoke_result_t __MFP__CallOriginal(F function, CLASS&& this_ptr, ARGS&&... args) { + F dummy_func = nullptr; -template -inline RETURN CallOriginal(RETURN (CLASS::*function)(ARGS...), CLASS* this_ptr, ARGS... args) { auto vtbl_index = ::KHook::GetVtableIndex(function); void* func = nullptr; if (vtbl_index != -1) { @@ -1755,42 +2119,24 @@ inline RETURN CallOriginal(RETURN (CLASS::*function)(ARGS...), CLASS* this_ptr, else { func = ::KHook::FindOriginal(::KHook::ExtractMFP(function)); } - auto mfp = ::KHook::BuildMFP(func); - return (this_ptr->*mfp)(args...); + ::KHook::FillMFP(&dummy_func, func); + return (this_ptr->*dummy_func)(std::forward(args)...); } -template -inline RETURN CallOriginal(RETURN (CLASS::*function)(ARGS...) const, const CLASS* this_ptr, ARGS... args) { - auto vtbl_index = ::KHook::GetVtableIndex(function); - const void* func = nullptr; - if (vtbl_index != -1) { - func = ::KHook::FindOriginalVirtual(*(void***)this_ptr, vtbl_index); - } - else { - func = ::KHook::FindOriginal(::KHook::ExtractMFP(function)); +template +inline std::invoke_result_t CallOriginal(F f, ARGS&&... args) { + if constexpr (std::is_member_function_pointer::value) { + return ::KHook::__MFP__CallOriginal(f, std::forward(args)...); + } else { + F function = (decltype(f))::KHook::FindOriginal(f); + return (*function)(std::forward(args)...); } - auto mfp = ::KHook::BuildMFP(func); - return (this_ptr->*mfp)(args...); -} - -template -inline RETURN CallOriginal(void* func, CLASS* this_ptr, ARGS... args) { - func = ::KHook::FindOriginal(func); - auto mfp = ::KHook::BuildMFP(func); - return (this_ptr->*mfp)(args...); -} - -template -inline RETURN CallOriginal(const void* func, const CLASS* this_ptr, ARGS... args) { - func = (const void*)::KHook::FindOriginal((void*)func); - auto mfp = ::KHook::BuildMFP(func); - return (this_ptr->*mfp)(args...); } class IKHook { public: - virtual HookID_t SetupHook(void* function, void* context, void* removed_function, void* pre, void* post, void* make_return, void* make_call_original, bool async = false) = 0; - virtual HookID_t SetupVirtualHook(void** vtable, int index, void* context, void* removed_function, void* pre, void* post, void* make_return, void* make_call_original, bool async = false) = 0; + virtual HookID_t SetupHook(void* function, void* context, void* removed_function, void* pre, void* post, void* make_return, void* make_call_original, unsigned int stack_size, bool async = false) = 0; + virtual HookID_t SetupVirtualHook(void** vtable, int index, void* context, void* removed_function, void* pre, void* post, void* make_return, void* make_call_original, unsigned int stack_size, bool async = false) = 0; virtual void RemoveHook(HookID_t id, bool async = false) = 0; virtual void* GetContext() = 0; virtual void* GetOriginalFunction() = 0; @@ -1802,12 +2148,13 @@ class IKHook { virtual void* FindOriginalVirtual(void** vtable, int index) = 0; virtual void* DoRecall(KHook::Action action, void* ptr_to_return, std::size_t return_size, void* init_op, void* deinit_op) = 0; virtual void SaveReturnValue(KHook::Action action, void* ptr_to_return, std::size_t return_size, void* init_op, void* deinit_op, bool original) = 0; + virtual void* LookupSignature(void* start, std::size_t size, const char* signature) = 0; }; #ifndef KHOOK_STANDALONE // KHOOK is exposed by something extern IKHook* __exported__khook; -KHOOK_API HookID_t SetupHook(void* function, void* context, void* removed_function, void* pre, void* post, void* make_return, void* make_call_original, bool async) { +KHOOK_API HookID_t SetupHook(void* function, void* context, void* removed_function, void* pre, void* post, void* make_return, void* make_call_original, unsigned int stack_size, bool async) { // For some hooks this is too early if (__exported__khook == nullptr) { std::cout << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"; @@ -1818,10 +2165,10 @@ KHOOK_API HookID_t SetupHook(void* function, void* context, void* removed_functi std::cerr << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"; return INVALID_HOOK; } - return __exported__khook->SetupHook(function, context, removed_function, pre, post, make_return, make_call_original, async); + return __exported__khook->SetupHook(function, context, removed_function, pre, post, make_return, make_call_original, stack_size, async); } -KHOOK_API HookID_t SetupVirtualHook(void** vtable, int index, void* context, void* removed_function, void* pre, void* post, void* make_return, void* make_call_original, bool async) { +KHOOK_API HookID_t SetupVirtualHook(void** vtable, int index, void* context, void* removed_function, void* pre, void* post, void* make_return, void* make_call_original, unsigned int stack_size, bool async) { // For some hooks this is too early if (__exported__khook == nullptr) { std::cout << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"; @@ -1832,7 +2179,7 @@ KHOOK_API HookID_t SetupVirtualHook(void** vtable, int index, void* context, voi std::cerr << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"; return INVALID_HOOK; } - return __exported__khook->SetupVirtualHook(vtable, index, context, removed_function, pre, post, make_return, make_call_original, async); + return __exported__khook->SetupVirtualHook(vtable, index, context, removed_function, pre, post, make_return, make_call_original, stack_size, async); } KHOOK_API void RemoveHook(HookID_t id, bool async) { @@ -1879,6 +2226,10 @@ KHOOK_API void SaveReturnValue(KHook::Action action, void* ptr_to_return, std::s return __exported__khook->SaveReturnValue(action, ptr_to_return, return_size, init_op, deinit_op, original); } +KHOOK_API void* LookupSignature(void* start, std::size_t size, const char* signature) { + return __exported__khook->LookupSignature(start, size, signature); +} + #endif } diff --git a/include/khook/asm.hpp b/include/khook/asm.hpp index 7138f94..b4f5708 100644 --- a/include/khook/asm.hpp +++ b/include/khook/asm.hpp @@ -317,9 +317,8 @@ namespace KHook a waste of virtual address space (Windows’ VirtualAlloc has a granularity of 64K). - IMPORTANT: the memory that Alloc() returns is not a in a defined state! - It could be in read+exec OR read+write mode. - -> call SetRE() or SetRW() before using allocated memory! + Memory that Alloc() returns is mapped read+write+execute, so it can be + written to (code generation) and executed without any further protection changes. */ class CPageAlloc { @@ -346,7 +345,6 @@ namespace KHook bool isolated; std::size_t minAlignment; AUList allocUnits; - bool isRE; void CheckGap(std::size_t gap_begin, std::size_t gap_end, std::size_t reqsize, std::size_t &smallestgap_pos, std::size_t &smallestgap_size, std::size_t &outAlignBytes) @@ -431,22 +429,11 @@ namespace KHook void DebugCleanMemory(unsigned char* start, size_t size) { - bool wasRE = isRE; - if (isRE) - { - SetRW(); - } - unsigned char* end = start + size; for (unsigned char* p = start; p != end; ++p) { *p = 0xCC; } - - if (wasRE) - { - SetRE(); - } } bool Contains(void *addr) @@ -462,18 +449,6 @@ namespace KHook munmap(startPtr, size); #endif } - - void SetRE() - { - Memory::SetAccess(startPtr, size, Memory::Flags::READ | Memory::Flags::EXECUTE); - isRE = true; - } - - void SetRW() - { - Memory::SetAccess(startPtr, size, Memory::Flags::READ | Memory::Flags::WRITE); - isRE = false; - } }; typedef List ARList; @@ -496,14 +471,15 @@ namespace KHook newRegion.size += m_PageSize; #ifdef _WIN32 - newRegion.startPtr = VirtualAlloc(nullptr, newRegion.size, MEM_COMMIT, PAGE_READWRITE); + newRegion.startPtr = VirtualAlloc(nullptr, newRegion.size, MEM_COMMIT, PAGE_EXECUTE_READWRITE); #else - newRegion.startPtr = mmap(0, newRegion.size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0); + newRegion.startPtr = mmap(0, newRegion.size, PROT_READ | PROT_WRITE | PROT_EXEC, MAP_PRIVATE | MAP_ANON, -1, 0); + if (newRegion.startPtr == MAP_FAILED) + newRegion.startPtr = nullptr; #endif if (newRegion.startPtr) { - newRegion.SetRW(); m_Regions.push_back(newRegion); return true; } @@ -582,30 +558,6 @@ namespace KHook } } - void SetRE(void *ptr) - { - for (ARList::iterator iter = m_Regions.begin(); iter != m_Regions.end(); ++iter) - { - if (iter->Contains(ptr)) - { - iter->SetRE(); - break; - } - } - } - - void SetRW(void *ptr) - { - for (ARList::iterator iter = m_Regions.begin(); iter != m_Regions.end(); ++iter) - { - if (iter->Contains(ptr)) - { - iter->SetRW(); - break; - } - } - } - std::size_t GetPageSize() { return m_PageSize; @@ -647,10 +599,6 @@ namespace KHook m_AllocatedSize = 0; } - void SetRE() { - Allocator.SetRE(reinterpret_cast(m_pData)); - } - operator void *() { return reinterpret_cast(GetData()); } @@ -687,7 +635,6 @@ namespace KHook unsigned char *newBuf; newBuf = reinterpret_cast(Allocator.Alloc(m_AllocatedSize)); - Allocator.SetRW(newBuf); if (!newBuf) { assertm(false, "bad_alloc: couldn't allocate new bytes of memory\n"); return; @@ -695,8 +642,6 @@ namespace KHook std::memset((void*)newBuf, 0xCC, m_AllocatedSize); // :TODO: remove this ! std::memcpy((void*)newBuf, (const void*)m_pData, m_Size); if (m_pData) { - Allocator.SetRE(reinterpret_cast(m_pData)); - Allocator.SetRW(newBuf); Allocator.Free(reinterpret_cast(m_pData)); } m_pData = newBuf; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ff36fea..26eb0dc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,5 +1,6 @@ add_library(khook_lib SHARED "detour.cpp" + "ranges.cpp" ) target_include_directories(khook_lib PUBLIC diff --git a/src/detour.cpp b/src/detour.cpp index 20d3fbf..ca90ae2 100644 --- a/src/detour.cpp +++ b/src/detour.cpp @@ -1,4 +1,5 @@ #include "detour.hpp" +#include "ranges.hpp" #include #include @@ -7,8 +8,20 @@ namespace KHook { using namespace KHook::Asm; - -#define STACK_SAFETY_BUFFER 112 +#define STACK_SAFETY_BUFFER 128 + +#if defined(KHOOK_TESTS) || defined(KHOOK_DEBUG_PRINT) +#define DEBUG_PRINT(...) \ + printf(__VA_ARGS__); +#define DEBUG_ABORT_PRINT(...) \ + printf(__VA_ARGS__); \ + fflush(stdout); \ + ::std::this_thread::sleep_for(std::chrono::milliseconds(5)); \ + ::std::abort(); +#else +#define DEBUG_PRINT(...) +#define DEBUG_ABORT_PRINT(...) ::std::abort(); +#endif #ifdef KHOOK_X64 #define FUNCTION_ATTRIBUTE_PREFIX(ret) ret @@ -89,7 +102,7 @@ static FUNCTION_ATTRIBUTE_PREFIX(void) RecursiveLockUnlockShared(std::shared_mut } } else { if (it->second == 0) { - std::abort(); + DEBUG_ABORT_PRINT("Recursive mutex for detour was unlocked too many times!\n") } it->second--; @@ -100,12 +113,19 @@ static FUNCTION_ATTRIBUTE_PREFIX(void) RecursiveLockUnlockShared(std::shared_mut } } +enum class CallOriginalState : std::uintptr_t { + NotReady = 50, + InProgress, + Complete, + CompleteSkipped +}; + struct AsmLoopDetails { // Current iterated hook std::uintptr_t linked_list_it; std::uintptr_t pre_loop_started; std::uintptr_t pre_loop_over; - std::uintptr_t original_call_over; + CallOriginalState original_call_state; std::uintptr_t post_loop_over; std::uintptr_t post_loop_started; std::uintptr_t recall_count; @@ -150,14 +170,14 @@ static thread_local AsmLoopDetails g_last_loop; static FUNCTION_ATTRIBUTE_PREFIX(void) EndDetour(AsmLoopDetails* loop, bool no_callback) FUNCTION_ATTRIBUTE_SUFFIX { if (g_saved_params.top() != loop || g_is_in_recall) { // Something went horribly wrong with the stack - std::abort(); + DEBUG_ABORT_PRINT("Call stack is corrupted!\n") } if (no_callback) { if (loop->recall_count != 0) { // If this is a recall, and we somehow have no callback then something horribly wrong happened // Terminate the program right now - std::abort(); + DEBUG_ABORT_PRINT("A detour recall happened with no callbacks!\n") } RecursiveLockUnlockShared(&loop->capsule->_detour_mutex, false); // Detour was early ended, unlock the mutex and pop the asm details @@ -165,6 +185,13 @@ static FUNCTION_ATTRIBUTE_PREFIX(void) EndDetour(AsmLoopDetails* loop, bool no_c } else { // Natural end of a detour, setup everything because AsmLoopDetails is about to go invalid (due to stack being freed) g_last_loop = *loop; + + // Ensure the original function was skipped OR executed, no in-between state + if (loop->original_call_state != CallOriginalState::Complete + && loop->original_call_state != CallOriginalState::CompleteSkipped) { + DEBUG_ABORT_PRINT("Detour was left in unknown state %d\n", loop->original_call_state) + } + if (loop->recall_count != 0) { RecursiveLockUnlockShared(&loop->capsule->_detour_mutex, false); } @@ -192,7 +219,7 @@ static FUNCTION_ATTRIBUTE_PREFIX(AsmLoopDetails*) BeginDetour( if (capsule != loop->capsule) { // Not the same detour somehow - std::abort(); + DEBUG_ABORT_PRINT("Different detour capsule (should be impossible)!\n") } if (loop->pre_loop_over == false) { @@ -203,10 +230,6 @@ static FUNCTION_ATTRIBUTE_PREFIX(AsmLoopDetails*) BeginDetour( if (loop->linked_list_it == 0x0) { loop->pre_loop_over = true; } - } else if (loop->original_call_over == false) { - // Recall happened in the original call, this is technically impossible - // But we're going to support it anywways, this will avoid infinite-loops - loop->original_call_over = true; } else if (loop->post_loop_over == false) { // Recall happened in a post-loop, move to next iterator auto hook = reinterpret_cast(loop->linked_list_it); @@ -217,7 +240,7 @@ static FUNCTION_ATTRIBUTE_PREFIX(AsmLoopDetails*) BeginDetour( } } else { // A recall happened outside of a hook - std::abort(); + DEBUG_ABORT_PRINT("Recall outside hook callbacks!\n") } @@ -233,7 +256,7 @@ static FUNCTION_ATTRIBUTE_PREFIX(AsmLoopDetails*) BeginDetour( new_loop->linked_list_it = 0x0; new_loop->pre_loop_over = false; new_loop->pre_loop_started = false; - new_loop->original_call_over = false; + new_loop->original_call_state = CallOriginalState::NotReady; new_loop->post_loop_over = false; new_loop->post_loop_started = false; new_loop->recall_count = 0; @@ -284,7 +307,7 @@ static FUNCTION_ATTRIBUTE_PREFIX(void) PushRsp(std::uintptr_t rsp) FUNCTION_ATTR static FUNCTION_ATTRIBUTE_PREFIX(std::uintptr_t) PeekRsp(std::uintptr_t rsp) FUNCTION_ATTRIBUTE_SUFFIX { auto internal_rsp = rsp_values.top(); - assert((internal_rsp + STACK_SAFETY_BUFFER) > rsp); + //assert((internal_rsp + STACK_SAFETY_BUFFER) > rsp); return internal_rsp; } @@ -296,6 +319,7 @@ static FUNCTION_ATTRIBUTE_PREFIX(std::uintptr_t) PeekRbp(std::uintptr_t rsp) FUN return reinterpret_cast(g_saved_params.top()); } +#ifdef KHOOK_DEBUG_PRINT static FUNCTION_ATTRIBUTE_PREFIX(void) PrintRSP(std::uintptr_t rsp) FUNCTION_ATTRIBUTE_SUFFIX { #ifdef KHOOK_X64 printf("RSP/ESP : 0x%lX\n", rsp); @@ -309,7 +333,9 @@ static FUNCTION_ATTRIBUTE_PREFIX(void) PrintRSP(std::uintptr_t rsp) FUNCTION_ATT << std::endl; }*/ } +#endif +#ifdef KHOOK_DEBUG_PRINT static FUNCTION_ATTRIBUTE_PREFIX(void) PrintRegister(std::uintptr_t reg, const char* name) FUNCTION_ATTRIBUTE_SUFFIX { #ifdef KHOOK_X64 printf("%s : 0x%lX\n", name, reg); @@ -317,7 +343,9 @@ static FUNCTION_ATTRIBUTE_PREFIX(void) PrintRegister(std::uintptr_t reg, const c printf("%s : 0x%X\n", name, reg); #endif } +#endif +#ifdef KHOOK_DEBUG_PRINT static FUNCTION_ATTRIBUTE_PREFIX(void) PrintEntryExitRSP(std::uintptr_t rsp, bool entry) FUNCTION_ATTRIBUTE_SUFFIX { #ifdef KHOOK_X64 //printf("%s RSP/ESP : 0x%lX\n", (entry) ? "ENTRY" : "EXIT", rsp); @@ -325,6 +353,7 @@ static FUNCTION_ATTRIBUTE_PREFIX(void) PrintEntryExitRSP(std::uintptr_t rsp, boo //printf("%s RSP/ESP : 0x%X\n", (entry) ? "ENTRY" : "EXIT", rsp); #endif } +#endif KHOOK_API void* GetContext() { return g_current_hook.top(); @@ -337,9 +366,8 @@ KHOOK_API void SaveReturnValue(KHook::Action action, void* ptr_to_return, std::s auto loop = g_saved_params.top(); if (original) { // Save original value - if (loop->original_return_ptr != 0) { - // Value has already been saved, what the fuck - std::abort(); + if (loop->original_call_state != CallOriginalState::InProgress) { + DEBUG_ABORT_PRINT("Attempting to save original return value, outside of original call window!\n") } if (return_size != 0) { auto new_return = new std::uint8_t[return_size]; @@ -349,6 +377,7 @@ KHOOK_API void SaveReturnValue(KHook::Action action, void* ptr_to_return, std::s init_copy_return fn = reinterpret_cast(init_op); (*fn)(new_return, ptr_to_return); } + loop->original_call_state = CallOriginalState::Complete; } if (action > (KHook::Action)loop->action) { loop->action = (std::uintptr_t)action; @@ -360,9 +389,8 @@ KHOOK_API void SaveReturnValue(KHook::Action action, void* ptr_to_return, std::s // Free it delete[] reinterpret_cast(loop->override_return_ptr); - if (return_size != 0) { - // What are you doing ????? - std::abort(); + if (return_size == 0) { + DEBUG_ABORT_PRINT("Attempting to save a 0 sized return value as override!\n") } } if (return_size != 0) { @@ -438,13 +466,17 @@ KHOOK_API void* GetCurrentValuePtr(bool pop) { } } +KHOOK_API bool WasOriginalFunctionSkipped() { + return g_saved_params.top()->original_call_state == CallOriginalState::CompleteSkipped; +} + /*void memcpy_detour(std::uintptr_t dst, std::uintptr_t src, std::uintptr_t size) { std::cout << std::hex << "dst 0x" << dst << " src 0x" << src << " size 0x" << size << std::endl; }*/ void memcpy_debug(void* dest, const void* src, std::size_t count) { //printf("dst: %p src: %p\n", dest, src); - float* fstack = reinterpret_cast(dest); + //float* fstack = reinterpret_cast(dest); memcpy(dest, src, count); /*printf("Dest ESP: %p | Src ESP: %p\n", dest, src); for (int i = 0; i < 10; i++) { @@ -494,13 +526,15 @@ void copy_stack(DetourCapsule::AsmJit& jit, std::int32_t offset, std::int32_t st #endif } -DetourCapsule::DetourCapsule() : +DetourCapsule::DetourCapsule(std::uint32_t stack_size) : _in_deletion(false), _start_callbacks(nullptr), _end_callbacks(nullptr), _jit_func_ptr(0), _original_function(0), - _stack_size(STACK_SAFETY_BUFFER) { + _stack_size(((stack_size + 0xF) & ~0xF)) { + DEBUG_PRINT("DetourCapsule::ctor(_stack_size: %hd)\n", _stack_size) + // Because we want to be call agnostic we must get clever // No register can be used to call a function, so here's the plan // mov rax, 0xStart Address of JIT function @@ -514,8 +548,8 @@ DetourCapsule::DetourCapsule() : #ifdef KHOOK_X64 using namespace Asm; - static auto print_register = [](DetourCapsule::AsmJit& jit, x86_64_Reg reg, const char* name) { #ifdef KHOOK_DEBUG_PRINT + static auto print_register = [](DetourCapsule::AsmJit& jit, x86_64_Reg reg, const char* name) { WIN_ONLY(jit.sub(rsp, 32)); LINUX_ONLY(jit.mov(rdi, reg)); @@ -528,11 +562,10 @@ DetourCapsule::DetourCapsule() : jit.call(rax); WIN_ONLY(jit.add(rsp, 32)); -#endif }; - - static auto print_rsp = [](DetourCapsule::AsmJit& jit, std::uint32_t offset = 0) { +#endif #ifdef KHOOK_DEBUG_PRINT + static auto print_rsp = [](DetourCapsule::AsmJit& jit, std::uint32_t offset = 0) { WIN_ONLY(jit.sub(rsp, 32)); jit.push(rdi); @@ -556,8 +589,8 @@ DetourCapsule::DetourCapsule() : jit.pop(rdi); WIN_ONLY(jit.add(rsp, 32)); -#endif }; +#endif static auto begin_detour = [](DetourCapsule::AsmJit& jit, std::uint32_t offset_to_loop_params, std::uint32_t offset_to_regs, std::uint32_t offset_to_stack, std::int32_t stack_size, DetourCapsule* capsule) { WIN_ONLY(static constexpr size_t shadowspace = 48); @@ -849,13 +882,15 @@ DetourCapsule::DetourCapsule() : // Call original (maybe) // RBP which we have set much earlier still contains our local variables // it should have been saved across all calls as per linux & win callconvs - _jit.mov(rax, rbp(offsetof(AsmLoopDetails, original_call_over))); - _jit.test(rax, rax); - _jit.jnz(INT32_MAX);{auto jnz = _jit.get_outputpos(); { + _jit.mov(rax, rbp(offsetof(AsmLoopDetails, original_call_state))); + _jit.cmp(rax, (std::int32_t)CallOriginalState::NotReady); + _jit.jne(INT32_MAX);{auto jnz = _jit.get_outputpos(); { + _jit.mov(rbp(offsetof(AsmLoopDetails, original_call_state)), (std::uint32_t)CallOriginalState::CompleteSkipped); _jit.mov(rax, rbp(offsetof(AsmLoopDetails, action))); _jit.cmp(rax, (std::int32_t)Action::Supersede); _jit.je(INT32_MAX); auto if_not_supersede = _jit.get_outputpos(); { + _jit.mov(rbp(offsetof(AsmLoopDetails, original_call_state)), (std::uint32_t)CallOriginalState::InProgress); // MAKE ORIGINAL CALL _jit.mov(rax, reinterpret_cast(&_jit_func_ptr)); _jit.mov(rax, rax()); @@ -873,12 +908,11 @@ DetourCapsule::DetourCapsule() : _jit.rewrite(make_pre_call_return - sizeof(std::uint32_t), _jit.get_outputpos()); peek_rsp(_jit); peek_rbp(_jit); + _jit.mov(rbp(offsetof(AsmLoopDetails, original_call_state)), (std::uint32_t)CallOriginalState::Complete); } _jit.rewrite(if_not_supersede - sizeof(std::int32_t), _jit.get_outputpos() - if_not_supersede); } _jit.rewrite(jnz - sizeof(std::int32_t), _jit.get_outputpos() - jnz);} - // Call original is over - _jit.mov(rbp(offsetof(AsmLoopDetails, original_call_over)), true); //print_register(_jit, rbp, "POST-RBP"); // Prelude to POST LOOP @@ -981,9 +1015,8 @@ DetourCapsule::DetourCapsule() : jit.pop(eax); // +4 #endif }; - - static auto print_entry_rsp = [](DetourCapsule::AsmJit& jit, bool b) { #ifdef KHOOK_DEBUG_PRINT + static auto print_entry_rsp = [](DetourCapsule::AsmJit& jit, bool b) { jit.push(eax); // -4 jit.lea(eax, esp(4)); @@ -995,8 +1028,8 @@ DetourCapsule::DetourCapsule() : jit.add(esp, sizeof(void*) * 2); // +8 jit.pop(eax); // +4 -#endif }; +#endif static auto begin_detour = [](DetourCapsule::AsmJit& jit, std::uint32_t offset_to_loop_params, std::uint32_t offset_to_regs, std::uint32_t offset_to_stack, std::int32_t stack_size, DetourCapsule* capsule) { auto param_size = sizeof(void*) * 7; @@ -1276,13 +1309,15 @@ DetourCapsule::DetourCapsule() : // Call original (maybe) // RBP which we have set much earlier still contains our local variables // it should have been saved across all calls as per linux & win callconvs - _jit.mov(eax, ebp(offsetof(AsmLoopDetails, original_call_over))); - _jit.test(eax, eax); - _jit.jnz(INT32_MAX);{auto jnz = _jit.get_outputpos(); { + _jit.mov(eax, ebp(offsetof(AsmLoopDetails, original_call_state))); + _jit.cmp(eax, (std::int32_t)CallOriginalState::NotReady); + _jit.jne(INT32_MAX);{auto jnz = _jit.get_outputpos(); { + _jit.mov(ebp(offsetof(AsmLoopDetails, original_call_state)), (std::uint32_t)CallOriginalState::CompleteSkipped); _jit.mov(eax, ebp(offsetof(AsmLoopDetails, action))); _jit.cmp(eax, (std::int32_t)Action::Supersede); _jit.je(INT32_MAX); auto if_not_supersede = _jit.get_outputpos(); { + _jit.mov(ebp(offsetof(AsmLoopDetails, original_call_state)), (std::uint32_t)CallOriginalState::InProgress); // MAKE ORIGINAL CALL _jit.mov(eax, reinterpret_cast(&_jit_func_ptr)); _jit.mov(eax, eax()); @@ -1301,12 +1336,11 @@ DetourCapsule::DetourCapsule() : _jit.rewrite(make_pre_call_return - sizeof(std::uint32_t), _jit.get_outputpos()); peek_rsp(_jit); peek_rbp(_jit); + _jit.mov(ebp(offsetof(AsmLoopDetails, original_call_state)), (std::uint32_t)CallOriginalState::Complete); } _jit.rewrite(if_not_supersede - sizeof(std::int32_t), _jit.get_outputpos() - if_not_supersede); } _jit.rewrite(jnz - sizeof(std::int32_t), _jit.get_outputpos() - jnz);} - // Call original is over - _jit.mov(ebp(offsetof(AsmLoopDetails, original_call_over)), true); // Prelude to POST LOOP // Hooks with a post callback are enqueued at the end of linked list @@ -1380,7 +1414,6 @@ DetourCapsule::DetourCapsule() : //_jit.breakpoint(); _jit.retn(); #endif - _jit.SetRE(); void* bridge = _jit; _jit_func_ptr = reinterpret_cast(bridge); } @@ -1395,8 +1428,12 @@ DetourCapsule::~DetourCapsule() { // Iterate through all existing hooks and kill them for (auto& callback : _callbacks) { auto& hook = callback.second; - auto mfp = BuildMFP(reinterpret_cast(hook->hook_fn_remove)); - (((EmptyClass*)(hook->hook_ptr))->*mfp)(callback.first); + if (hook->hook_fn_remove) { + auto fn = reinterpret_cast(hook->hook_fn_remove); + PushPopCurrentHook(reinterpret_cast(hook->hook_ptr), true); + fn(callback.first); + PushPopCurrentHook(reinterpret_cast(hook->hook_ptr), false); + } } _callbacks.clear(); _start_callbacks = nullptr; @@ -1437,9 +1474,6 @@ bool DetourCapsule::InsertHook(HookID_t id, const DetourCapsule::InsertHookDetai } } else { // Okay iterate through the list and add it in the middle - LinkedList* prev = nullptr; - LinkedList* next = nullptr; - LinkedList* curr = _start_callbacks; while (curr && curr->fn_make_post == 0 && curr->next) { curr = curr->next; @@ -1473,7 +1507,7 @@ void DetourCapsule::RemoveHook(HookID_t id) { auto linked_it = _start_callbacks; while (linked_it != hook) { - linked_it = linked_it->next; + linked_it = linked_it->next; } if (linked_it->prev) { @@ -1489,11 +1523,18 @@ void DetourCapsule::RemoveHook(HookID_t id) { if (hook == _end_callbacks) { _end_callbacks = _end_callbacks->prev; } - + + auto remove_fn = hook->hook_fn_remove; + auto ctx_ptr = hook->hook_ptr; + _callbacks.erase(it); - auto mfp = BuildMFP(reinterpret_cast(hook->hook_fn_remove)); - (((EmptyClass*)(hook->hook_ptr))->*mfp)(id); + if (remove_fn) { + auto fn = reinterpret_cast(remove_fn); + PushPopCurrentHook(reinterpret_cast(ctx_ptr), true); + fn(id); + PushPopCurrentHook(reinterpret_cast(ctx_ptr), false); + } } } @@ -1596,6 +1637,7 @@ HookID_t __Setup__Hook( void* post, void* make_return, void* make_call_original, + std::uint32_t stack_size, bool async, bool (DetourCapsule::*setup_hook)(Args...), Args... args @@ -1616,7 +1658,7 @@ HookID_t __Setup__Hook( g_hooks_detour_mutex.unlock_shared(); g_hooks_detour_mutex.lock(); - auto insert = g_hooks_detour.insert_or_assign(unique_identifier, std::make_unique()); + auto insert = g_hooks_detour.insert_or_assign(unique_identifier, std::make_unique(stack_size)); if (insert.second) { auto detour = insert.first->second.get(); // Hook setup failed, so early abort... @@ -1681,6 +1723,7 @@ KHOOK_API HookID_t SetupHook( void* post, void* make_return, void* make_call_original, + unsigned int stack_size, bool async ) { return __Setup__Hook( @@ -1691,6 +1734,7 @@ KHOOK_API HookID_t SetupHook( post, make_return, make_call_original, + stack_size, async, &DetourCapsule::SetupAddress, function @@ -1706,6 +1750,7 @@ KHOOK_API HookID_t SetupVirtualHook( void* post, void* make_return, void* make_call_original, + unsigned int stack_size, bool async ) { return __Setup__Hook( @@ -1716,6 +1761,7 @@ KHOOK_API HookID_t SetupVirtualHook( post, make_return, make_call_original, + stack_size, async, &DetourCapsule::SetupVirtual, vtable, @@ -1734,7 +1780,10 @@ KHOOK_API void RemoveHook( continue; } - // Hook not yet been inserted, remove it right now + // Hook not yet been inserted, remove it right now. + // Capture the remove-callback details before erase frees the node. + auto remove_fn = it->second.hook_fn_remove; + auto ctx_ptr = it->second.hook_ptr; g_insert_hooks.erase(it); // Disassociate from the detour @@ -1744,10 +1793,13 @@ KHOOK_API void RemoveHook( } // Invoke remove callback - auto& hook = it->second; - auto mfp = BuildMFP(reinterpret_cast(hook.hook_fn_remove)); - (((EmptyClass*)(hook.hook_ptr))->*mfp)(id); - return; + if (remove_fn) { + auto fn = reinterpret_cast(remove_fn); + PushPopCurrentHook(reinterpret_cast(ctx_ptr), true); + fn(id); + PushPopCurrentHook(reinterpret_cast(ctx_ptr), false); + } + break; } } @@ -1804,4 +1856,8 @@ KHOOK_API void* FindOriginalVirtual(void** vtable, int index) { return vtable[index]; } +KHOOK_API void* LookupSignature(void* start, std::size_t size, const char* signature) { + return reinterpret_cast(KHook::Ranges::Lookup(reinterpret_cast(start), size, std::string(signature))); +} + } diff --git a/src/detour.hpp b/src/detour.hpp index 06bd66b..bb6b080 100644 --- a/src/detour.hpp +++ b/src/detour.hpp @@ -18,6 +18,7 @@ #include #include +#include "ranges.hpp" #include "safetyhook.hpp" #ifdef KHOOK_X64 @@ -53,7 +54,7 @@ namespace KHook { using AsmJit = Asm::x86_Jit; #endif - DetourCapsule(); + DetourCapsule(std::uint32_t stack_size); ~DetourCapsule(); struct InsertHookDetails { @@ -79,8 +80,10 @@ namespace KHook { } bool SetupAddress(void* detour_address) { + auto range = std::make_unique(reinterpret_cast(detour_address), reinterpret_cast(detour_address) + 0xA); + auto result = safetyhook::InlineHook::create(detour_address, _jit_func_ptr); - if (result) { + if (result && Ranges::Add(std::move(range))) { // Successfully detour'd the function _safetyhook = std::move(result.value()); _original_function = reinterpret_cast(_safetyhook.original()); diff --git a/src/ranges.cpp b/src/ranges.cpp new file mode 100644 index 0000000..d7a019b --- /dev/null +++ b/src/ranges.cpp @@ -0,0 +1,142 @@ +#include "ranges.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace KHook::Ranges { + +std::vector> g_ranges; +std::shared_mutex g_mutex; + +bool Add(std::unique_ptr range) { + std::unique_lock lock(g_mutex); + + if (range->begin > range->end) { + return false; + } + + auto it = std::lower_bound( + g_ranges.begin(), + g_ranges.end(), + range->begin, + [](const std::unique_ptr& r, std::uintptr_t value) { + return r->begin < value; + }); + + // Check overlap with previous range + if (it != g_ranges.begin()) { + auto prev = std::prev(it); + + if (range->begin <= (*prev)->end) { + return false; + } + } + + // Check overlap with next range + if (it != g_ranges.end()) { + if (range->end >= (*it)->begin) { + return false; + } + } + + g_ranges.insert(it, std::move(range)); + return true; +} + +std::uintptr_t Lookup(std::uintptr_t start, std::size_t size, const std::string& bytes) { + // Parse the bytes sequence + std::vector sequence; + auto c_string = bytes.c_str(); + + static auto p = [](const char c) { + return ('0' <= c && c <= '9') ? c - '0' + : ('a' <= c && c <= 'f') ? 10 + (c - 'a') + : ('A' <= c && c <= 'F') ? 10 + (c - 'A') : -1; + }; + + for (int i = 0; i <= bytes.size(); i++) { + if (i == bytes.size() || bytes[i] == ' ') { + // New bytes/end of string. Process what we parsed + if (c_string[i - 2] == '?' && c_string[i - 1] == '?') { + // Wildcard + sequence.push_back(0xFFFF); + } else if (p(c_string[i - 1]) != -1 && p(c_string[i - 2]) != -1) { + std::uint16_t parsed = p(c_string[i - 1]) + (16 * p(c_string[i - 2])); + sequence.push_back(parsed); + } else { + return 0; + } + } + } + + /*std::cout << "Searching: " << std::setfill('0'); + for (int i = 0; i < sequence.size(); i++) { + if (sequence[i] == 0xFFFF) { + std::cout << "??"; + } else { + std::cout << std::hex << std::setw(2) << sequence[i]; + } + std::cout << " "; + } + std::cout << std::endl;*/ + + std::shared_lock lock(g_mutex); + + auto it = std::upper_bound( + g_ranges.begin(), + g_ranges.end(), + start + size, + [start](std::uintptr_t v, const std::unique_ptr& r) { + return r->begin <= v && r->begin >= start; + }); + + const auto& it_end = g_ranges.end(); + for (auto lookup = start, lookup_end = lookup + size; lookup < lookup_end; lookup++) { + // Move the iterator further until we intersect again + while (it != it_end && (*it)->end < lookup) { + it++; + } + + auto read_it = it; + bool found = true; + for (std::size_t i = 0; i < sequence.size() && found; i++) { + if (sequence[i] == 0xFFFF) { + // Wildcard, skip + continue; + } + + auto read = lookup + i; + + // Move the iterator further until we intersect again + while (read_it != it_end && (*read_it)->end < read) { + read_it++; + } + + if (it_end != read_it) { + if (read >= (*read_it)->begin && read <= (*read_it)->end) { + auto diff = read - (*read_it)->begin; + read = reinterpret_cast(&((*read_it)->og_bytes[diff])); + std::cout << " og: " << std::hex << static_cast((*read_it)->og_bytes[diff]); + } + } + + // Ensure bytes are matching + found &= (*reinterpret_cast(read) == sequence[i]); + } + + if (found) { + std::cout << "found" << std::endl; + return lookup; + } + } + + // Failure + return 0; +} + +} \ No newline at end of file diff --git a/src/ranges.hpp b/src/ranges.hpp new file mode 100644 index 0000000..f9cdb93 --- /dev/null +++ b/src/ranges.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include +#include +#include + +#include + +namespace KHook::Ranges { + +struct Range { + using Self = Range; + Range(const Self&) = delete; + Range& operator= (const Self&) = delete; + Range(uintptr_t begin, uintptr_t end) : begin(begin), end(end) { + auto len = (end - begin) + 1; + og_bytes = new std::uint8_t[len]; + + auto read = reinterpret_cast(begin); + for (decltype(len) i = 0; i < len; i++, read++) { + og_bytes[i] = *read; + } + } + ~Range() { + delete[] og_bytes; + } + + uintptr_t begin; + uintptr_t end; + std::uint8_t* og_bytes; +}; + +bool Add(std::unique_ptr range); + +// 0 - If lookup fails +std::uintptr_t Lookup(std::uintptr_t start, std::size_t size, const std::string& bytes); + +} \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..3166dc7 --- /dev/null +++ b/test.py @@ -0,0 +1,64 @@ +import argparse +import os +import subprocess +import sys + +SCRIPT_DIR = os.path.dirname(__file__) + +PACKAGE_DIRNAME = "package" +GTEST_PARALLEL_PATH = os.path.join(SCRIPT_DIR, "third_party", "gtest-parallel", "gtest-parallel") +CONFIGURE_SCRIPT_PATH = os.path.join(SCRIPT_DIR, "configure.py") + +def configure(build_dir, target): + if not os.path.isdir(build_dir): + os.mkdir(build_dir) + result = subprocess.run([sys.executable, CONFIGURE_SCRIPT_PATH, "--targets", target, "--enable-tests"], cwd=build_dir) + if result.returncode != 0: + sys.exit(result.returncode) + +def build(build_dir): + result = subprocess.run(["ambuild"], cwd=build_dir) + if result.returncode != 0: + sys.exit(result.returncode) + +def run_tests(build_dir, target): + test_binary = os.path.join(build_dir, PACKAGE_DIRNAME, target, "testrunner") + if sys.platform == "win32": + test_binary += ".exe" + elif sys.platform == "linux": + pass + else: + raise OSError(f"Unsupported platform: {sys.platform}") + + if not os.path.isfile(test_binary): + print(f"Test binary not found: {test_binary}") + sys.exit(1) + gtest_parallel = os.path.abspath(GTEST_PARALLEL_PATH) + result = subprocess.run([sys.executable, gtest_parallel, test_binary]) + sys.exit(result.returncode) + +def main(): + parser = argparse.ArgumentParser(description="Build and run the test suite.") + parser.add_argument("--target", choices=["x86", "x86_64"], required=True, + help="Target architecture to build and test.") + parser.add_argument("--build-dir", default="build", + help="Build directory.") + parser.add_argument("--skip-build", action="store_true", + help="Skip build step.") + parser.add_argument("--skip-configure", action="store_true", + help="Skip build configuration step.") + args = parser.parse_args() + + build_dir = os.path.abspath(args.build_dir) + target = args.target + + if not args.skip_build: + if not args.skip_configure: + configure(build_dir, target) + + build(build_dir) + + run_tests(build_dir, target) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test/AMBuilder b/test/AMBuilder new file mode 100644 index 0000000..df45cb8 --- /dev/null +++ b/test/AMBuilder @@ -0,0 +1,55 @@ +import os + +projectName = 'testrunner' + +for compiler in TestRunner.all_targets: + binary = compiler.Program(projectName) + compiler = binary.compiler + + compiler.cxxincludes += [os.path.join(builder.sourcePath, 'include')] + compiler.cxxdefines += ['KHOOK_STANDALONE','KHOOK_EXPORT'] + + for task in TestRunner.libkhook: + if task.target.arch == compiler.target.arch: + compiler.linkflags += [task.binary] + for task in TestRunner.libsafetyhook: + if task.target.arch == compiler.target.arch: + compiler.linkflags += [task.binary] + + compiler.includes += [ + os.path.join(builder.sourcePath, 'third_party', 'googletest', 'googletest', 'include') + ] + + for task in TestRunner.libgtest: + if task.target.arch == compiler.target.arch: + compiler.linkflags += [task.binary] + + binary.sources += [ + 'main.cpp', + 'test_functions.cpp', + 'test_members.cpp', + 'test_regressions.cpp', + 'test_static_hook.cpp', + 'test_virtuals.cpp', + 'test_virtual_hook.cpp', + 'test_ranges.cpp' + ] + + if compiler.like('gcc'): + pass + elif compiler.family == 'msvc': + compiler.cxxflags += [ + '/EHsc' + ] + + if compiler.target.platform == 'linux': + compiler.linkflags += [ + '-pthread' + ] + pass + elif compiler.target.platform == 'mac': + pass + elif compiler.target.platform == 'windows': + pass + + TestRunner.test_binaries += [builder.Add(binary)] \ No newline at end of file diff --git a/test/AMBuilder.gtest b/test/AMBuilder.gtest new file mode 100644 index 0000000..f48ae80 --- /dev/null +++ b/test/AMBuilder.gtest @@ -0,0 +1,31 @@ +import os + +libgtest = builder.StaticLibraryProject('libgtest') +libgtest.sources += [ + os.path.join(builder.sourcePath, 'third_party', 'googletest', 'googletest', 'src', 'gtest-all.cc'), +] + +for compiler in GTest.all_targets: + binary = libgtest.Configure(compiler, libgtest.name, 'Release - {0}'.format(compiler.target.arch)) + compiler = binary.compiler + + compiler.includes += [ + os.path.join(builder.sourcePath, 'third_party', 'googletest', 'googletest', 'include'), + os.path.join(builder.sourcePath, 'third_party', 'googletest', 'googletest'), + ] + + if compiler.like('gcc'): + pass + elif compiler.family == 'msvc': + compiler.cxxflags += [ + '/EHsc' + ] + + if compiler.target.platform == 'linux': + pass + elif compiler.target.platform == 'mac': + pass + elif compiler.target.platform == 'windows': + pass + +GTest.libgtest = builder.Add(libgtest) \ No newline at end of file diff --git a/test/helpers.hpp b/test/helpers.hpp new file mode 100644 index 0000000..b6f51bc --- /dev/null +++ b/test/helpers.hpp @@ -0,0 +1,115 @@ +#include + +#if defined(_MSC_VER) + #define NOINLINE __declspec(noinline) +#elif defined(__GNUC__) || defined(__clang__) + #define NOINLINE __attribute__((noinline)) + #define __thiscall +#else + #define NOINLINE +#endif + +template +class FunctionContext { +public: + using fnContextCallback = ::KHook::Return (FunctionContext::*)(ARGS...); + using fnCallback = ::KHook::Return (*)(ARGS...); + using fnHooked = RETURN (*)(ARGS...); + + FunctionContext() : m_pre(nullptr), m_post(nullptr) {} + FunctionContext(fnCallback callback) : m_pre(callback), m_post(callback) {} + FunctionContext(fnCallback pre, fnCallback post) : m_pre(pre), m_post(post) {} + + void Configure(fnCallback pre, fnCallback post) { + m_pre = pre; + m_post = post; + } + + ::KHook::Return OnPre(ARGS... args) { + if (m_pre) { + return m_pre(args...); + } + return {KHook::Action::Ignore}; + } + + ::KHook::Return OnPost(ARGS... args) { + if (m_post) { + return m_post(args...); + } + return {KHook::Action::Ignore}; + } + +private: + fnCallback m_pre; + fnCallback m_post; +}; + +template +class MemberContext { +public: + using fnContextCallback = ::KHook::Return (MemberContext::*)(TargetClass*, ARGS...); + using fnCallback = ::KHook::Return (*)(TargetClass*, ARGS...); + using fnHooked = RETURN (*)(TargetClass*, ARGS...); + + MemberContext() : m_pre(nullptr), m_post(nullptr) {} + MemberContext(fnCallback callback) : m_pre(callback), m_post(callback) {} + MemberContext(fnCallback pre, fnCallback post) : m_pre(pre), m_post(post) {} + + void Configure(fnCallback pre, fnCallback post) { + m_pre = pre; + m_post = post; + } + + ::KHook::Return OnPre(TargetClass* _this, ARGS... args) { + if (m_pre) { + return m_pre(_this, args...); + } + return {KHook::Action::Ignore}; + } + + ::KHook::Return OnPost(TargetClass* _this, ARGS... args) { + if (m_post) { + return m_post(_this, args...); + } + return {KHook::Action::Ignore}; + } + +private: + fnCallback m_pre; + fnCallback m_post; +}; + +template +class VirtualContext { +public: + using fnContextCallback = ::KHook::Return (VirtualContext::*)(TargetClass*, ARGS...); + using fnCallback = ::KHook::Return (*)(TargetClass*, ARGS...); + using fnHooked = RETURN (*)(TargetClass*, ARGS...); + + VirtualContext() : m_pre(nullptr), m_post(nullptr) {} + VirtualContext(fnCallback callback) : m_pre(callback), m_post(callback) {} + VirtualContext(fnCallback pre, fnCallback post) : m_pre(pre), m_post(post) {} + + void Configure(fnCallback pre, fnCallback post) { + m_pre = pre; + m_post = post; + } + + ::KHook::Return OnPre(TargetClass* _this, ARGS... args) { + if (m_pre) { + return m_pre(_this, args...); + } + return {KHook::Action::Ignore}; + } + + ::KHook::Return OnPost(TargetClass* _this, ARGS... args) { + if (m_post) { + return m_post(_this, args...); + } + return {KHook::Action::Ignore}; + } + +private: + fnCallback m_pre; + fnCallback m_post; +}; \ No newline at end of file diff --git a/test/main.cpp b/test/main.cpp new file mode 100644 index 0000000..822065b --- /dev/null +++ b/test/main.cpp @@ -0,0 +1,12 @@ +#include +#include + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + + int result = RUN_ALL_TESTS(); + + KHook::Shutdown(); + + return result; +} \ No newline at end of file diff --git a/test/test_functions.cpp b/test/test_functions.cpp new file mode 100644 index 0000000..35c00a1 --- /dev/null +++ b/test/test_functions.cpp @@ -0,0 +1,215 @@ +#include +#include +#include "helpers.hpp" + +class FunctionTests : public ::testing::Test { +protected: + typedef FunctionContext ContextType; + + FunctionTests() + { + } + + static void TargetFunction(int& value) { + value = 0xDEADBEEF; + } + + static KHook::Return HookFunction(int& value) { + std::cout << "Here"; + return {KHook::Action::Ignore}; + } + + static KHook::Return HookFunctionSupersede(int& value) { + return {KHook::Action::Supersede}; + } +}; + +TEST_F(FunctionTests, NoopPreCallback) { + KHook::Function hook(&TargetFunction, &HookFunction, nullptr); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + TargetFunction(value); + + std::string output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked function was modified"; + EXPECT_EQ(output, "Here") << "No response from callback"; +} + +TEST_F(FunctionTests, NoopPostCallback) { + KHook::Function hook(&TargetFunction, nullptr, &HookFunction); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + TargetFunction(value); + + std::string output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked function was modified"; + EXPECT_EQ(output, "Here") << "No response from callback"; +} + +TEST_F(FunctionTests, NoopPreAndPostCallbacks) { + KHook::Function hook(&TargetFunction, &HookFunction, &HookFunction); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + TargetFunction(value); + + std::string output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked function was modified"; + EXPECT_EQ(output, "HereHere") << "No response from callbacks"; +} + +TEST_F(FunctionTests, SupersedePreCallback) { + KHook::Function hook(&TargetFunction, &HookFunctionSupersede, nullptr); + + int value = 0x1337; + TargetFunction(value); + + EXPECT_EQ(value, 0x1337) << "Behaviour of hooked function was modified"; +} + +TEST_F(FunctionTests, ContextPreCallback) { + ContextType context(&HookFunction); + KHook::Function hook(&TargetFunction, &context, &ContextType::OnPre, nullptr); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + TargetFunction(value); + + std::string output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked function was modified"; + EXPECT_EQ(output, "Here") << "No response from callback"; + + hook.RemoveContext(&context); + + testing::internal::CaptureStdout(); + + value = 0x1337; + TargetFunction(value); + + output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked function was modified"; + EXPECT_EQ(output, "") << "Unexpected response from callback after context removal"; +} + +TEST_F(FunctionTests, ContextPostCallback) { + ContextType context(&HookFunction); + KHook::Function hook(&TargetFunction, &context, nullptr, &ContextType::OnPost); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + TargetFunction(value); + + std::string output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked function was modified"; + EXPECT_EQ(output, "Here") << "No response from callback"; + + hook.RemoveContext(&context); + + testing::internal::CaptureStdout(); + + value = 0x1337; + TargetFunction(value); + + output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked function was modified"; + EXPECT_EQ(output, "") << "Unexpected response from callback after context removal"; +} + +TEST_F(FunctionTests, ContextPreAndPostCallbacks) { + ContextType context(&HookFunction); + KHook::Function hook(&TargetFunction, &context, &ContextType::OnPre, &ContextType::OnPost); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + TargetFunction(value); + + std::string output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked function was modified"; + EXPECT_EQ(output, "HereHere") << "No response from callbacks"; + + hook.RemoveContext(&context); + + testing::internal::CaptureStdout(); + + value = 0x1337; + TargetFunction(value); + + output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked function was modified"; + EXPECT_EQ(output, "") << "Unexpected response from callbacks after context removal"; +} + +TEST_F(FunctionTests, ContextSupersedePreCallback) { + ContextType context(&HookFunctionSupersede); + KHook::Function hook(&TargetFunction, &context, &ContextType::OnPre, nullptr); + + int value = 0x1337; + TargetFunction(value); + + ASSERT_EQ(value, 0x1337) << "Behaviour of hooked function was modified"; + + hook.RemoveContext(&context); + + value = 0x1337; + TargetFunction(value); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked function was modified"; +} + +TEST_F(FunctionTests, MultipleContextsPreCallback) { + ContextType context1(&HookFunctionSupersede); + ContextType context2(&HookFunction); + KHook::Function hook(&TargetFunction, &context1, &ContextType::OnPre, nullptr); + hook.AddContext(&context2, &ContextType::OnPre, nullptr); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + TargetFunction(value); + + std::string output = testing::internal::GetCapturedStdout(); + + ASSERT_EQ(value, 0x1337) << "Behaviour of hooked function was modified"; + ASSERT_EQ(output, "Here") << "Incorrect response from multiple contexts"; + + hook.RemoveContext(&context1); + + testing::internal::CaptureStdout(); + + value = 0x1337; + TargetFunction(value); + + output = testing::internal::GetCapturedStdout(); + + ASSERT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked function was modified after context removal"; + ASSERT_EQ(output, "Here") << "Incorrect response from remaining context after context removal"; + + hook.RemoveContext(&context2); + + testing::internal::CaptureStdout(); + + value = 0x1337; + TargetFunction(value); + + output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked function was modified after context removal"; + EXPECT_EQ(output, "") << "Unexpected response from contexts after context removal"; +} \ No newline at end of file diff --git a/test/test_members.cpp b/test/test_members.cpp new file mode 100644 index 0000000..4132ce2 --- /dev/null +++ b/test/test_members.cpp @@ -0,0 +1,229 @@ +#include +#include +#include "helpers.hpp" + +class MemberTests : public ::testing::Test { +protected: + class TargetClass { + public: + void TargetMethod(int& value) { + value = 0xDEADBEEF; + } + }; + + typedef MemberContext ContextType; + + static KHook::Return HookMethod(TargetClass* _this, int& value) { + std::cout << "Here"; + return {KHook::Action::Ignore}; + } + + static KHook::Return HookMethodSupersede(TargetClass* _this, int& value) { + return {KHook::Action::Supersede}; + } +}; + +TEST_F(MemberTests, PreCallback) { + TargetClass instance; + + KHook::Member hook(&TargetClass::TargetMethod, &HookMethod, nullptr); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + instance.TargetMethod(value); + + std::string output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + EXPECT_EQ(output, "Here") << "No response from callback"; +} + +TEST_F(MemberTests, PostCallback) { + TargetClass instance; + + KHook::Member hook(&TargetClass::TargetMethod, nullptr, &HookMethod); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + instance.TargetMethod(value); + + std::string output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + EXPECT_EQ(output, "Here") << "No response from callback"; +} + +TEST_F(MemberTests, PreAndPostCallbacks) { + TargetClass instance; + + KHook::Member hook(&TargetClass::TargetMethod, &HookMethod, &HookMethod); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + instance.TargetMethod(value); + + std::string output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + EXPECT_EQ(output, "HereHere") << "No response from callbacks"; +} + +TEST_F(MemberTests, SupersedePreCallback) { + TargetClass instance; + + KHook::Member hook(&TargetClass::TargetMethod, &HookMethodSupersede, nullptr); + + int value = 0x1337; + instance.TargetMethod(value); + + EXPECT_EQ(value, 0x1337) << "Behaviour of hooked method was modified"; +} + +TEST_F(MemberTests, ContextPreCallback) { + TargetClass instance; + + ContextType context(&HookMethod); + KHook::Member hook(&TargetClass::TargetMethod, &context, &ContextType::OnPre, nullptr); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + instance.TargetMethod(value); + + std::string output = testing::internal::GetCapturedStdout(); + + ASSERT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + ASSERT_EQ(output, "Here") << "No response from callback"; + + hook.RemoveContext(&context); + + testing::internal::CaptureStdout(); + + value = 0x1337; + instance.TargetMethod(value); + + output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + EXPECT_EQ(output, "") << "Unexpected response from callback after context removal"; +} + +TEST_F(MemberTests, ContextPostCallback) { + TargetClass instance; + + ContextType context(&HookMethod); + KHook::Member hook(&TargetClass::TargetMethod, &context, nullptr, &ContextType::OnPost); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + instance.TargetMethod(value); + + std::string output = testing::internal::GetCapturedStdout(); + + ASSERT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + ASSERT_EQ(output, "Here") << "No response from callback"; + + hook.RemoveContext(&context); + + testing::internal::CaptureStdout(); + + value = 0x1337; + instance.TargetMethod(value); + + output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + EXPECT_EQ(output, "") << "Unexpected response from callback after context removal"; +} + +TEST_F(MemberTests, ContextPreAndPostCallbacks) { + TargetClass instance; + + ContextType context(&HookMethod); + KHook::Member hook(&TargetClass::TargetMethod, &context, &ContextType::OnPre, &ContextType::OnPost); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + instance.TargetMethod(value); + + std::string output = testing::internal::GetCapturedStdout(); + + ASSERT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + ASSERT_EQ(output, "HereHere") << "No response from callbacks"; + + hook.RemoveContext(&context); + + testing::internal::CaptureStdout(); + + value = 0x1337; + instance.TargetMethod(value); + + output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + EXPECT_EQ(output, "") << "Unexpected response from callbacks after context removal"; +} + +TEST_F(MemberTests, ContextSupersedePreCallback) { + TargetClass instance; + + ContextType context(&HookMethodSupersede); + KHook::Member hook(&TargetClass::TargetMethod, &context, &ContextType::OnPre, nullptr); + + int value = 0x1337; + instance.TargetMethod(value); + + ASSERT_EQ(value, 0x1337) << "Behaviour of hooked method was modified"; + + hook.RemoveContext(&context); + + value = 0x1337; + instance.TargetMethod(value); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; +} + +TEST_F(MemberTests, MultipleContextsPreCallback) { + TargetClass instance; + + ContextType context1(&HookMethodSupersede); + ContextType context2(&HookMethod); + KHook::Member hook(&TargetClass::TargetMethod, &context1, &ContextType::OnPre, nullptr); + hook.AddContext(&context2, &ContextType::OnPre, nullptr); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + instance.TargetMethod(value); + std::string output = testing::internal::GetCapturedStdout(); + + ASSERT_EQ(value, 0x1337) << "Behaviour of hooked method was modified"; + ASSERT_EQ(output, "Here") << "Incorrect response from multiple contexts"; + + hook.RemoveContext(&context1); + + testing::internal::CaptureStdout(); + + value = 0x1337; + instance.TargetMethod(value); + output = testing::internal::GetCapturedStdout(); + + ASSERT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified after context removal"; + ASSERT_EQ(output, "Here") << "Incorrect response from remaining context after context removal"; + + hook.RemoveContext(&context2); + + testing::internal::CaptureStdout(); + + value = 0x1337; + instance.TargetMethod(value); + output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified after context removal"; + EXPECT_EQ(output, "") << "Unexpected response from callbacks after context removal"; +} \ No newline at end of file diff --git a/test/test_ranges.cpp b/test/test_ranges.cpp new file mode 100644 index 0000000..5638fab --- /dev/null +++ b/test/test_ranges.cpp @@ -0,0 +1,218 @@ +#include +#include +#include "helpers.hpp" + +#include +#include +#include +#include +#include + +struct MemoryRegion +{ + uintptr_t start; + uintptr_t end; +}; + +#ifndef WIN32 + +std::vector GetReadableRegions() +{ + std::vector regions; + + std::ifstream maps("/proc/self/maps"); + std::string line; + + while (std::getline(maps, line)) + { + char perms[5]; + unsigned long long start, end; + + if (sscanf(line.c_str(), + "%llx-%llx %4s", + &start, + &end, + perms) == 3) + { + if (perms[0] == 'r') + { + regions.push_back({ + static_cast(start), + static_cast(end) + }); + } + } + } + return regions; +} + +#else + +#include + +bool IsReadable(DWORD protect) +{ + protect &= 0xff; + + switch (protect) + { + case PAGE_READONLY: + case PAGE_READWRITE: + case PAGE_WRITECOPY: + case PAGE_EXECUTE_READ: + case PAGE_EXECUTE_READWRITE: + case PAGE_EXECUTE_WRITECOPY: + return true; + + default: + return false; + } +} + +std::vector GetReadableRegions() +{ + std::vector regions; + + SYSTEM_INFO si; + GetSystemInfo(&si); + + uintptr_t addr = + reinterpret_cast(si.lpMinimumApplicationAddress); + + uintptr_t maxAddr = + reinterpret_cast(si.lpMaximumApplicationAddress); + + MEMORY_BASIC_INFORMATION mbi; + + while (addr < maxAddr) + { + SIZE_T result = + VirtualQuery( + reinterpret_cast(addr), + &mbi, + sizeof(mbi)); + + if (result == 0) + break; + + if (mbi.State == MEM_COMMIT && + !(mbi.Protect & PAGE_GUARD) && + !(mbi.Protect & PAGE_NOACCESS) && + IsReadable(mbi.Protect)) + { + regions.push_back({ + reinterpret_cast(mbi.BaseAddress), + reinterpret_cast(mbi.BaseAddress) + + mbi.RegionSize + }); + } + + addr = + reinterpret_cast(mbi.BaseAddress) + + mbi.RegionSize; + } + + return regions; +} + +#endif + + +std::uint64_t do_nothing_big(std::uint64_t x) +{ + volatile std::uint64_t sink; + + x = x * 3 + 7; x ^= x << 13; + x = x * 5 + 11; x ^= x >> 17; + x = x * 7 + 13; x ^= x << 9; + x = x * 11 + 17; x ^= x >> 21; + x = x * 13 + 19; x ^= x << 5; + x = x * 17 + 23; x ^= x >> 11; + x = x * 19 + 29; x ^= x << 7; + x = x * 23 + 31; x ^= x >> 13; + x = x * 29 + 37; x ^= x << 17; + x = x * 31 + 41; x ^= x >> 19; + + sink = x; + + return x; +} + +class RangesTest : public ::testing::Test { +public: + static const constexpr std::size_t SIGNATURE_BYTES = 20; + RangesTest() { + // Make a signature of 50 bytes + _function = reinterpret_cast(&do_nothing_big); + + std::uint8_t* read = reinterpret_cast(_function); + + for (std::size_t i = 0; i < SIGNATURE_BYTES; i++) { + auto byte = read[i]; + auto lower = byte % 16; + auto upper = (byte - lower) / 16; + + _signature[i * 3 + 0] = (upper >= 10) ? 'A' + (upper - 10) : '0' + upper; + _signature[i * 3 + 1] = (lower >= 10) ? 'A' + (lower - 10) : '0' + lower; + _signature[i * 3 + 2] = ' '; + } + _signature[sizeof(_signature) - 1] = '\0'; + + std::cout << std::hex << reinterpret_cast(_function) << " | Crafted signature: " << _signature << std::endl; + } + + bool Lookup() { + auto regions = GetReadableRegions(); + for (const auto& region : regions) { + auto ptr = KHook::LookupSignature(reinterpret_cast(region.start), static_cast(region.end - region.start), _signature); + if (_function == ptr) { + std::cout << "Function lookup success!" << std::endl; + return true; + } + } + return false; + } + + char _signature[SIGNATURE_BYTES * 3]; + void* _function; +}; + +TEST_F(RangesTest, Lookup) { + EXPECT_EQ(Lookup(), true) << "Failed to sig scan our function"; +} + +TEST_F(RangesTest, LookupWildcard) { + static const constexpr int byte1 = 5; + static_assert(byte1 <= RangesTest::SIGNATURE_BYTES); + static const constexpr int byte2 = 17; + static_assert(byte2 <= RangesTest::SIGNATURE_BYTES); + static const constexpr int byte3 = 8; + static_assert(byte3 <= RangesTest::SIGNATURE_BYTES); + + _signature[byte1 * 3 + 0] = '?'; + _signature[byte1 * 3 + 1] = '?'; + _signature[byte2 * 3 + 0] = '?'; + _signature[byte2 * 3 + 1] = '?'; + _signature[byte3 * 3 + 0] = '?'; + _signature[byte3 * 3 + 1] = '?'; + + EXPECT_EQ(Lookup(), true) << "Failed to sig scan our function"; +} + +TEST_F(RangesTest, LookupWithHook) { + static auto nothing = [](std::uint64_t x){ return x; }; + + EXPECT_NE(KHook::SetupHook( + reinterpret_cast(&do_nothing_big), + nullptr, + nullptr, + reinterpret_cast(¬hing), + reinterpret_cast(¬hing), + reinterpret_cast(¬hing), + reinterpret_cast(¬hing), + 100, + false + ), KHook::INVALID_HOOK) << "Failed to setup hook"; + + EXPECT_EQ(Lookup(), true) << "Failed to sig scan our function"; +} \ No newline at end of file diff --git a/test/test_regressions.cpp b/test/test_regressions.cpp new file mode 100644 index 0000000..3bff38f --- /dev/null +++ b/test/test_regressions.cpp @@ -0,0 +1,277 @@ +#include +#include +#include +#include +#include "helpers.hpp" + +namespace { + // Tests high-level classes Function, Virtual, and Member and verifies + // that, after hooking, that all arguments and return values are + // passed to the detour callback and the original function unchanged, + // especially when calling conventions dictate that arguments be passed + // onto the stack. + // + // This verifies that enough of the stack is being copied by the hook + // classes, prior to execution being passed onto callbacks and the + // originally hooked function. + namespace stack_copy_size { + template + class TestCase; + + using Case1 = TestCase<1, void, bool, bool, bool, void*>; + using Case2 = TestCase<2, void, + void*, int, int, const char*, + float, float, int, int, int, + void*, void*, void*, bool, float, int>; + using Case3 = TestCase<3, void, bool, void*>; + using Case4 = TestCase<4, void, bool, bool, bool, bool, bool, int>; + using Case5 = TestCase<5, int, int>; + using Case6 = TestCase<6, std::tuple, int, int>; + using Case7 = TestCase<7, std::tuple>; + using Case8 = TestCase<8, void, + void*, int, int, const char*, + float, float, int, int, int, + void*, void*, void*, bool, float, int, + bool, bool, bool, bool, int>; + using AllCases = ::testing::Types< + Case1, + Case2, + Case3, + Case4, + Case5, + Case6, + Case7, + Case8>; + + template + class TestCase { + public: + using Self = TestCase; + using ArgsTuple = std::tuple; + using Return = RETURN; + using FunctionHook = ::KHook::Function; + using MemberHook = ::KHook::Member; + using VirtualHook = ::KHook::Virtual; + + static constexpr typename Self::ArgsTuple GetExpectedArgs() { + if constexpr(std::is_same_v) { + return std::make_tuple(false, true, false, nullptr); + } + else if constexpr(std::is_same_v) { + return std::make_tuple( + reinterpret_cast((std::uintptr_t)0xDEADBEEF), + 17, + 5, + "42", + 0.25f, + 0.68f, + 19, + 123, + 7, + nullptr, + reinterpret_cast((std::uintptr_t)0x1002), + reinterpret_cast((std::uintptr_t)0x80099), + false, + 3.5f, + 1337); + } + else if constexpr(std::is_same_v) { + return std::make_tuple(true, reinterpret_cast((std::uintptr_t)0xDEADBEEF)); + } + else if constexpr(std::is_same_v) { + return std::make_tuple(true, false, false, true, true, 1337); + } + else if constexpr(std::is_same_v) { + return std::make_tuple(12); + } + else if constexpr(std::is_same_v) { + return std::make_tuple(45, 22); + } + else if constexpr(std::is_same_v) { + return std::make_tuple(); + } + else if constexpr(std::is_same_v) { + return std::make_tuple( + reinterpret_cast((std::uintptr_t)0xDEADBEEF), + 17, + 5, + "42", + 0.25f, + 0.68f, + 19, + 123, + 7, + nullptr, + reinterpret_cast((std::uintptr_t)0x1002), + reinterpret_cast((std::uintptr_t)0x80099), + false, + 3.5f, + 1337, + true, + false, + true, + true, + 70023); + } + throw std::runtime_error("expected arguments not set for test case"); + } + + static constexpr RETURN GetExpectedReturn() { + if constexpr(std::is_void_v) { + return; + } + else if constexpr(std::is_same_v) { + return 25; + } + else if constexpr(std::is_same_v) { + return std::make_tuple( + reinterpret_cast((std::uintptr_t)0xDEADBEEF), + reinterpret_cast((std::uintptr_t)0x1002), + reinterpret_cast((std::uintptr_t)0x80099) + ); + } + else if constexpr(std::is_same_v) { + return std::make_tuple( + reinterpret_cast((std::uintptr_t)0xDEADBEEF), + reinterpret_cast((std::uintptr_t)0xBEEFDEAD), + reinterpret_cast((std::uintptr_t)0x1200056) + ); + } + throw std::runtime_error("expected return not set for test case"); + } + + TestCase() : _orig_calls(0), _hook_calls(0) { + _expected_args = GetExpectedArgs(); + } + + static NOINLINE RETURN FunctionCall(ARGS... args) { + SCOPED_TRACE("original method"); + auto actual = std::make_tuple(args...); + _orig_static_calls++; + EXPECT_EQ(_expected_args, actual); + if constexpr(!std::is_void_v) { + return GetExpectedReturn(); + } + } + + NOINLINE RETURN MemberCall(ARGS... args) { + SCOPED_TRACE("original method"); + auto actual = std::make_tuple(args...); + _orig_calls++; + EXPECT_EQ(_expected_args, actual); + if constexpr(!std::is_void_v) { + return GetExpectedReturn(); + } + } + + virtual RETURN VirtualCall(ARGS... args) { + SCOPED_TRACE("original method"); + auto actual = std::make_tuple(args...); + _orig_calls++; + EXPECT_EQ(_expected_args, actual); + if constexpr(!std::is_void_v) { + return GetExpectedReturn(); + } + } + + static KHook::Return Callback(Self* hookedThis, ARGS... args) { + SCOPED_TRACE("hook callback"); + auto actual = std::make_tuple(args...); + hookedThis->_hook_calls++; + EXPECT_EQ(hookedThis->_expected_args, actual); + return { KHook::Action::Ignore }; + } + + KHook::Return Callback(ARGS... args) { + SCOPED_TRACE("hook callback"); + auto actual = std::make_tuple(args...); + _hook_calls++; + EXPECT_EQ(_expected_args, actual); + return { KHook::Action::Ignore }; + } + public: + static std::uint32_t GetNumOrigStaticCalls() { return _orig_static_calls; } + std::uint32_t GetNumOrigCalls() const { return _orig_calls; } + std::uint32_t GetNumHookCalls() const { return _hook_calls; } + private: + static inline std::uint32_t _orig_static_calls = 0; + static inline ArgsTuple _expected_args; + std::uint32_t _orig_calls; + std::uint32_t _hook_calls; + }; + + template + class Regression_StackCopySizeTests : public ::testing::Test { + protected: + void SetUp() override { + _case = new CASE; + } + + void TearDown() override { + delete _case; + } + + CASE* _case; + }; + + TYPED_TEST_SUITE(Regression_StackCopySizeTests, AllCases); + + TYPED_TEST(Regression_StackCopySizeTests, Function) { + typename TypeParam::FunctionHook hook(&TypeParam::FunctionCall, this->_case, &TypeParam::Callback, nullptr); + + auto expected = TypeParam::GetExpectedArgs(); + std::apply([&](auto&&... args) { + if constexpr(std::is_void_v) { + TypeParam::FunctionCall(std::forward(args)...); + } + else { + auto expected_ret = TypeParam::GetExpectedReturn(); + auto actual_ret = TypeParam::FunctionCall(std::forward(args)...); + EXPECT_EQ(expected_ret, actual_ret); + } + }, expected); + + EXPECT_EQ(this->_case->GetNumHookCalls(), 1) << "Pre-hook should run exactly once"; + EXPECT_EQ(this->_case->GetNumOrigStaticCalls(), 1) << "Original method should still run after Ignore"; + } + + TYPED_TEST(Regression_StackCopySizeTests, Member) { + typename TypeParam::MemberHook hook(&TypeParam::MemberCall, &TypeParam::Callback, nullptr); + + auto expected = TypeParam::GetExpectedArgs(); + std::apply([&](auto&&... args) { + if constexpr(std::is_void_v) { + this->_case->MemberCall(std::forward(args)...); + } + else { + auto expected_ret = TypeParam::GetExpectedReturn(); + auto actual_ret = this->_case->MemberCall(std::forward(args)...); + EXPECT_EQ(expected_ret, actual_ret); + } + }, expected); + + EXPECT_EQ(this->_case->GetNumHookCalls(), 1) << "Pre-hook should run exactly once"; + EXPECT_EQ(this->_case->GetNumOrigCalls(), 1) << "Original method should still run after Ignore"; + } + + TYPED_TEST(Regression_StackCopySizeTests, Virtual) { + typename TypeParam::VirtualHook hook(&TypeParam::VirtualCall, &TypeParam::Callback, nullptr); + hook.Add(this->_case); + + auto expected = TypeParam::GetExpectedArgs(); + std::apply([&](auto&&... args) { + if constexpr(std::is_void_v) { + this->_case->VirtualCall(std::forward(args)...); + } + else { + auto expected_ret = TypeParam::GetExpectedReturn(); + auto actual_ret = this->_case->VirtualCall(std::forward(args)...); + EXPECT_EQ(expected_ret, actual_ret); + } + }, expected); + + EXPECT_EQ(this->_case->GetNumHookCalls(), 1) << "Pre-hook should run exactly once"; + EXPECT_EQ(this->_case->GetNumOrigCalls(), 1) << "Original method should still run after Ignore"; + } + } +} \ No newline at end of file diff --git a/test/test_static_hook.cpp b/test/test_static_hook.cpp new file mode 100644 index 0000000..098984d --- /dev/null +++ b/test/test_static_hook.cpp @@ -0,0 +1,685 @@ +#include +#include +#include +#include +#include "helpers.hpp" + +template +class NoopStaticHookTemplate { + public: + static NOINLINE Ret PrePostNoop(Args... args); + static NOINLINE Ret CallOriginal(Args... args); + static NOINLINE Ret MakeReturn(Args... args); + static NOINLINE void OnRemoved(int hookId); +}; + +template +NOINLINE Ret NoopStaticHookTemplate::PrePostNoop(Args... args) { + std::cout << "PrePostNoop()" << std::endl; + KHook::SaveReturnValue( + KHook::Action::Ignore, + nullptr, + 0, + nullptr, + nullptr, + false + ); + if constexpr (std::is_same::value) { + return; + } else { + return Ret(); + } +} + +template +NOINLINE Ret NoopStaticHookTemplate::CallOriginal(Args... args) { + std::cout << "CallOriginal()" << std::endl; + auto original = + reinterpret_cast(KHook::GetOriginalFunction()); + if constexpr (std::is_same::value) { + original(args...); + KHook::SaveReturnValue( + KHook::Action::Ignore, + nullptr, + 0, + nullptr, + nullptr, + true + ); + return; + } else { + Ret result = original(args...); + KHook::SaveReturnValue( + KHook::Action::Ignore, + &result, + sizeof(Ret), + (void*)KHook::init_operator, + (void*)KHook::deinit_operator, + true + ); + return result; + } +} + +template +NOINLINE Ret NoopStaticHookTemplate::MakeReturn(Args... args) { + std::cout << "MakeReturn()" << std::endl; + if constexpr (std::is_same::value) { + KHook::DestroyReturnValue(); + return; + } else { + Ret result = *((Ret*)KHook::GetCurrentValuePtr(true)); + KHook::DestroyReturnValue(); + return result; + } +} + +template +NOINLINE void NoopStaticHookTemplate::OnRemoved(int hookId) { + std::cout << "OnRemoved(" << std::dec << hookId << ")" << std::endl; +} + +class StaticHookTest: public ::testing::Test { + protected: + class TestObject { + public: + int m_testValue; + }; + + class HookedClass { + public: + NOINLINE static bool IsAllowed(TestObject* obj) { + std::cout << "HookedClass::IsAllowed()" << std::endl; + return true; + } + + NOINLINE static int SetObjectValue(TestObject* obj, int value) { + std::cout << "HookedClass::SetObjectValue()" << std::endl; + obj->m_testValue = value; + return value; + } + + NOINLINE static void MyVoid(TestObject* obj) { + std::cout << "HookedClass::MyVoid()" << std::endl; + } + }; + + using IsAllowedNoopHook = NoopStaticHookTemplate; + using SetObjectValueNoopHook = NoopStaticHookTemplate; + using MyVoidNoopHook = NoopStaticHookTemplate; + + class FakeClass { + public: + NOINLINE static bool OverrideIsAllowedReturnValue(TestObject* obj) { + std::cout << "OverrideReturnValue()" << std::endl; + bool result = false; + KHook::SaveReturnValue( + KHook::Action::Override, + &result, + sizeof(bool), + (void*)KHook::init_operator, + (void*)KHook::deinit_operator, + false + ); + return false; + } + + NOINLINE static bool SupersedeIsAllowedReturnValue(TestObject* obj) { + std::cout << "SupersedeReturnValue()" << std::endl; + bool result = false; + KHook::SaveReturnValue( + KHook::Action::Supersede, + &result, + sizeof(bool), + (void*)KHook::init_operator, + (void*)KHook::deinit_operator, + false + ); + return false; + } + + NOINLINE static int OverrideSetObjectValue(TestObject* obj, int value) { + auto recall = reinterpret_cast( + KHook::DoRecall( + KHook::Action::Ignore, + nullptr, + 0, + nullptr, + nullptr + ) + ); + recall(obj, 1337); + return 0; + } + + NOINLINE static int SupersedeSetObjectValue(TestObject* obj, int value) { + std::cout << "SupersedeSetObjectValue()" << std::endl; + int newValue = 9001; + KHook::SaveReturnValue( + KHook::Action::Supersede, + &newValue, + sizeof(int), + (void*)KHook::init_operator, + (void*)KHook::deinit_operator, + false + ); + return 0; + } + + NOINLINE static int HookInsideSetObjectValue(TestObject* obj, int value) { + m_hookId = KHook::SetupHook( + (void*)&HookedClass::IsAllowed, + nullptr, + (void*)&IsAllowedNoopHook::OnRemoved, + (void*)&IsAllowedNoopHook::PrePostNoop, + (void*)&IsAllowedNoopHook::PrePostNoop, + (void*)&IsAllowedNoopHook::MakeReturn, + (void*)&IsAllowedNoopHook::CallOriginal, + false + ); + KHook::SaveReturnValue( + KHook::Action::Ignore, + nullptr, + 0, + nullptr, + nullptr, + false + ); + return 0; + } + + NOINLINE static void SupersedeMyVoid(TestObject* obj) { + std::cout << "SupersedeMyVoid()" << std::endl; + KHook::SaveReturnValue( + KHook::Action::Supersede, + nullptr, + 0, + nullptr, + nullptr, + false + ); + } + }; + + protected: + void SetUp() override { + obj = new TestObject(); + } + + void TearDown() override { + if (obj) { + delete obj; + obj = nullptr; + } + } + + TestObject* obj = nullptr; + static int m_hookId; +}; + +int StaticHookTest::m_hookId = KHook::INVALID_HOOK; + +TEST_F(StaticHookTest, Noop) { + int hookId = KHook::SetupHook( + (void*)&HookedClass::IsAllowed, + nullptr, + (void*)&IsAllowedNoopHook::OnRemoved, + (void*)&IsAllowedNoopHook::PrePostNoop, + (void*)&IsAllowedNoopHook::PrePostNoop, + (void*)&IsAllowedNoopHook::MakeReturn, + (void*)&IsAllowedNoopHook::CallOriginal, + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + testing::internal::CaptureStdout(); + + bool overriddenResult = HookedClass::IsAllowed(obj); + + KHook::RemoveHook(hookId, false); + + bool originalResult = HookedClass::IsAllowed(obj); + + std::string output = testing::internal::GetCapturedStdout(); + + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "CallOriginal()" << std::endl; + expected << "HookedClass::IsAllowed()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + expected << "OnRemoved(" << std::dec << hookId << ")" << std::endl; + expected << "HookedClass::IsAllowed()" << std::endl; + + EXPECT_EQ(output, expected.str()) + << "Callback functions should be called in the expected order"; + EXPECT_TRUE(overriddenResult) + << "Method should return original value when hooked"; + EXPECT_TRUE(originalResult) + << "Method should return original value after hook removal"; +} + +TEST_F(StaticHookTest, NoopVoid) { + int hookId = KHook::SetupHook( + (void*)&HookedClass::MyVoid, + nullptr, + (void*)&MyVoidNoopHook::OnRemoved, + (void*)&MyVoidNoopHook::PrePostNoop, + (void*)&MyVoidNoopHook::PrePostNoop, + (void*)&MyVoidNoopHook::MakeReturn, + (void*)&MyVoidNoopHook::CallOriginal, + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + testing::internal::CaptureStdout(); + + HookedClass::MyVoid(obj); + + KHook::RemoveHook(hookId, false); + + HookedClass::MyVoid(obj); + + std::string output = testing::internal::GetCapturedStdout(); + + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "CallOriginal()" << std::endl; + expected << "HookedClass::MyVoid()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + expected << "OnRemoved(" << std::dec << hookId << ")" << std::endl; + expected << "HookedClass::MyVoid()" << std::endl; + + ASSERT_EQ(output, expected.str()) + << "Callback functions should be called in the expected order"; +} + +TEST_F(StaticHookTest, OverrideReturnValuePre) { + int hookId = KHook::SetupHook( + (void*)&HookedClass::IsAllowed, + nullptr, + (void*)&IsAllowedNoopHook::OnRemoved, + (void*)&FakeClass::OverrideIsAllowedReturnValue, + (void*)&IsAllowedNoopHook::PrePostNoop, + (void*)&IsAllowedNoopHook::MakeReturn, + (void*)&IsAllowedNoopHook::CallOriginal, + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + bool result = HookedClass::IsAllowed(obj); + EXPECT_FALSE(result) << "Method should return false when hooked"; + + KHook::RemoveHook(hookId, false); + + result = HookedClass::IsAllowed(obj); + EXPECT_TRUE(result) << "Method should return true after hook removal"; +} + +TEST_F(StaticHookTest, OverrideReturnValuePost) { + int hookId = KHook::SetupHook( + (void*)&HookedClass::IsAllowed, + nullptr, + (void*)&IsAllowedNoopHook::OnRemoved, + (void*)&IsAllowedNoopHook::PrePostNoop, + (void*)&FakeClass::OverrideIsAllowedReturnValue, + (void*)&IsAllowedNoopHook::MakeReturn, + (void*)&IsAllowedNoopHook::CallOriginal, + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + bool result = HookedClass::IsAllowed(obj); + EXPECT_FALSE(result) << "Method should return false when hooked"; + + KHook::RemoveHook(hookId, false); + + result = HookedClass::IsAllowed(obj); + EXPECT_TRUE(result) << "Method should return true after hook removal"; +} + +TEST_F(StaticHookTest, SupersedeReturnValue) { + int hookId = KHook::SetupHook( + (void*)&HookedClass::IsAllowed, + nullptr, + (void*)&IsAllowedNoopHook::OnRemoved, + (void*)&FakeClass::SupersedeIsAllowedReturnValue, + (void*)&IsAllowedNoopHook::PrePostNoop, + (void*)&IsAllowedNoopHook::MakeReturn, + (void*)&IsAllowedNoopHook::CallOriginal, + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + testing::internal::CaptureStdout(); + + bool overriddenResult = HookedClass::IsAllowed(obj); + + KHook::RemoveHook(hookId, false); + + bool originalResult = HookedClass::IsAllowed(obj); + + std::string output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(output.find("CallOriginal()"), std::string::npos) + << "Original method should not be called"; + EXPECT_FALSE(overriddenResult) << "Method should return false when hooked"; + EXPECT_TRUE(originalResult) + << "Method should return true after hook removal"; +} + +TEST_F(StaticHookTest, SupersedeVoidReturnValue) { + int hookId = KHook::SetupHook( + (void*)&HookedClass::MyVoid, + nullptr, + (void*)&MyVoidNoopHook::OnRemoved, + (void*)&FakeClass::SupersedeMyVoid, + (void*)&MyVoidNoopHook::PrePostNoop, + (void*)&MyVoidNoopHook::MakeReturn, + (void*)&MyVoidNoopHook::CallOriginal, + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + testing::internal::CaptureStdout(); + + HookedClass::MyVoid(obj); + + std::string output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "SupersedeMyVoid()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + ASSERT_EQ(expected.str(), output) + << "Callbacks should be called in the correct order"; + } + + KHook::RemoveHook(hookId, false); + + testing::internal::CaptureStdout(); + + HookedClass::MyVoid(obj); + + output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "HookedClass::MyVoid()" << std::endl; + ASSERT_EQ(expected.str(), output) + << "No callbacks should be called after hook removal"; + } +} + +TEST_F(StaticHookTest, OverrideParameterWithRecall) { + int hookId = KHook::SetupHook( + (void*)&HookedClass::SetObjectValue, + nullptr, + (void*)&SetObjectValueNoopHook::OnRemoved, + (void*)&FakeClass::OverrideSetObjectValue, + (void*)&SetObjectValueNoopHook::PrePostNoop, + (void*)&SetObjectValueNoopHook::MakeReturn, + (void*)&SetObjectValueNoopHook::CallOriginal, + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + int result = HookedClass::SetObjectValue(obj, 0xDEADBEEF); + EXPECT_EQ(obj->m_testValue, 1337) + << "Method should set value to hooked value"; + EXPECT_EQ(result, 1337) << "Method should return hooked value"; + + KHook::RemoveHook(hookId, false); + + result = HookedClass::SetObjectValue(obj, 0xDEADBEEF); + EXPECT_EQ(obj->m_testValue, 0xDEADBEEF) + << "Method should set value to original value after hook removal"; + EXPECT_EQ(result, 0xDEADBEEF) + << "Method should return original value after hook removal"; +} + +TEST_F(StaticHookTest, SupersedeThenNoopVoidHooksOnSameFunction) { + int firstHookId = KHook::SetupHook( + (void*)&HookedClass::MyVoid, + nullptr, + (void*)&MyVoidNoopHook::OnRemoved, + (void*)&FakeClass::SupersedeMyVoid, + (void*)&MyVoidNoopHook::PrePostNoop, + (void*)&MyVoidNoopHook::MakeReturn, + (void*)&MyVoidNoopHook::CallOriginal, + false + ); + + ASSERT_NE(firstHookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + int secondHookId = KHook::SetupHook( + (void*)&HookedClass::MyVoid, + nullptr, + (void*)&MyVoidNoopHook::OnRemoved, + (void*)&MyVoidNoopHook::PrePostNoop, + (void*)&MyVoidNoopHook::PrePostNoop, + (void*)&MyVoidNoopHook::MakeReturn, + (void*)&MyVoidNoopHook::CallOriginal, + false + ); + + ASSERT_NE(secondHookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + testing::internal::CaptureStdout(); + HookedClass::MyVoid(obj); + std::string output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "SupersedeMyVoid()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + + ASSERT_EQ(expected.str(), output) + << "Callback functions should be called in the correct order"; + } + + KHook::RemoveHook(firstHookId, false); + + testing::internal::CaptureStdout(); + HookedClass::MyVoid(obj); + output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "CallOriginal()" << std::endl; + expected << "HookedClass::MyVoid()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + + ASSERT_EQ(expected.str(), output) + << "Callback functions should be called in the correct order"; + } +} + +TEST_F(StaticHookTest, MultipleNoopVoidHooksOnSameFunction) { + int firstHookId = KHook::SetupHook( + (void*)&HookedClass::MyVoid, + nullptr, + (void*)&MyVoidNoopHook::OnRemoved, + (void*)&MyVoidNoopHook::PrePostNoop, + (void*)&MyVoidNoopHook::PrePostNoop, + (void*)&MyVoidNoopHook::MakeReturn, + (void*)&MyVoidNoopHook::CallOriginal, + false + ); + + ASSERT_NE(firstHookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + int secondHookId = KHook::SetupHook( + (void*)&HookedClass::MyVoid, + nullptr, + (void*)&MyVoidNoopHook::OnRemoved, + (void*)&MyVoidNoopHook::PrePostNoop, + (void*)&MyVoidNoopHook::PrePostNoop, + (void*)&MyVoidNoopHook::MakeReturn, + (void*)&MyVoidNoopHook::CallOriginal, + false + ); + + ASSERT_NE(secondHookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + testing::internal::CaptureStdout(); + HookedClass::MyVoid(obj); + std::string output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "CallOriginal()" << std::endl; + expected << "HookedClass::MyVoid()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + + ASSERT_EQ(expected.str(), output) + << "Callback functions should be called in the correct order"; + } + + KHook::RemoveHook(firstHookId, false); + + testing::internal::CaptureStdout(); + HookedClass::MyVoid(obj); + output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "CallOriginal()" << std::endl; + expected << "HookedClass::MyVoid()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + + ASSERT_EQ(expected.str(), output) + << "Callback functions should be called in the correct order"; + } +} + +TEST_F(StaticHookTest, SupersedeThenNoopHooksOnSameFunction) { + int firstHookId = KHook::SetupHook( + (void*)&HookedClass::SetObjectValue, + nullptr, + (void*)&SetObjectValueNoopHook::OnRemoved, + (void*)&FakeClass::SupersedeSetObjectValue, + (void*)&SetObjectValueNoopHook::PrePostNoop, + (void*)&SetObjectValueNoopHook::MakeReturn, + (void*)&SetObjectValueNoopHook::CallOriginal, + false + ); + + ASSERT_NE(firstHookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + int secondHookId = KHook::SetupHook( + (void*)&HookedClass::SetObjectValue, + nullptr, + (void*)&SetObjectValueNoopHook::OnRemoved, + (void*)&SetObjectValueNoopHook::PrePostNoop, + (void*)&SetObjectValueNoopHook::PrePostNoop, + (void*)&SetObjectValueNoopHook::MakeReturn, + (void*)&SetObjectValueNoopHook::CallOriginal, + false + ); + + ASSERT_NE(secondHookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + obj->m_testValue = 0x9600; + + testing::internal::CaptureStdout(); + int result = HookedClass::SetObjectValue(obj, 0xDEADBEEF); + std::string output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "SupersedeSetObjectValue()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + + ASSERT_EQ(expected.str(), output) + << "Callback functions should be called in the correct order"; + } + + KHook::RemoveHook(firstHookId, false); + + obj->m_testValue = 0x9600; + + testing::internal::CaptureStdout(); + result = HookedClass::SetObjectValue(obj, 0xDEADBEEF); + output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "CallOriginal()" << std::endl; + expected << "HookedClass::SetObjectValue()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + + ASSERT_EQ(expected.str(), output) + << "Callback functions should be called in the correct order"; + } + + KHook::RemoveHook(secondHookId, false); + + obj->m_testValue = 0x9600; + + testing::internal::CaptureStdout(); + result = HookedClass::SetObjectValue(obj, 0xDEADBEEF); + output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(output.find("CallOriginal()"), std::string::npos) + << "CallOriginal() should not be called after all hooks removed"; + EXPECT_EQ(obj->m_testValue, 0xDEADBEEF) + << "Method should set value to original value after recall hook " + "removal"; + EXPECT_EQ(result, 0xDEADBEEF) + << "Method should return original value after recall hook removal"; +} + +TEST_F(StaticHookTest, HookIsAllowedInsideSetObjectValue) { + m_hookId = KHook::INVALID_HOOK; + + int hookId = KHook::SetupHook( + (void*)&HookedClass::SetObjectValue, + nullptr, + (void*)&SetObjectValueNoopHook::OnRemoved, + (void*)&FakeClass::HookInsideSetObjectValue, + (void*)&SetObjectValueNoopHook::PrePostNoop, + (void*)&SetObjectValueNoopHook::MakeReturn, + (void*)&SetObjectValueNoopHook::CallOriginal, + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + int firstResult = HookedClass::SetObjectValue(obj, 42); + + EXPECT_EQ(obj->m_testValue, 42) + << "SetObjectValue should set value to 42 (original behavior)"; + EXPECT_EQ(firstResult, 42) + << "SetObjectValue should return 42 (original value)"; + ASSERT_NE(m_hookId, KHook::INVALID_HOOK) + << "IsAllowed hook should have been set up inside SetObjectValue"; + + bool result = HookedClass::IsAllowed(obj); + EXPECT_TRUE(result) << "IsAllowed should return true (original value)"; +} \ No newline at end of file diff --git a/test/test_virtual_hook.cpp b/test/test_virtual_hook.cpp new file mode 100644 index 0000000..ba733ca --- /dev/null +++ b/test/test_virtual_hook.cpp @@ -0,0 +1,707 @@ +#include +#include +#include +#include +#include "helpers.hpp" + +template +class NoopMemberHookTemplate { + public: + NOINLINE Ret PrePostNoop(Args... args); + NOINLINE Ret CallOriginal(Args... args); + NOINLINE Ret MakeReturn(Args... args); + static NOINLINE void OnRemoved(int hookId); +}; + +template +NOINLINE Ret NoopMemberHookTemplate::PrePostNoop(Args... args) { + std::cout << "PrePostNoop()" << std::endl; + KHook::SaveReturnValue( + KHook::Action::Ignore, + nullptr, + 0, + nullptr, + nullptr, + false + ); + if constexpr (std::is_same::value) { + return; + } else { + return Ret(); + } +} + +template +NOINLINE Ret NoopMemberHookTemplate::CallOriginal(Args... args) { + std::cout << "CallOriginal()" << std::endl; + auto original = reinterpret_cast( + KHook::GetOriginalFunction() + ); + if constexpr (std::is_same::value) { + original(this, args...); + KHook::SaveReturnValue( + KHook::Action::Ignore, + nullptr, + 0, + nullptr, + nullptr, + true + ); + return; + } else { + Ret result = original(this, args...); + KHook::SaveReturnValue( + KHook::Action::Ignore, + &result, + sizeof(Ret), + (void*)KHook::init_operator, + (void*)KHook::deinit_operator, + true + ); + return result; + } +} + +template +NOINLINE Ret NoopMemberHookTemplate::MakeReturn(Args... args) { + std::cout << "MakeReturn()" << std::endl; + if constexpr (std::is_same::value) { + KHook::DestroyReturnValue(); + return; + } else { + Ret result = *((Ret*)KHook::GetCurrentValuePtr(true)); + KHook::DestroyReturnValue(); + return result; + } +} + +template +NOINLINE void NoopMemberHookTemplate::OnRemoved(int hookId) { + std::cout << "OnRemoved(" << std::dec << hookId << ")" << std::endl; +} + +class VirtualHookTests: public ::testing::Test { + protected: + class TestObject { + public: + int m_testValue; + }; + + class HookedClass { + public: + virtual bool IsAllowed(TestObject* obj) { + std::cout << "HookedClass::IsAllowed()" << std::endl; + return true; + } + + virtual int SetObjectValue(TestObject* obj, int value) { + std::cout << "HookedClass::SetObjectValue()" << std::endl; + obj->m_testValue = value; + return value; + } + + virtual void MyVoid(TestObject* obj) { + std::cout << "HookedClass::MyVoid()" << std::endl; + } + }; + + using IsAllowedNoopHook = NoopMemberHookTemplate; + using SetObjectValueNoopHook = NoopMemberHookTemplate; + using MyVoidNoopHook = NoopMemberHookTemplate; + + class FakeClass { + public: + NOINLINE bool OverrideIsAllowedReturnValue(TestObject* obj) { + std::cout << "OverrideReturnValue()" << std::endl; + bool result = false; + KHook::SaveReturnValue( + KHook::Action::Override, + &result, + sizeof(bool), + (void*)KHook::init_operator, + (void*)KHook::deinit_operator, + false + ); + return false; + } + + NOINLINE bool SupersedeIsAllowedReturnValue(TestObject* obj) { + std::cout << "SupersedeReturnValue()" << std::endl; + bool result = false; + KHook::SaveReturnValue( + KHook::Action::Supersede, + &result, + sizeof(bool), + (void*)KHook::init_operator, + (void*)KHook::deinit_operator, + false + ); + return false; + } + + NOINLINE int OverrideSetObjectValue(TestObject* obj, int value) { + auto recall = KHook::BuildMFP( + KHook::DoRecall( + KHook::Action::Ignore, + nullptr, + 0, + nullptr, + nullptr + ) + ); + (this->*recall)(obj, 1337); + return 0; + } + + NOINLINE int SupersedeSetObjectValue(TestObject* obj, int value) { + std::cout << "SupersedeSetObjectValue()" << std::endl; + int newValue = 9001; + KHook::SaveReturnValue( + KHook::Action::Supersede, + &newValue, + sizeof(int), + (void*)KHook::init_operator, + (void*)KHook::deinit_operator, + false + ); + return 0; + } + + NOINLINE int HookInsideSetObjectValue(TestObject* obj, int value) { + m_hookId = KHook::SetupVirtualHook( + *(void***)(this), + KHook::GetVtableIndex(&HookedClass::IsAllowed), + nullptr, + (void*)(&IsAllowedNoopHook::OnRemoved), + KHook::ExtractMFP(&IsAllowedNoopHook::PrePostNoop), + KHook::ExtractMFP(&IsAllowedNoopHook::PrePostNoop), + KHook::ExtractMFP(&IsAllowedNoopHook::MakeReturn), + KHook::ExtractMFP(&IsAllowedNoopHook::CallOriginal), + false + ); + KHook::SaveReturnValue( + KHook::Action::Ignore, + nullptr, + 0, + nullptr, + nullptr, + false + ); + return 0; + } + + NOINLINE void SupersedeMyVoid(TestObject* obj) { + std::cout << "SupersedeMyVoid()" << std::endl; + KHook::SaveReturnValue( + KHook::Action::Supersede, + nullptr, + 0, + nullptr, + nullptr, + false + ); + } + }; + + protected: + void SetUp() override { + target = new HookedClass(); + obj = new TestObject(); + } + + void TearDown() override { + if (obj) { + delete obj; + obj = nullptr; + } + if (target) { + delete target; + target = nullptr; + } + } + + HookedClass* target = nullptr; + TestObject* obj = nullptr; + static int m_hookId; +}; + +int VirtualHookTests::m_hookId = KHook::INVALID_HOOK; + +TEST_F(VirtualHookTests, Noop) { + int hookId = KHook::SetupVirtualHook( + *(void***)(target), + KHook::GetVtableIndex(&HookedClass::IsAllowed), + nullptr, + (void*)(&IsAllowedNoopHook::OnRemoved), + KHook::ExtractMFP(&IsAllowedNoopHook::PrePostNoop), + KHook::ExtractMFP(&IsAllowedNoopHook::PrePostNoop), + KHook::ExtractMFP(&IsAllowedNoopHook::MakeReturn), + KHook::ExtractMFP(&IsAllowedNoopHook::CallOriginal), + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + testing::internal::CaptureStdout(); + + bool overriddenResult = target->IsAllowed(obj); + + KHook::RemoveHook(hookId, false); + + bool originalResult = target->IsAllowed(obj); + + std::string output = testing::internal::GetCapturedStdout(); + + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "CallOriginal()" << std::endl; + expected << "HookedClass::IsAllowed()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + expected << "OnRemoved(" << std::dec << hookId << ")" << std::endl; + expected << "HookedClass::IsAllowed()" << std::endl; + + EXPECT_EQ(output, expected.str()) + << "Callback functions should be called in the expected order"; + EXPECT_TRUE(overriddenResult) + << "Method should return original value when hooked"; + EXPECT_TRUE(originalResult) + << "Method should return original value after hook removal"; +} + +TEST_F(VirtualHookTests, NoopVoid) { + int hookId = KHook::SetupVirtualHook( + *(void***)(target), + KHook::GetVtableIndex(&HookedClass::MyVoid), + nullptr, + (void*)(&MyVoidNoopHook::OnRemoved), + KHook::ExtractMFP(&MyVoidNoopHook::PrePostNoop), + KHook::ExtractMFP(&MyVoidNoopHook::PrePostNoop), + KHook::ExtractMFP(&MyVoidNoopHook::MakeReturn), + KHook::ExtractMFP(&MyVoidNoopHook::CallOriginal), + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + testing::internal::CaptureStdout(); + + target->MyVoid(obj); + + KHook::RemoveHook(hookId, false); + + target->MyVoid(obj); + + std::string output = testing::internal::GetCapturedStdout(); + + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "CallOriginal()" << std::endl; + expected << "HookedClass::MyVoid()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + expected << "OnRemoved(" << std::dec << hookId << ")" << std::endl; + expected << "HookedClass::MyVoid()" << std::endl; + + ASSERT_EQ(output, expected.str()) + << "Callback functions should be called in the expected order"; +} + +TEST_F(VirtualHookTests, OverrideReturnValuePre) { + int hookId = KHook::SetupVirtualHook( + *(void***)(target), + KHook::GetVtableIndex(&HookedClass::IsAllowed), + nullptr, + (void*)(&IsAllowedNoopHook::OnRemoved), + KHook::ExtractMFP(&FakeClass::OverrideIsAllowedReturnValue), + KHook::ExtractMFP(&IsAllowedNoopHook::PrePostNoop), + KHook::ExtractMFP(&IsAllowedNoopHook::MakeReturn), + KHook::ExtractMFP(&IsAllowedNoopHook::CallOriginal), + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + bool result = target->IsAllowed(obj); + EXPECT_FALSE(result) << "Method should return false when hooked"; + + KHook::RemoveHook(hookId, false); + + result = target->IsAllowed(obj); + EXPECT_TRUE(result) << "Method should return true after hook removal"; +} + +TEST_F(VirtualHookTests, OverrideReturnValuePost) { + int hookId = KHook::SetupVirtualHook( + *(void***)(target), + KHook::GetVtableIndex(&HookedClass::IsAllowed), + nullptr, + (void*)(&IsAllowedNoopHook::OnRemoved), + KHook::ExtractMFP(&IsAllowedNoopHook::PrePostNoop), + KHook::ExtractMFP(&FakeClass::OverrideIsAllowedReturnValue), + KHook::ExtractMFP(&IsAllowedNoopHook::MakeReturn), + KHook::ExtractMFP(&IsAllowedNoopHook::CallOriginal), + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + bool result = target->IsAllowed(obj); + EXPECT_FALSE(result) << "Method should return false when hooked"; + + KHook::RemoveHook(hookId, false); + + result = target->IsAllowed(obj); + EXPECT_TRUE(result) << "Method should return true after hook removal"; +} + +TEST_F(VirtualHookTests, SupersedeReturnValue) { + int hookId = KHook::SetupVirtualHook( + *(void***)(target), + KHook::GetVtableIndex(&HookedClass::IsAllowed), + nullptr, + (void*)(&IsAllowedNoopHook::OnRemoved), + KHook::ExtractMFP(&FakeClass::SupersedeIsAllowedReturnValue), + KHook::ExtractMFP(&IsAllowedNoopHook::PrePostNoop), + KHook::ExtractMFP(&IsAllowedNoopHook::MakeReturn), + KHook::ExtractMFP(&IsAllowedNoopHook::CallOriginal), + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + testing::internal::CaptureStdout(); + + bool overriddenResult = target->IsAllowed(obj); + + KHook::RemoveHook(hookId, false); + + bool originalResult = target->IsAllowed(obj); + + std::string output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(output.find("CallOriginal()"), std::string::npos) + << "Original method should not be called"; + EXPECT_FALSE(overriddenResult) << "Method should return false when hooked"; + EXPECT_TRUE(originalResult) + << "Method should return true after hook removal"; +} + +TEST_F(VirtualHookTests, SupersedeVoidReturnValue) { + int hookId = KHook::SetupVirtualHook( + *(void***)(target), + KHook::GetVtableIndex(&HookedClass::MyVoid), + nullptr, + (void*)(&MyVoidNoopHook::OnRemoved), + KHook::ExtractMFP(&FakeClass::SupersedeMyVoid), + KHook::ExtractMFP(&MyVoidNoopHook::PrePostNoop), + KHook::ExtractMFP(&MyVoidNoopHook::MakeReturn), + KHook::ExtractMFP(&MyVoidNoopHook::CallOriginal), + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + testing::internal::CaptureStdout(); + + target->MyVoid(obj); + + std::string output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "SupersedeMyVoid()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + ASSERT_EQ(expected.str(), output) + << "Callbacks should be called in the correct order"; + } + + KHook::RemoveHook(hookId, false); + + testing::internal::CaptureStdout(); + + target->MyVoid(obj); + + output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "HookedClass::MyVoid()" << std::endl; + ASSERT_EQ(expected.str(), output) + << "No callbacks should be called after hook removal"; + } +} + +TEST_F(VirtualHookTests, OverrideParameterWithRecall) { + int hookId = KHook::SetupVirtualHook( + *(void***)(target), + KHook::GetVtableIndex(&HookedClass::SetObjectValue), + nullptr, + (void*)(&SetObjectValueNoopHook::OnRemoved), + KHook::ExtractMFP(&FakeClass::OverrideSetObjectValue), + KHook::ExtractMFP(&SetObjectValueNoopHook::PrePostNoop), + KHook::ExtractMFP(&SetObjectValueNoopHook::MakeReturn), + KHook::ExtractMFP(&SetObjectValueNoopHook::CallOriginal), + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + int result = target->SetObjectValue(obj, 0xDEADBEEF); + EXPECT_EQ(obj->m_testValue, 1337) + << "Method should set value to hooked value"; + EXPECT_EQ(result, 1337) << "Method should return hooked value"; + + KHook::RemoveHook(hookId, false); + + result = target->SetObjectValue(obj, 0xDEADBEEF); + EXPECT_EQ(obj->m_testValue, 0xDEADBEEF) + << "Method should set value to original value after hook removal"; + EXPECT_EQ(result, 0xDEADBEEF) + << "Method should return original value after hook removal"; +} + +TEST_F(VirtualHookTests, SupersedeThenNoopVoidHooksOnSameIndex) { + int firstHookId = KHook::SetupVirtualHook( + *(void***)(target), + KHook::GetVtableIndex(&HookedClass::MyVoid), + nullptr, + (void*)(&MyVoidNoopHook::OnRemoved), + KHook::ExtractMFP(&FakeClass::SupersedeMyVoid), + KHook::ExtractMFP(&MyVoidNoopHook::PrePostNoop), + KHook::ExtractMFP(&MyVoidNoopHook::MakeReturn), + KHook::ExtractMFP(&MyVoidNoopHook::CallOriginal), + false + ); + + ASSERT_NE(firstHookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + int secondHookId = KHook::SetupVirtualHook( + *(void***)(target), + KHook::GetVtableIndex(&HookedClass::MyVoid), + nullptr, + (void*)(&MyVoidNoopHook::OnRemoved), + KHook::ExtractMFP(&MyVoidNoopHook::PrePostNoop), + KHook::ExtractMFP(&MyVoidNoopHook::PrePostNoop), + KHook::ExtractMFP(&MyVoidNoopHook::MakeReturn), + KHook::ExtractMFP(&MyVoidNoopHook::CallOriginal), + false + ); + + ASSERT_NE(secondHookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + testing::internal::CaptureStdout(); + target->MyVoid(obj); + std::string output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "SupersedeMyVoid()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + + ASSERT_EQ(expected.str(), output) + << "Callback functions should be called in the correct order"; + } + + KHook::RemoveHook(firstHookId, false); + + testing::internal::CaptureStdout(); + target->MyVoid(obj); + output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "CallOriginal()" << std::endl; + expected << "HookedClass::MyVoid()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + + ASSERT_EQ(expected.str(), output) + << "Callback functions should be called in the correct order"; + } +} + +TEST_F(VirtualHookTests, MultipleNoopVoidHooksOnSameIndex) { + int firstHookId = KHook::SetupVirtualHook( + *(void***)(target), + KHook::GetVtableIndex(&HookedClass::MyVoid), + nullptr, + (void*)(&MyVoidNoopHook::OnRemoved), + KHook::ExtractMFP(&MyVoidNoopHook::PrePostNoop), + KHook::ExtractMFP(&MyVoidNoopHook::PrePostNoop), + KHook::ExtractMFP(&MyVoidNoopHook::MakeReturn), + KHook::ExtractMFP(&MyVoidNoopHook::CallOriginal), + false + ); + + ASSERT_NE(firstHookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + int secondHookId = KHook::SetupVirtualHook( + *(void***)(target), + KHook::GetVtableIndex(&HookedClass::MyVoid), + nullptr, + (void*)(&MyVoidNoopHook::OnRemoved), + KHook::ExtractMFP(&MyVoidNoopHook::PrePostNoop), + KHook::ExtractMFP(&MyVoidNoopHook::PrePostNoop), + KHook::ExtractMFP(&MyVoidNoopHook::MakeReturn), + KHook::ExtractMFP(&MyVoidNoopHook::CallOriginal), + false + ); + + ASSERT_NE(secondHookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + testing::internal::CaptureStdout(); + target->MyVoid(obj); + std::string output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "CallOriginal()" << std::endl; + expected << "HookedClass::MyVoid()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + + ASSERT_EQ(expected.str(), output) + << "Callback functions should be called in the correct order"; + } + + KHook::RemoveHook(firstHookId, false); + + testing::internal::CaptureStdout(); + target->MyVoid(obj); + output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "CallOriginal()" << std::endl; + expected << "HookedClass::MyVoid()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + + ASSERT_EQ(expected.str(), output) + << "Callback functions should be called in the correct order"; + } +} + +TEST_F(VirtualHookTests, SupersedeThenNoopHooksOnSameIndex) { + int firstHookId = KHook::SetupVirtualHook( + *(void***)(target), + KHook::GetVtableIndex(&HookedClass::SetObjectValue), + nullptr, + (void*)(&SetObjectValueNoopHook::OnRemoved), + KHook::ExtractMFP(&FakeClass::SupersedeSetObjectValue), + KHook::ExtractMFP(&SetObjectValueNoopHook::PrePostNoop), + KHook::ExtractMFP(&SetObjectValueNoopHook::MakeReturn), + KHook::ExtractMFP(&SetObjectValueNoopHook::CallOriginal), + false + ); + + ASSERT_NE(firstHookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + int secondHookId = KHook::SetupVirtualHook( + *(void***)(target), + KHook::GetVtableIndex(&HookedClass::SetObjectValue), + nullptr, + (void*)(&SetObjectValueNoopHook::OnRemoved), + KHook::ExtractMFP(&SetObjectValueNoopHook::PrePostNoop), + KHook::ExtractMFP(&SetObjectValueNoopHook::PrePostNoop), + KHook::ExtractMFP(&SetObjectValueNoopHook::MakeReturn), + KHook::ExtractMFP(&SetObjectValueNoopHook::CallOriginal), + false + ); + + ASSERT_NE(secondHookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + obj->m_testValue = 0x9600; + + testing::internal::CaptureStdout(); + int result = target->SetObjectValue(obj, 0xDEADBEEF); + std::string output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "SupersedeSetObjectValue()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + + ASSERT_EQ(expected.str(), output) + << "Callback functions should be called in the correct order"; + } + + KHook::RemoveHook(firstHookId, false); + + obj->m_testValue = 0x9600; + + testing::internal::CaptureStdout(); + result = target->SetObjectValue(obj, 0xDEADBEEF); + output = testing::internal::GetCapturedStdout(); + + { + std::ostringstream expected; + expected << "PrePostNoop()" << std::endl; + expected << "CallOriginal()" << std::endl; + expected << "HookedClass::SetObjectValue()" << std::endl; + expected << "PrePostNoop()" << std::endl; + expected << "MakeReturn()" << std::endl; + + ASSERT_EQ(expected.str(), output) + << "Callback functions should be called in the correct order"; + } + + KHook::RemoveHook(secondHookId, false); + + obj->m_testValue = 0x9600; + + testing::internal::CaptureStdout(); + result = target->SetObjectValue(obj, 0xDEADBEEF); + output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(output.find("CallOriginal()"), std::string::npos) + << "CallOriginal() should not be called after all hooks removed"; + EXPECT_EQ(obj->m_testValue, 0xDEADBEEF) + << "Method should set value to original value after recall hook " + "removal"; + EXPECT_EQ(result, 0xDEADBEEF) + << "Method should return original value after recall hook removal"; +} + +TEST_F(VirtualHookTests, HookIsAllowedInsideSetObjectValue) { + m_hookId = KHook::INVALID_HOOK; + + int hookId = KHook::SetupVirtualHook( + *(void***)(target), + KHook::GetVtableIndex(&HookedClass::SetObjectValue), + nullptr, + (void*)(&SetObjectValueNoopHook::OnRemoved), + KHook::ExtractMFP(&FakeClass::HookInsideSetObjectValue), + KHook::ExtractMFP(&SetObjectValueNoopHook::PrePostNoop), + KHook::ExtractMFP(&SetObjectValueNoopHook::MakeReturn), + KHook::ExtractMFP(&SetObjectValueNoopHook::CallOriginal), + false + ); + + ASSERT_NE(hookId, KHook::INVALID_HOOK) << "Hook setup should succeed"; + + int firstResult = target->SetObjectValue(obj, 42); + + EXPECT_EQ(obj->m_testValue, 42) + << "SetObjectValue should set value to 42 (original behavior)"; + EXPECT_EQ(firstResult, 42) + << "SetObjectValue should return 42 (original value)"; + ASSERT_NE(m_hookId, KHook::INVALID_HOOK) + << "IsAllowed hook should have been set up inside SetObjectValue"; + + bool result = target->IsAllowed(obj); + EXPECT_TRUE(result) << "IsAllowed should return true (original value)"; +} \ No newline at end of file diff --git a/test/test_virtuals.cpp b/test/test_virtuals.cpp new file mode 100644 index 0000000..c15cfda --- /dev/null +++ b/test/test_virtuals.cpp @@ -0,0 +1,293 @@ +#include +#include +#include +#include "helpers.hpp" + +class VirtualTests : public ::testing::Test { +protected: + class BaseClass { + public: + virtual void TargetMethod(int& value) { + value = 0xBEEFDEAD; + } + }; + + class TargetClass : public BaseClass { + public: + virtual void TargetMethod(int& value) override { + value = 0xDEADBEEF; + } + }; + + typedef VirtualContext ContextType; + + static KHook::Return HookMethod(BaseClass* _this, int& value) { + std::cout << "Here"; + return {KHook::Action::Ignore}; + } + + static KHook::Return HookMethodSupersede(BaseClass* _this, int& value) { + return {KHook::Action::Supersede}; + } +}; + +TEST_F(VirtualTests, PreCallback) { + std::unique_ptr instance = std::make_unique(); + + KHook::Virtual hook(&BaseClass::TargetMethod, &HookMethod, nullptr); + hook.Add(instance.get()); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + instance->TargetMethod(value); + + std::string output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + EXPECT_EQ(output, "Here") << "No response from callback"; +} + +TEST_F(VirtualTests, PostCallback) { + std::unique_ptr instance = std::make_unique(); + + KHook::Virtual hook(&BaseClass::TargetMethod, nullptr, &HookMethod); + hook.Add(instance.get()); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + instance->TargetMethod(value); + + std::string output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + EXPECT_EQ(output, "Here") << "No response from callback"; +} + +TEST_F(VirtualTests, PreAndPostCallbacks) { + std::unique_ptr instance = std::make_unique(); + + KHook::Virtual hook(&BaseClass::TargetMethod, &HookMethod, &HookMethod); + hook.Add(instance.get()); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + instance->TargetMethod(value); + + std::string output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + EXPECT_EQ(output, "HereHere") << "No response from callbacks"; +} + +TEST_F(VirtualTests, SupersedePreCallback) { + std::unique_ptr instance = std::make_unique(); + + KHook::Virtual hook(&BaseClass::TargetMethod, &HookMethodSupersede, nullptr); + hook.Add(instance.get()); + + int value = 0x1337; + instance->TargetMethod(value); + + EXPECT_EQ(value, 0x1337) << "Callback did not supersede original method"; +} + +TEST_F(VirtualTests, InstanceOnly) { + std::unique_ptr instance1 = std::make_unique(); + std::unique_ptr instance2 = std::make_unique(); + + KHook::Virtual hook(&BaseClass::TargetMethod, &HookMethodSupersede, nullptr); + hook.Add(instance1.get()); + + testing::internal::CaptureStdout(); + + int value1 = 0x1337; + int value2 = 0x1337; + instance1->TargetMethod(value1); + instance2->TargetMethod(value2); + + std::string output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value1, 0x1337) << "Behaviour of hooked instance was modified"; + EXPECT_EQ(value2, 0xDEADBEEF) << "Behaviour of unhooked instance was modified"; +} + +TEST_F(VirtualTests, ContextPreCallback) { + std::unique_ptr instance = std::make_unique(); + + ContextType context(&HookMethod); + KHook::Virtual hook(&BaseClass::TargetMethod, &context, &ContextType::OnPre, nullptr); + hook.Add(instance.get()); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + instance->TargetMethod(value); + + std::string output = testing::internal::GetCapturedStdout(); + + ASSERT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + ASSERT_EQ(output, "Here") << "No response from callback"; + + hook.RemoveContext(&context); + + testing::internal::CaptureStdout(); + + value = 0x1337; + instance->TargetMethod(value); + + output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + EXPECT_EQ(output, "") << "Unexpected response from callback after context removal"; +} + +TEST_F(VirtualTests, ContextPostCallback) { + std::unique_ptr instance = std::make_unique(); + + ContextType context(&HookMethod); + KHook::Virtual hook(&BaseClass::TargetMethod, &context, nullptr, &ContextType::OnPost); + hook.Add(instance.get()); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + instance->TargetMethod(value); + + std::string output = testing::internal::GetCapturedStdout(); + + ASSERT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + ASSERT_EQ(output, "Here") << "No response from callback"; + + hook.RemoveContext(&context); + + testing::internal::CaptureStdout(); + + value = 0x1337; + instance->TargetMethod(value); + + output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + EXPECT_EQ(output, "") << "Unexpected response from callback after context removal"; +} + +TEST_F(VirtualTests, ContextPreAndPostCallbacks) { + std::unique_ptr instance = std::make_unique(); + + ContextType context(&HookMethod); + KHook::Virtual hook(&BaseClass::TargetMethod, &context, &ContextType::OnPre, &ContextType::OnPost); + hook.Add(instance.get()); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + instance->TargetMethod(value); + + std::string output = testing::internal::GetCapturedStdout(); + + ASSERT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + ASSERT_EQ(output, "HereHere") << "No response from callbacks"; + + hook.RemoveContext(&context); + + testing::internal::CaptureStdout(); + + value = 0x1337; + instance->TargetMethod(value); + + output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked method was modified"; + EXPECT_EQ(output, "") << "Unexpected response from callbacks after context removal"; +} + +TEST_F(VirtualTests, ContextSupersedePreCallback) { + std::unique_ptr instance = std::make_unique(); + + ContextType context(&HookMethodSupersede); + KHook::Virtual hook(&BaseClass::TargetMethod, &context, &ContextType::OnPre, nullptr); + hook.Add(instance.get()); + + int value = 0x1337; + instance->TargetMethod(value); + + ASSERT_EQ(value, 0x1337) << "Callback did not supersede original method"; + + hook.RemoveContext(&context); + + value = 0x1337; + instance->TargetMethod(value); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of unhooked method was modified"; +} + +TEST_F(VirtualTests, ContextInstanceOnly) { + std::unique_ptr instance1 = std::make_unique(); + std::unique_ptr instance2 = std::make_unique(); + + ContextType context(&HookMethodSupersede); + KHook::Virtual hook(&BaseClass::TargetMethod, &context, &ContextType::OnPre, nullptr); + hook.Add(instance1.get()); + + int value1 = 0x1337; + int value2 = 0x1337; + instance1->TargetMethod(value1); + instance2->TargetMethod(value2); + + ASSERT_EQ(value1, 0x1337) << "Callback did not supersede original method"; + ASSERT_EQ(value2, 0xDEADBEEF) << "Behaviour of unhooked instance was modified"; + + hook.RemoveContext(&context); + + value1 = 0x1337; + value2 = 0x1337; + instance1->TargetMethod(value1); + instance2->TargetMethod(value2); + + EXPECT_EQ(value1, 0xDEADBEEF) << "Behaviour of unhooked instance was modified after context removal"; + EXPECT_EQ(value2, 0xDEADBEEF) << "Behaviour of unhooked instance was modified after context removal"; +} + +TEST_F(VirtualTests, MultipleContextsPreCallback) { + std::unique_ptr instance = std::make_unique(); + + ContextType context1(&HookMethodSupersede); + ContextType context2(&HookMethod); + KHook::Virtual hook(&BaseClass::TargetMethod, &context1, &ContextType::OnPre, nullptr); + hook.Add(instance.get()); + hook.AddContext(&context2, &ContextType::OnPre, nullptr); + + testing::internal::CaptureStdout(); + + int value = 0x1337; + instance->TargetMethod(value); + std::string output = testing::internal::GetCapturedStdout(); + + ASSERT_EQ(value, 0x1337) << "Behaviour of hooked instance was modified"; + ASSERT_EQ(output, "Here") << "Incorrect response from multiple contexts"; + + hook.RemoveContext(&context1); + + testing::internal::CaptureStdout(); + + value = 0x1337; + instance->TargetMethod(value); + output = testing::internal::GetCapturedStdout(); + + ASSERT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked instance was modified after context removal"; + ASSERT_EQ(output, "Here") << "Incorrect response from remaining context after context removal"; + + hook.RemoveContext(&context2); + + testing::internal::CaptureStdout(); + + value = 0x1337; + instance->TargetMethod(value); + output = testing::internal::GetCapturedStdout(); + + EXPECT_EQ(value, 0xDEADBEEF) << "Behaviour of hooked instance was modified after context removal"; + EXPECT_EQ(output, "") << "Unexpected response from callbacks after context removal"; +} \ No newline at end of file diff --git a/third_party/googletest b/third_party/googletest new file mode 160000 index 0000000..73a63ea --- /dev/null +++ b/third_party/googletest @@ -0,0 +1 @@ +Subproject commit 73a63ea05dc8ca29ec1d2c1d66481dd0de1950f1 diff --git a/third_party/gtest-parallel b/third_party/gtest-parallel new file mode 160000 index 0000000..cd488bd --- /dev/null +++ b/third_party/gtest-parallel @@ -0,0 +1 @@ +Subproject commit cd488bdedc1d2cffb98201a17afc1b298b0b90f1 diff --git a/third_party/safetyhook b/third_party/safetyhook index 8c6692c..ec3f698 160000 --- a/third_party/safetyhook +++ b/third_party/safetyhook @@ -1 +1 @@ -Subproject commit 8c6692c85a6c41f5d89f744da57b5ba43515b4ec +Subproject commit ec3f698a1d9936d72c57c639536fbbedab6d7c8a