//////////////////////////////////////////////////////////////////////////////////// // // File: shim2.c // // History: May-99 clupu Created. // Aug-99 v-johnwh Various bug fixes. // 23-Nov-99 markder Support for multiple shim DLLs, chaining // of hooks, DLL loads/unloads. General clean-up. // Jan-00 markder Windows 9x support added. // Mar-00 a-batjar Changed to support whistler format on w2k // May-00 v-johnwh Modified to work in the profiler // Desc: Contains all code to facilitate hooking of APIs by replacing entries // in the import tables of loaded modules. // //////////////////////////////////////////////////////////////////////////////////// #include #include #include #include #include #include #include "shimdb.h" #include "shim2.h" #define HAF_RESOLVED 0x0001 #define HAF_BOTTOM_OF_CHAIN 0x0002 typedef PHOOKAPI (*PFNNEWGETHOOKAPIS)(DWORD dwGetProcAddress, DWORD dwLoadLibraryA, DWORD dwFreeLibrary, DWORD* pdwHookAPICount); typedef LPSTR (*PFNGETCOMMANDLINEA)(VOID); typedef LPWSTR (*PFNGETCOMMANDLINEW)(VOID); typedef PVOID (*PFNGETPROCADDRESS)(HMODULE hMod, char* pszProc); typedef HINSTANCE (*PFNLOADLIBRARYA)(LPCSTR lpLibFileName); typedef HINSTANCE (*PFNLOADLIBRARYW)(LPCWSTR lpLibFileName); typedef HINSTANCE (*PFNLOADLIBRARYEXA)(LPCSTR lpLibFileName, HANDLE hFile, DWORD dwFlags); typedef HINSTANCE (*PFNLOADLIBRARYEXW)(LPCWSTR lpLibFileName, HANDLE hFile, DWORD dwFlags); typedef BOOL (*PFNFREELIBRARY)(HMODULE hLibModule); // Global Variables // Disable build warnings due to Print macro in free builds #pragma warning( disable : 4002 ) #define MAX_MODULES 512 #define SHIM_GETHOOKAPIS "GetHookAPIs" //////////////////////////////////////////////////////////////////////////////////// // API hook count & indices //////////////////////////////////////////////////////////////////////////////////// enum { hookGetProcAddress, hookLoadLibraryA, hookLoadLibraryW, hookLoadLibraryExA, hookLoadLibraryExW, hookFreeLibrary, hookGetCommandLineA, hookGetCommandLineW }; //////////////////////////////////////////////////////////////////////////////////// // Global variables //////////////////////////////////////////////////////////////////////////////////// // This array contains information used by the shim mechanism to describe // what API to hook with a particular stub function. LONG g_nShimDllCount; HMODULE g_hShimDlls[MAX_MODULES]; PHOOKAPI g_rgpHookAPIs[MAX_MODULES]; LONG g_rgnHookAPICount[MAX_MODULES]; LPTSTR g_rgnHookDllList[MAX_MODULES]; HMODULE g_hHookedModules[MAX_MODULES]; LONG g_nHookedModuleCount; extern BOOL g_bIsWin9X; HANDLE g_hSnapshot = NULL; HANDLE g_hValidationSnapshot = NULL; //////////////////////////////////////////////////////////////////////////////////// // // Func: ValidateAddress // // Params: pfnOld Original API function pointer to validate. // // Return: Potentially massaged pfnOld. // // Desc: Win9x thunks system API entry points for some reason. The // shim mechanism has to work around this to get to the // 'real' pointer so that it can make valid comparisons. // PVOID ValidateAddress( PVOID pfnOld ) { MODULEENTRY32 ModuleEntry32; BOOL bRet; long i, j; // Make sure the address isn't a shim thunk for( i = g_nShimDllCount - 1; i >= 0; i-- ) { for( j = 0; j < g_rgnHookAPICount[i]; j++ ) { if( g_rgpHookAPIs[i][j].pfnOld == pfnOld ) { if( pfnOld == g_rgpHookAPIs[i][j].pfnNew ) return pfnOld; } } } ModuleEntry32.dwSize = sizeof( ModuleEntry32 ); bRet = Module32First( g_hValidationSnapshot, &ModuleEntry32 ); while( bRet ) { if( pfnOld >= (PVOID) ModuleEntry32.modBaseAddr && pfnOld <= (PVOID) ( ModuleEntry32.modBaseAddr + ModuleEntry32.modBaseSize ) ) { return pfnOld; } bRet = Module32Next( g_hValidationSnapshot, &ModuleEntry32 ); } // Hack for Win9x return *(PVOID *)( ((PBYTE)pfnOld)+1); } //////////////////////////////////////////////////////////////////////////////////// // // Func: ConstructChain // // Params: pfnOld Original API function pointer to resolve. // // Return: Top-of-chain PHOOKAPI structure. // // Desc: Scans HookAPI arrays for pfnOld and either constructs the // chain or returns the top-of-chain PHOOKAPI if the chain // already exists. // PHOOKAPI ConstructChain( PVOID pfnOld ,DWORD* DllListIndex) { LONG i, j; PHOOKAPI pTopHookAPI; PHOOKAPI pBottomHookAPI; pTopHookAPI = NULL; pBottomHookAPI = NULL; *DllListIndex=0; // Scan all HOOKAPI entries for corresponding function pointer for( i = g_nShimDllCount - 1; i >= 0; i-- ) { for( j = 0; j < g_rgnHookAPICount[i]; j++ ) { if( g_rgpHookAPIs[i][j].pfnOld == pfnOld ) { if( pTopHookAPI ) { // Already hooked! Chain them together. pBottomHookAPI->pfnOld = g_rgpHookAPIs[i][j].pfnNew; pBottomHookAPI = &( g_rgpHookAPIs[i][j] ); pBottomHookAPI->pNextHook = pTopHookAPI; pBottomHookAPI->dwFlags = HAF_RESOLVED; } else { if( g_rgpHookAPIs[i][j].pNextHook ) { // Chaining has already been constructed. pTopHookAPI = (PHOOKAPI) g_rgpHookAPIs[i][j].pNextHook; *DllListIndex=i; return pTopHookAPI; } // Not hooked yet. Set to top of chain. pTopHookAPI = &( g_rgpHookAPIs[i][j] ); pTopHookAPI->pNextHook = pTopHookAPI; pTopHookAPI->dwFlags = HAF_RESOLVED; pBottomHookAPI = pTopHookAPI; } break; } } } if( pBottomHookAPI ) { pBottomHookAPI->dwFlags = HAF_BOTTOM_OF_CHAIN; } *DllListIndex=i; return pTopHookAPI; } // ConstructChain //////////////////////////////////////////////////////////////////////////////////// // // Func: HookImports // // Params: dwBaseAddress Base address of module image to hook. // // szModName Name of module (for debug purposes only). // // Desc: This function is the workhorse of the shim: It scans the import // table of a module (specified by dwBaseAddress) looking for // function pointers that require hooking (according to HOOKAPI // entries in g_rgpHookAPIs). It then overwrites hooked function // pointers with the first stub function in the chain. // VOID HookImports( DWORD dwBaseAddress, LPTSTR szModName ) { BOOL bAnyHooked = FALSE; PIMAGE_DOS_HEADER pIDH = (PIMAGE_DOS_HEADER) dwBaseAddress; PIMAGE_NT_HEADERS pINTH; PIMAGE_IMPORT_DESCRIPTOR pIID; PIMAGE_NT_HEADERS NtHeaders; DWORD dwTemp; DWORD dwImportTableOffset; PHOOKAPI pTopHookAPI; DWORD dwOldProtect; LONG i, j; PVOID pfnOld; // Get the import table pINTH = (PIMAGE_NT_HEADERS)(dwBaseAddress + pIDH->e_lfanew); dwImportTableOffset = pINTH->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT].VirtualAddress; if( dwImportTableOffset == 0 ) return; pIID = (PIMAGE_IMPORT_DESCRIPTOR)(dwBaseAddress + dwImportTableOffset); // Loop through the import table and search for the APIs that we want to patch while( TRUE ) { LPTSTR pszModule; PIMAGE_THUNK_DATA pITDA; // Return if no first thunk if (pIID->FirstThunk == 0) // (terminating condition) break; pszModule = (LPTSTR) ( dwBaseAddress + pIID->Name ); // If we're not interested in this module jump to the next. bAnyHooked = FALSE; for( i = 0; i < g_nShimDllCount; i++ ) { for( j = 0; j < g_rgnHookAPICount[i]; j++ ) { if( lstrcmpi( g_rgpHookAPIs[i][j].pszModule, pszModule ) == 0 ) { bAnyHooked = TRUE; goto ScanDone; } } } ScanDone: if( !bAnyHooked ) { pIID++; continue; } // We have APIs to hook for this module! pITDA = (PIMAGE_THUNK_DATA)( dwBaseAddress + (DWORD)pIID->FirstThunk ); while( TRUE ) { DWORD DllListIndex = 0; pfnOld = (PVOID) pITDA->u1.Function; // Done with all the imports from this module? if( pITDA->u1.Ordinal == 0 ) // (terminating condition) break; if( g_bIsWin9X ) pfnOld = ValidateAddress( pfnOld ); pTopHookAPI = ConstructChain( (PVOID) pfnOld,&DllListIndex ); if( ! pTopHookAPI ) { pITDA++; continue; } /* * Check if we want to patch this API for this particular loaded module */ if (NULL != g_rgnHookDllList[DllListIndex]) { LPTSTR pszMod = g_rgnHookDllList[DllListIndex]; BOOL b = FALSE; //gets set to true if the list is an exclude list while (*pszMod != 0) { if (lstrcmpi(pszMod, szModName) == 0) break; if(lstrcmpi(pszMod,TEXT("%")) == 0) b=TRUE; if(b && lstrcmpi(pszMod,TEXT("*")) == 0) { //this means it is exclude all and we already checked include list //skip this api break; } pszMod = pszMod + lstrlen(pszMod) + 1; } if(b && *pszMod != 0) { pITDA++; continue; } if (!b && *pszMod == 0) { pITDA++; continue; } } // Make the code page writable and overwrite new function pointer in import table if ( VirtualProtect( &pITDA->u1.Function, sizeof(DWORD), PAGE_READWRITE, &dwOldProtect) ) { pITDA->u1.Function = (ULONG) pTopHookAPI->pfnNew; VirtualProtect( &pITDA->u1.Function, sizeof(DWORD), dwOldProtect, &dwTemp ); } pITDA++; } pIID++; } } // HookImports //////////////////////////////////////////////////////////////////////////////////// // // Func: ResolveAPIs // // Desc: Each time a module is loaded, the pfnOld members of each HOOKAPI // structure in g_rgpHookAPIs are resolved (by calling GetProcAddress). // VOID ResolveAPIs() { LONG i, j; PVOID pfnOld = NULL; PIMAGE_NT_HEADERS NtHeaders; for (i = 0; i < g_nShimDllCount; i++) { for (j = 0; j < g_rgnHookAPICount[i]; j++ ) { HMODULE hMod; // We only care about HOOKAPIs at the bottom of a chain. if( ( g_rgpHookAPIs[i][j].dwFlags & HAF_RESOLVED ) && ! ( g_rgpHookAPIs[i][j].dwFlags & HAF_BOTTOM_OF_CHAIN ) ) continue; if( ( hMod = GetModuleHandle(g_rgpHookAPIs[i][j].pszModule) ) != NULL) { pfnOld = GetProcAddress( hMod, g_rgpHookAPIs[i][j].pszFunctionName ); if( pfnOld == NULL ) { // This is an ERROR. The hook DLL asked to patch a function // that doesn't exist !!! } else { if( g_bIsWin9X ) pfnOld = ValidateAddress( pfnOld ); g_rgpHookAPIs[i][j].pfnOld = pfnOld; } } } } } // ResolveAPIs //////////////////////////////////////////////////////////////////////////////////// // // Func: PatchNewModules // // Desc: This function is called at initialization and then each time a module // is loaded. It enumerates all loaded processes and calls HookImports // to overwrite appropriate function pointers. // void __stdcall Shim2PatchNewModules( VOID ) { DWORD i; LONG j; BOOL bRet; HMODULE hMod; MODULEENTRY32 ModuleEntry32; // Enumerate all the loaded modules and hook their import tables g_hSnapshot = CreateToolhelp32Snapshot( TH32CS_SNAPMODULE, 0 ); g_hValidationSnapshot = CreateToolhelp32Snapshot( TH32CS_SNAPMODULE, 0 ); if( g_hSnapshot == NULL ) { return; } // Resolve old APIs for loaded modules ResolveAPIs(); ModuleEntry32.dwSize = sizeof( ModuleEntry32 ); bRet = Module32First( g_hSnapshot, &ModuleEntry32 ); while( bRet ) { hMod = ModuleEntry32.hModule; if( hMod >= (HMODULE) 0x80000000 ) { bRet = Module32Next( g_hSnapshot, &ModuleEntry32 ); continue; } // we need to make sure we are not trying to shim ourselves for (j = 0; j < g_nShimDllCount; j++ ) { if( hMod == g_hShimDlls[j] ) { hMod = NULL; break; } } for (j = 0; j < g_nHookedModuleCount; j++ ) { if( hMod == g_hHookedModules[ j ] ) { hMod = NULL; break; } } if( hMod ) { HookImports( (DWORD) hMod, ModuleEntry32.szModule ); g_hHookedModules[ g_nHookedModuleCount++ ] = hMod; } bRet = Module32Next( g_hSnapshot, &ModuleEntry32 ); } if( g_hSnapshot ) { CloseHandle( g_hSnapshot ); g_hSnapshot = NULL; } if( g_hValidationSnapshot ) { CloseHandle( g_hValidationSnapshot ); g_hValidationSnapshot = NULL; } } //PatchNewModules //////////////////////////////////////////////////////////////////////////////////// // // Func: AddHookAPIs // // Params: hShimDll Handle of new shim DLL. // // pHookAPIs Pointer to new HOOKAPI array. // // dwCount Number of entries in pHookAPIs. // // Desc: Stores away the pointer returned by a shim DLL's GetHookAPIs // function in our global arrays. // void AddHookAPIs( HMODULE hShimDll, PHOOKAPI pHookAPIs, DWORD dwCount,LPTSTR szIncExclDllList) { DWORD i; for( i = 0; i < dwCount; i++ ) { pHookAPIs[i].dwFlags = 0; pHookAPIs[i].pNextHook = NULL; } g_rgpHookAPIs[ g_nShimDllCount ] = pHookAPIs; g_rgnHookAPICount[ g_nShimDllCount ] = dwCount; g_hShimDlls[ g_nShimDllCount ] = hShimDll; g_rgnHookDllList[g_nShimDllCount ] = szIncExclDllList; g_nShimDllCount++; } // AddHookAPIs //////////////////////////////////////////////////////////////////////////////////// // // Func: _LoadPatchDll // // Params: pwszPatchDll Name of shim DLL to be loaded. // // Return: TRUE if successful, FALSE if not. // // Desc: Loads a shim DLL and retrieves the hooking information via GetHookAPIs. // BOOL _LoadPatchDll( LPWSTR szPatchDll,LPSTR szCmdLine,LPSTR szIncExclDllList) { PHOOKAPI pHookAPIs = NULL; DWORD dwHookAPICount = 0; HMODULE hModHookDll; PFNGETHOOKAPIS pfnGetHookAPIs; hModHookDll = LoadLibraryW(szPatchDll); if (hModHookDll == NULL) { return FALSE; } pfnGetHookAPIs = (PFNGETHOOKAPIS) GetProcAddress( hModHookDll, SHIM_GETHOOKAPIS ); if( pfnGetHookAPIs == NULL ) { FreeLibrary( hModHookDll ); return FALSE; } pHookAPIs = (*pfnGetHookAPIs)(szCmdLine, Shim2PatchNewModules, &dwHookAPICount ); if( dwHookAPICount == 0 || pHookAPIs == NULL ) { FreeLibrary( hModHookDll ); return FALSE; } AddHookAPIs( hModHookDll, pHookAPIs, dwHookAPICount,szIncExclDllList); return TRUE; } // _LoadPatchDll // Re-enable build warnings. #pragma warning( default : 4002 )