Source code of Windows XP (NT5)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

577 lines
17 KiB

  1. ////////////////////////////////////////////////////////////////////////////////////
  2. //
  3. // File: shim2.c
  4. //
  5. // History: May-99 clupu Created.
  6. // Aug-99 v-johnwh Various bug fixes.
  7. // 23-Nov-99 markder Support for multiple shim DLLs, chaining
  8. // of hooks, DLL loads/unloads. General clean-up.
  9. // Jan-00 markder Windows 9x support added.
  10. // Mar-00 a-batjar Changed to support whistler format on w2k
  11. // May-00 v-johnwh Modified to work in the profiler
  12. // Desc: Contains all code to facilitate hooking of APIs by replacing entries
  13. // in the import tables of loaded modules.
  14. //
  15. ////////////////////////////////////////////////////////////////////////////////////
  16. #include <windows.h>
  17. #include <stdlib.h>
  18. #include <psapi.h>
  19. #include <tlhelp32.h>
  20. #include <imagehlp.h>
  21. #include <stdio.h>
  22. #include "shimdb.h"
  23. #include "shim2.h"
  24. #define HAF_RESOLVED 0x0001
  25. #define HAF_BOTTOM_OF_CHAIN 0x0002
  26. typedef PHOOKAPI (*PFNNEWGETHOOKAPIS)(DWORD dwGetProcAddress, DWORD dwLoadLibraryA, DWORD dwFreeLibrary, DWORD* pdwHookAPICount);
  27. typedef LPSTR (*PFNGETCOMMANDLINEA)(VOID);
  28. typedef LPWSTR (*PFNGETCOMMANDLINEW)(VOID);
  29. typedef PVOID (*PFNGETPROCADDRESS)(HMODULE hMod, char* pszProc);
  30. typedef HINSTANCE (*PFNLOADLIBRARYA)(LPCSTR lpLibFileName);
  31. typedef HINSTANCE (*PFNLOADLIBRARYW)(LPCWSTR lpLibFileName);
  32. typedef HINSTANCE (*PFNLOADLIBRARYEXA)(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags);
  33. typedef HINSTANCE (*PFNLOADLIBRARYEXW)(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags);
  34. typedef BOOL (*PFNFREELIBRARY)(HMODULE hLibModule);
  35. // Global Variables
  36. // Disable build warnings due to Print macro in free builds
  37. #pragma warning( disable : 4002 )
  38. #define MAX_MODULES 512
  39. #define SHIM_GETHOOKAPIS "GetHookAPIs"
  40. ////////////////////////////////////////////////////////////////////////////////////
  41. // API hook count & indices
  42. ////////////////////////////////////////////////////////////////////////////////////
  43. enum
  44. {
  45. hookGetProcAddress,
  46. hookLoadLibraryA,
  47. hookLoadLibraryW,
  48. hookLoadLibraryExA,
  49. hookLoadLibraryExW,
  50. hookFreeLibrary,
  51. hookGetCommandLineA,
  52. hookGetCommandLineW
  53. };
  54. ////////////////////////////////////////////////////////////////////////////////////
  55. // Global variables
  56. ////////////////////////////////////////////////////////////////////////////////////
  57. // This array contains information used by the shim mechanism to describe
  58. // what API to hook with a particular stub function.
  59. LONG g_nShimDllCount;
  60. HMODULE g_hShimDlls[MAX_MODULES];
  61. PHOOKAPI g_rgpHookAPIs[MAX_MODULES];
  62. LONG g_rgnHookAPICount[MAX_MODULES];
  63. LPTSTR g_rgnHookDllList[MAX_MODULES];
  64. HMODULE g_hHookedModules[MAX_MODULES];
  65. LONG g_nHookedModuleCount;
  66. extern BOOL g_bIsWin9X;
  67. HANDLE g_hSnapshot = NULL;
  68. HANDLE g_hValidationSnapshot = NULL;
  69. ////////////////////////////////////////////////////////////////////////////////////
  70. //
  71. // Func: ValidateAddress
  72. //
  73. // Params: pfnOld Original API function pointer to validate.
  74. //
  75. // Return: Potentially massaged pfnOld.
  76. //
  77. // Desc: Win9x thunks system API entry points for some reason. The
  78. // shim mechanism has to work around this to get to the
  79. // 'real' pointer so that it can make valid comparisons.
  80. //
  81. PVOID ValidateAddress( PVOID pfnOld )
  82. {
  83. MODULEENTRY32 ModuleEntry32;
  84. BOOL bRet;
  85. long i, j;
  86. // Make sure the address isn't a shim thunk
  87. for( i = g_nShimDllCount - 1; i >= 0; i-- )
  88. {
  89. for( j = 0; j < g_rgnHookAPICount[i]; j++ )
  90. {
  91. if( g_rgpHookAPIs[i][j].pfnOld == pfnOld )
  92. {
  93. if( pfnOld == g_rgpHookAPIs[i][j].pfnNew )
  94. return pfnOld;
  95. }
  96. }
  97. }
  98. ModuleEntry32.dwSize = sizeof( ModuleEntry32 );
  99. bRet = Module32First( g_hValidationSnapshot, &ModuleEntry32 );
  100. while( bRet )
  101. {
  102. if( pfnOld >= (PVOID) ModuleEntry32.modBaseAddr &&
  103. pfnOld <= (PVOID) ( ModuleEntry32.modBaseAddr + ModuleEntry32.modBaseSize ) )
  104. {
  105. return pfnOld;
  106. }
  107. bRet = Module32Next( g_hValidationSnapshot, &ModuleEntry32 );
  108. }
  109. // Hack for Win9x
  110. return *(PVOID *)( ((PBYTE)pfnOld)+1);
  111. }
  112. ////////////////////////////////////////////////////////////////////////////////////
  113. //
  114. // Func: ConstructChain
  115. //
  116. // Params: pfnOld Original API function pointer to resolve.
  117. //
  118. // Return: Top-of-chain PHOOKAPI structure.
  119. //
  120. // Desc: Scans HookAPI arrays for pfnOld and either constructs the
  121. // chain or returns the top-of-chain PHOOKAPI if the chain
  122. // already exists.
  123. //
  124. PHOOKAPI ConstructChain( PVOID pfnOld ,DWORD* DllListIndex)
  125. {
  126. LONG i, j;
  127. PHOOKAPI pTopHookAPI;
  128. PHOOKAPI pBottomHookAPI;
  129. pTopHookAPI = NULL;
  130. pBottomHookAPI = NULL;
  131. *DllListIndex=0;
  132. // Scan all HOOKAPI entries for corresponding function pointer
  133. for( i = g_nShimDllCount - 1; i >= 0; i-- )
  134. {
  135. for( j = 0; j < g_rgnHookAPICount[i]; j++ )
  136. {
  137. if( g_rgpHookAPIs[i][j].pfnOld == pfnOld )
  138. {
  139. if( pTopHookAPI )
  140. {
  141. // Already hooked! Chain them together.
  142. pBottomHookAPI->pfnOld = g_rgpHookAPIs[i][j].pfnNew;
  143. pBottomHookAPI = &( g_rgpHookAPIs[i][j] );
  144. pBottomHookAPI->pNextHook = pTopHookAPI;
  145. pBottomHookAPI->dwFlags = HAF_RESOLVED;
  146. }
  147. else
  148. {
  149. if( g_rgpHookAPIs[i][j].pNextHook )
  150. {
  151. // Chaining has already been constructed.
  152. pTopHookAPI = (PHOOKAPI) g_rgpHookAPIs[i][j].pNextHook;
  153. *DllListIndex=i;
  154. return pTopHookAPI;
  155. }
  156. // Not hooked yet. Set to top of chain.
  157. pTopHookAPI = &( g_rgpHookAPIs[i][j] );
  158. pTopHookAPI->pNextHook = pTopHookAPI;
  159. pTopHookAPI->dwFlags = HAF_RESOLVED;
  160. pBottomHookAPI = pTopHookAPI;
  161. }
  162. break;
  163. }
  164. }
  165. }
  166. if( pBottomHookAPI )
  167. {
  168. pBottomHookAPI->dwFlags = HAF_BOTTOM_OF_CHAIN;
  169. }
  170. *DllListIndex=i;
  171. return pTopHookAPI;
  172. } // ConstructChain
  173. ////////////////////////////////////////////////////////////////////////////////////
  174. //
  175. // Func: HookImports
  176. //
  177. // Params: dwBaseAddress Base address of module image to hook.
  178. //
  179. // szModName Name of module (for debug purposes only).
  180. //
  181. // Desc: This function is the workhorse of the shim: It scans the import
  182. // table of a module (specified by dwBaseAddress) looking for
  183. // function pointers that require hooking (according to HOOKAPI
  184. // entries in g_rgpHookAPIs). It then overwrites hooked function
  185. // pointers with the first stub function in the chain.
  186. //
  187. VOID HookImports(
  188. DWORD dwBaseAddress,
  189. LPTSTR szModName )
  190. {
  191. BOOL bAnyHooked = FALSE;
  192. PIMAGE_DOS_HEADER pIDH = (PIMAGE_DOS_HEADER) dwBaseAddress;
  193. PIMAGE_NT_HEADERS pINTH;
  194. PIMAGE_IMPORT_DESCRIPTOR pIID;
  195. PIMAGE_NT_HEADERS NtHeaders;
  196. DWORD dwTemp;
  197. DWORD dwImportTableOffset;
  198. PHOOKAPI pTopHookAPI;
  199. DWORD dwOldProtect;
  200. LONG i, j;
  201. PVOID pfnOld;
  202. // Get the import table
  203. pINTH = (PIMAGE_NT_HEADERS)(dwBaseAddress + pIDH->e_lfanew);
  204. dwImportTableOffset = pINTH->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress;
  205. if( dwImportTableOffset == 0 )
  206. return;
  207. pIID = (PIMAGE_IMPORT_DESCRIPTOR)(dwBaseAddress + dwImportTableOffset);
  208. // Loop through the import table and search for the APIs that we want to patch
  209. while( TRUE )
  210. {
  211. LPTSTR pszModule;
  212. PIMAGE_THUNK_DATA pITDA;
  213. // Return if no first thunk
  214. if (pIID->FirstThunk == 0) // (terminating condition)
  215. break;
  216. pszModule = (LPTSTR) ( dwBaseAddress + pIID->Name );
  217. // If we're not interested in this module jump to the next.
  218. bAnyHooked = FALSE;
  219. for( i = 0; i < g_nShimDllCount; i++ )
  220. {
  221. for( j = 0; j < g_rgnHookAPICount[i]; j++ )
  222. {
  223. if( lstrcmpi( g_rgpHookAPIs[i][j].pszModule, pszModule ) == 0 )
  224. {
  225. bAnyHooked = TRUE;
  226. goto ScanDone;
  227. }
  228. }
  229. }
  230. ScanDone:
  231. if( !bAnyHooked )
  232. {
  233. pIID++;
  234. continue;
  235. }
  236. // We have APIs to hook for this module!
  237. pITDA = (PIMAGE_THUNK_DATA)( dwBaseAddress + (DWORD)pIID->FirstThunk );
  238. while( TRUE )
  239. {
  240. DWORD DllListIndex = 0;
  241. pfnOld = (PVOID) pITDA->u1.Function;
  242. // Done with all the imports from this module?
  243. if( pITDA->u1.Ordinal == 0 ) // (terminating condition)
  244. break;
  245. if( g_bIsWin9X )
  246. pfnOld = ValidateAddress( pfnOld );
  247. pTopHookAPI = ConstructChain( (PVOID) pfnOld,&DllListIndex );
  248. if( ! pTopHookAPI )
  249. {
  250. pITDA++;
  251. continue;
  252. }
  253. /*
  254. * Check if we want to patch this API for this particular loaded module
  255. */
  256. if (NULL != g_rgnHookDllList[DllListIndex])
  257. {
  258. LPTSTR pszMod = g_rgnHookDllList[DllListIndex];
  259. BOOL b = FALSE; //gets set to true if the list is an exclude list
  260. while (*pszMod != 0) {
  261. if (lstrcmpi(pszMod, szModName) == 0)
  262. break;
  263. if(lstrcmpi(pszMod,TEXT("%")) == 0)
  264. b=TRUE;
  265. if(b && lstrcmpi(pszMod,TEXT("*")) == 0)
  266. {
  267. //this means it is exclude all and we already checked include list
  268. //skip this api
  269. break;
  270. }
  271. pszMod = pszMod + lstrlen(pszMod) + 1;
  272. }
  273. if(b && *pszMod != 0)
  274. {
  275. pITDA++;
  276. continue;
  277. }
  278. if (!b && *pszMod == 0)
  279. {
  280. pITDA++;
  281. continue;
  282. }
  283. }
  284. // Make the code page writable and overwrite new function pointer in import table
  285. if ( VirtualProtect( &pITDA->u1.Function,
  286. sizeof(DWORD),
  287. PAGE_READWRITE,
  288. &dwOldProtect) )
  289. {
  290. pITDA->u1.Function = (ULONG) pTopHookAPI->pfnNew;
  291. VirtualProtect( &pITDA->u1.Function,
  292. sizeof(DWORD),
  293. dwOldProtect,
  294. &dwTemp );
  295. }
  296. pITDA++;
  297. }
  298. pIID++;
  299. }
  300. } // HookImports
  301. ////////////////////////////////////////////////////////////////////////////////////
  302. //
  303. // Func: ResolveAPIs
  304. //
  305. // Desc: Each time a module is loaded, the pfnOld members of each HOOKAPI
  306. // structure in g_rgpHookAPIs are resolved (by calling GetProcAddress).
  307. //
  308. VOID ResolveAPIs()
  309. {
  310. LONG i, j;
  311. PVOID pfnOld = NULL;
  312. PIMAGE_NT_HEADERS NtHeaders;
  313. for (i = 0; i < g_nShimDllCount; i++)
  314. {
  315. for (j = 0; j < g_rgnHookAPICount[i]; j++ )
  316. {
  317. HMODULE hMod;
  318. // We only care about HOOKAPIs at the bottom of a chain.
  319. if( ( g_rgpHookAPIs[i][j].dwFlags & HAF_RESOLVED ) &&
  320. ! ( g_rgpHookAPIs[i][j].dwFlags & HAF_BOTTOM_OF_CHAIN ) )
  321. continue;
  322. if( ( hMod = GetModuleHandle(g_rgpHookAPIs[i][j].pszModule) ) != NULL)
  323. {
  324. pfnOld = GetProcAddress( hMod, g_rgpHookAPIs[i][j].pszFunctionName );
  325. if( pfnOld == NULL )
  326. {
  327. // This is an ERROR. The hook DLL asked to patch a function
  328. // that doesn't exist !!!
  329. }
  330. else
  331. {
  332. if( g_bIsWin9X )
  333. pfnOld = ValidateAddress( pfnOld );
  334. g_rgpHookAPIs[i][j].pfnOld = pfnOld;
  335. }
  336. }
  337. }
  338. }
  339. } // ResolveAPIs
  340. ////////////////////////////////////////////////////////////////////////////////////
  341. //
  342. // Func: PatchNewModules
  343. //
  344. // Desc: This function is called at initialization and then each time a module
  345. // is loaded. It enumerates all loaded processes and calls HookImports
  346. // to overwrite appropriate function pointers.
  347. //
  348. void __stdcall Shim2PatchNewModules( VOID )
  349. {
  350. DWORD i;
  351. LONG j;
  352. BOOL bRet;
  353. HMODULE hMod;
  354. MODULEENTRY32 ModuleEntry32;
  355. // Enumerate all the loaded modules and hook their import tables
  356. g_hSnapshot = CreateToolhelp32Snapshot( TH32CS_SNAPMODULE, 0 );
  357. g_hValidationSnapshot = CreateToolhelp32Snapshot( TH32CS_SNAPMODULE, 0 );
  358. if( g_hSnapshot == NULL )
  359. {
  360. return;
  361. }
  362. // Resolve old APIs for loaded modules
  363. ResolveAPIs();
  364. ModuleEntry32.dwSize = sizeof( ModuleEntry32 );
  365. bRet = Module32First( g_hSnapshot, &ModuleEntry32 );
  366. while( bRet )
  367. {
  368. hMod = ModuleEntry32.hModule;
  369. if( hMod >= (HMODULE) 0x80000000 )
  370. {
  371. bRet = Module32Next( g_hSnapshot, &ModuleEntry32 );
  372. continue;
  373. }
  374. // we need to make sure we are not trying to shim ourselves
  375. for (j = 0; j < g_nShimDllCount; j++ )
  376. {
  377. if( hMod == g_hShimDlls[j] )
  378. {
  379. hMod = NULL;
  380. break;
  381. }
  382. }
  383. for (j = 0; j < g_nHookedModuleCount; j++ )
  384. {
  385. if( hMod == g_hHookedModules[ j ] )
  386. {
  387. hMod = NULL;
  388. break;
  389. }
  390. }
  391. if( hMod )
  392. {
  393. HookImports( (DWORD) hMod, ModuleEntry32.szModule );
  394. g_hHookedModules[ g_nHookedModuleCount++ ] = hMod;
  395. }
  396. bRet = Module32Next( g_hSnapshot, &ModuleEntry32 );
  397. }
  398. if( g_hSnapshot )
  399. {
  400. CloseHandle( g_hSnapshot );
  401. g_hSnapshot = NULL;
  402. }
  403. if( g_hValidationSnapshot )
  404. {
  405. CloseHandle( g_hValidationSnapshot );
  406. g_hValidationSnapshot = NULL;
  407. }
  408. } //PatchNewModules
  409. ////////////////////////////////////////////////////////////////////////////////////
  410. //
  411. // Func: AddHookAPIs
  412. //
  413. // Params: hShimDll Handle of new shim DLL.
  414. //
  415. // pHookAPIs Pointer to new HOOKAPI array.
  416. //
  417. // dwCount Number of entries in pHookAPIs.
  418. //
  419. // Desc: Stores away the pointer returned by a shim DLL's GetHookAPIs
  420. // function in our global arrays.
  421. //
  422. void AddHookAPIs( HMODULE hShimDll, PHOOKAPI pHookAPIs, DWORD dwCount,LPTSTR szIncExclDllList)
  423. {
  424. DWORD i;
  425. for( i = 0; i < dwCount; i++ )
  426. {
  427. pHookAPIs[i].dwFlags = 0;
  428. pHookAPIs[i].pNextHook = NULL;
  429. }
  430. g_rgpHookAPIs[ g_nShimDllCount ] = pHookAPIs;
  431. g_rgnHookAPICount[ g_nShimDllCount ] = dwCount;
  432. g_hShimDlls[ g_nShimDllCount ] = hShimDll;
  433. g_rgnHookDllList[g_nShimDllCount ] = szIncExclDllList;
  434. g_nShimDllCount++;
  435. } // AddHookAPIs
  436. ////////////////////////////////////////////////////////////////////////////////////
  437. //
  438. // Func: _LoadPatchDll
  439. //
  440. // Params: pwszPatchDll Name of shim DLL to be loaded.
  441. //
  442. // Return: TRUE if successful, FALSE if not.
  443. //
  444. // Desc: Loads a shim DLL and retrieves the hooking information via GetHookAPIs.
  445. //
  446. BOOL _LoadPatchDll(
  447. LPWSTR szPatchDll,LPSTR szCmdLine,LPSTR szIncExclDllList)
  448. {
  449. PHOOKAPI pHookAPIs = NULL;
  450. DWORD dwHookAPICount = 0;
  451. HMODULE hModHookDll;
  452. PFNGETHOOKAPIS pfnGetHookAPIs;
  453. hModHookDll = LoadLibraryW(szPatchDll);
  454. if (hModHookDll == NULL)
  455. {
  456. return FALSE;
  457. }
  458. pfnGetHookAPIs = (PFNGETHOOKAPIS) GetProcAddress( hModHookDll, SHIM_GETHOOKAPIS );
  459. if( pfnGetHookAPIs == NULL )
  460. {
  461. FreeLibrary( hModHookDll );
  462. return FALSE;
  463. }
  464. pHookAPIs = (*pfnGetHookAPIs)(szCmdLine, Shim2PatchNewModules, &dwHookAPICount );
  465. if( dwHookAPICount == 0 || pHookAPIs == NULL )
  466. {
  467. FreeLibrary( hModHookDll );
  468. return FALSE;
  469. }
  470. AddHookAPIs( hModHookDll, pHookAPIs, dwHookAPICount,szIncExclDllList);
  471. return TRUE;
  472. } // _LoadPatchDll
  473. // Re-enable build warnings.
  474. #pragma warning( default : 4002 )