Counter Strike : Global Offensive Source Code
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

276 lines
8.4 KiB

  1. //========= Copyright � 1996-2002, Valve LLC, All rights reserved. ============
  2. //
  3. // Purpose: Function Detouring code used by the overlay
  4. //
  5. // $NoKeywords: $
  6. //=============================================================================
  7. #ifndef DETOURFUNC_H
  8. #define DETOURFUNC_H
  9. #ifdef _WIN32
  10. #pragma once
  11. #endif
  12. void * HookFunc( BYTE *pRealFunctionAddr, const BYTE *pHookFunctionAddr, int nJumpsToFollowBeforeHooking = 0 );
  13. bool HookFuncSafe( BYTE *pRealFunctionAddr, const BYTE *pHookFunctionAddr, void ** ppRelocFunctionAddr, int nJumpsToFollowBeforeHooking = 0 );
  14. bool bIsFuncHooked( BYTE *pRealFunctionAddr, void *pHookFunc = NULL );
  15. void UnhookFunc( BYTE *pRealFunctionAddr, BYTE *pOriginalFunctionAddr_DEPRECATED );
  16. void UnhookFunc( BYTE *pRealFunctionAddr, bool bLogFailures = true );
  17. void UnhookFuncByRelocAddr( BYTE *pRelocFunctionAddr, bool bLogFailures = true );
  18. void RegregisterTrampolines();
  19. void DetectUnloadedHooks();
  20. #if defined( _WIN32 ) && DEBUG_ENABLE_DETOUR_RECORDING
  21. template <typename T, int k_nCountElements >
  22. class CCallRecordSet
  23. {
  24. public:
  25. typedef T ElemType_t;
  26. CCallRecordSet()
  27. {
  28. m_cElements = 0;
  29. m_cElementPostWrite = 0;
  30. m_cElementMax = k_nCountElements;
  31. m_cubElements = sizeof(m_rgElements);
  32. memset( m_rgElements, 0, sizeof(m_rgElements) );
  33. }
  34. // if return value is >= 0, then we matched an existing record
  35. int AddFunctionCallRecord( const ElemType_t &fcr )
  36. {
  37. // if we are full, dont bother searching any more
  38. // this reduces our perf impact to near zero if these functions are
  39. // called a lot more than we expect
  40. int cElements = m_cElements;
  41. if ( cElements >= k_nCountElements )
  42. {
  43. return -2;
  44. }
  45. // search backwards through the log
  46. for( int i = cElements-1; i >= 0; i-- )
  47. {
  48. if ( m_rgElements[i] == fcr )
  49. return i;
  50. }
  51. cElements = ++m_cElements;
  52. if ( cElements <= k_nCountElements )
  53. {
  54. m_rgElements[cElements-1] = fcr;
  55. }
  56. // if an external reader sees m_cElements != m_cElementPostWrite
  57. // they know the last item(s) may not be complete
  58. m_cElementPostWrite++;
  59. return -1;
  60. }
  61. CInterlockedIntT< int > m_cElements;
  62. CInterlockedIntT< int > m_cElementPostWrite;
  63. int m_cElementMax;
  64. int m_cubElements;
  65. ElemType_t m_rgElements[k_nCountElements];
  66. };
  67. class CRecordDetouredCalls
  68. {
  69. public:
  70. CRecordDetouredCalls();
  71. void SetMasterSwitchOn() { m_bMasterSwitch = true; }
  72. bool BIsMasterSwitchOn() { return m_bMasterSwitch; }
  73. bool BShouldRecordProtectFlags( DWORD flProtect );
  74. void RecordGetAsyncKeyState( DWORD vKey,
  75. PVOID lpCallersAddress, PVOID lpCallersCallerAddress
  76. );
  77. void RecordVirtualAlloc( LPVOID lpAddress, SIZE_T dwSize, DWORD flAllocationType, DWORD flProtect,
  78. LPVOID lpvResult, DWORD dwGetLastError,
  79. PVOID lpCallersAddress, PVOID lpCallersCallerAddress
  80. );
  81. void RecordVirtualProtect( LPVOID lpAddress, SIZE_T dwSize, DWORD flNewProtect, DWORD flOldProtect,
  82. BOOL bResult, DWORD dwGetLastError,
  83. PVOID lpCallersAddress, PVOID lpCallersCallerAddress
  84. );
  85. void RecordVirtualAllocEx( HANDLE hProcess, LPVOID lpAddress, SIZE_T dwSize, DWORD flAllocationType, DWORD flProtect,
  86. LPVOID lpvResult, DWORD dwGetLastError,
  87. PVOID lpCallersAddress, PVOID lpCallersCallerAddress
  88. );
  89. void RecordVirtualProtectEx( HANDLE hProcess, LPVOID lpAddress, SIZE_T dwSize, DWORD flNewProtect, DWORD flOldProtect,
  90. BOOL bResult, DWORD dwGetLastError,
  91. PVOID lpCallersAddress, PVOID lpCallersCallerAddress
  92. );
  93. void RecordLoadLibraryW(
  94. LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags,
  95. HMODULE hModule, DWORD dwGetLastError,
  96. PVOID lpCallersAddress, PVOID lpCallersCallerAddress
  97. );
  98. void RecordLoadLibraryA(
  99. LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags,
  100. HMODULE hModule, DWORD dwGetLastError,
  101. PVOID lpCallersAddress, PVOID lpCallersCallerAddress
  102. );
  103. private:
  104. struct FunctionCallRecordBase_t
  105. {
  106. void SharedInit(
  107. DWORD dwResult, DWORD dwGetLastError,
  108. PVOID lpCallersAddress, PVOID lpCallersCallerAddress
  109. );
  110. DWORD m_dwResult;
  111. DWORD m_dwGetLastError;
  112. LPVOID m_lpFirstCallersAddress;
  113. LPVOID m_lpLastCallerAddress;
  114. };
  115. // for GetAsyncKeyState the only thing we care about is the call site
  116. // dont care about results or params
  117. struct GetAsyncKeyStateCallRecord_t : public FunctionCallRecordBase_t
  118. {
  119. GetAsyncKeyStateCallRecord_t()
  120. {}
  121. void InitGetAsyncKeyState( DWORD vKey,
  122. PVOID lpCallersAddress, PVOID lpCallersCallerAddress
  123. );
  124. bool operator==( const FunctionCallRecordBase_t &rhs ) const
  125. {
  126. // compare callers only, dont care about results or params
  127. return
  128. m_lpFirstCallersAddress == rhs.m_lpFirstCallersAddress &&
  129. m_lpLastCallerAddress == rhs.m_lpLastCallerAddress;
  130. }
  131. };
  132. struct VirtualAllocCallRecord_t : public FunctionCallRecordBase_t
  133. {
  134. VirtualAllocCallRecord_t()
  135. {}
  136. // VirtualAlloc
  137. void InitVirtualAlloc( LPVOID lpAddress, SIZE_T dwSize, DWORD flAllocationType, DWORD flProtect,
  138. LPVOID lpvResult, DWORD dwGetLastError,
  139. PVOID lpCallersAddress, PVOID lpCallersCallerAddress
  140. );
  141. // VirtualAllocEx
  142. void InitVirtualAllocEx( HANDLE hProcess, LPVOID lpAddress, SIZE_T dwSize, DWORD flAllocationType, DWORD flProtect,
  143. LPVOID lpvResult, DWORD dwGetLastError,
  144. PVOID lpCallersAddress, PVOID lpCallersCallerAddress
  145. );
  146. // VirtualProtect
  147. void InitVirtualProtect( LPVOID lpAddress, SIZE_T dwSize, DWORD flNewProtect, DWORD flOldProtect,
  148. BOOL bResult, DWORD dwGetLastError,
  149. PVOID lpCallersAddress, PVOID lpCallersCallerAddress
  150. );
  151. // VirtualProtectEx
  152. void InitVirtualProtectEx( HANDLE hProcess, LPVOID lpAddress, SIZE_T dwSize, DWORD flNewProtect, DWORD flOldProtect,
  153. BOOL bResult, DWORD dwGetLastError,
  154. PVOID lpCallersAddress, PVOID lpCallersCallerAddress
  155. );
  156. bool operator==( const VirtualAllocCallRecord_t &rhs ) const
  157. {
  158. // compare everything
  159. return
  160. m_dwResult == rhs.m_dwResult &&
  161. m_dwGetLastError == rhs.m_dwGetLastError &&
  162. m_dwProcessId == rhs.m_dwProcessId &&
  163. m_lpAddress == rhs.m_lpAddress &&
  164. m_dwSize == rhs.m_dwSize &&
  165. m_flProtect == rhs.m_flProtect &&
  166. m_dw2 == rhs.m_dw2 &&
  167. m_lpFirstCallersAddress == rhs.m_lpFirstCallersAddress &&
  168. m_lpLastCallerAddress == rhs.m_lpLastCallerAddress;
  169. }
  170. DWORD m_dwProcessId;
  171. LPVOID m_lpAddress;
  172. SIZE_T m_dwSize;
  173. DWORD m_flProtect;
  174. DWORD m_dw2;
  175. };
  176. // for LoadLibrary just log everything, params and call sites
  177. struct LoadLibraryCallRecord_t : public FunctionCallRecordBase_t
  178. {
  179. LoadLibraryCallRecord_t() {}
  180. // LoadLibraryExW or LoadLibraryW
  181. void InitLoadLibraryW(
  182. LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags,
  183. HMODULE hModule, DWORD dwGetLastError,
  184. PVOID lpCallersAddress, PVOID lpCallersCallerAddress
  185. );
  186. void InitLoadLibraryA(
  187. LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags,
  188. HMODULE hModule, DWORD dwGetLastError,
  189. PVOID lpCallersAddress, PVOID lpCallersCallerAddress
  190. );
  191. bool operator==( const LoadLibraryCallRecord_t &rhs ) const
  192. {
  193. // compare the result ( hModule ) but not the callers
  194. // we arent going to have a perfect history of every caller
  195. if ( m_dwResult != rhs.m_dwResult )
  196. {
  197. return false;
  198. }
  199. // and then what we have of the actual filename
  200. return ( memcmp( m_rgubFileName, &rhs.m_rgubFileName, sizeof(m_rgubFileName) ) == 0 );
  201. }
  202. uint8 m_rgubFileName[128];
  203. HANDLE m_hFile;
  204. DWORD m_dwFlags;
  205. };
  206. // These GUIDs are constants, and it is how we find this structure when looking through the data section
  207. // when we are trying to read this data with an external process
  208. GUID m_guidMarkerBegin;
  209. // some helpers for parsing the structure externally
  210. int m_nVersionNumber;
  211. int m_cubRecordDetouredCalls;
  212. int m_cubGetAsyncKeyStateCallRecord;
  213. int m_cubVirtualAllocCallRecord;
  214. int m_cubVirtualProtectCallRecord;
  215. int m_cubLoadLibraryCallRecord;
  216. // these numbers were chosen by profiling CS:GO a bunch
  217. CCallRecordSet< GetAsyncKeyStateCallRecord_t, 50 > m_GetAsyncKeyStateCallRecord;
  218. CCallRecordSet< VirtualAllocCallRecord_t, 300 > m_VirtualAllocCallRecord;
  219. CCallRecordSet< VirtualAllocCallRecord_t, 500 > m_VirtualProtectCallRecord;
  220. CCallRecordSet< LoadLibraryCallRecord_t, 200 > m_LoadLibraryCallRecord;
  221. bool m_bMasterSwitch;
  222. // These GUIDs are constants, and it is how we find this structure when looking through the data section
  223. GUID m_guidMarkerEnd;
  224. };
  225. extern CRecordDetouredCalls g_RecordDetouredCalls;
  226. typedef PVOID (WINAPI *RtlGetCallersAddress_t)( PVOID *CallersAddress, PVOID *CallersCaller );
  227. extern RtlGetCallersAddress_t g_pRtlGetCallersAddress;
  228. #endif // _WIN32
  229. #endif // DETOURFUNC_H