#include "precomp.h" IMPLEMENT_SHIM_BEGIN(TSPerUserFiles) #include "ShimHookMacro.h" #include #include #include "TSPerUserFiles_utils.h" HKEY HKLM = NULL; //Functions - helpers. DWORD RegKeyOpen(IN HKEY hKeyParent, IN LPCWSTR szKeyName, IN REGSAM samDesired, OUT HKEY *phKey ); DWORD RegLoadDWORD(IN HKEY hKey, IN LPCWSTR szValueName, OUT DWORD *pdwValue); DWORD RegGetKeyInfo(IN HKEY hKey, OUT LPDWORD pcValues, OUT LPDWORD pcbMaxValueNameLen); DWORD RegKeyEnumValues(IN HKEY hKey, IN DWORD iValue, OUT LPWSTR *pwszValueName, OUT LPBYTE *ppbData); /////////////////////////////////////////////////////////////////////////////// //struct PER_USER_PATH /////////////////////////////////////////////////////////////////////////////// /****************************************************************************** Routine Description: Loads wszFile and wszPerUserDir values from the registry; Translates them to ANSI and saves results in szFile and szPerUserDir If no wildcards used - constructs wszPerUserFile from wszPerUserDir, and wszFile and szPerUserFile from szPerUserDir and szFile Arguments: IN HKEY hKey - key to load values from. IN DWORD dwIndex - index of value Return Value: error code. Note: Name of a registry value is loaded into wszFile; Data is loaded into wszPerUserPath ******************************************************************************/ DWORD PER_USER_PATH::Init( IN HKEY hKey, IN DWORD dwIndex) { DWORD err = RegKeyEnumValues(hKey, dwIndex, &wszFile, (LPBYTE *)&wszPerUserDir ); if(err != ERROR_SUCCESS) { return err; } cFileLen = wcslen(wszFile); cPerUserDirLen = wcslen(wszPerUserDir); //Check if there is a '*' instead of file name. //'*' it stands for any file name not encluding extension long i; for(i=cFileLen-1; i>=0 && wszFile[i] !=L'\\'; i--) { if(wszFile[i] == L'*') { bWildCardUsed = TRUE; break; } } //there must be at least one '\\' in wszFile //and it cannot be a first character if(i<=0) { return ERROR_INVALID_PARAMETER; } InitANSI(); if(!bWildCardUsed) { //now wszFile+i points to '\\' in wszFile //Let's construct wszPerUserFile from //wszPerUserDir and wszFile DWORD cPerUserFile = (wcslen(wszPerUserDir)+cFileLen-i)+1; wszPerUserFile=(LPWSTR) LocalAlloc(LPTR,cPerUserFile*sizeof(WCHAR)); if(!wszPerUserFile) { return ERROR_NOT_ENOUGH_MEMORY; } wcscpy(wszPerUserFile,wszPerUserDir); wcscat(wszPerUserFile,wszFile+i); if(!bInitANSIFailed) { szPerUserFile=(LPSTR) LocalAlloc(LPTR,cPerUserFile); if(!szPerUserFile) { return ERROR_NOT_ENOUGH_MEMORY; } strcpy(szPerUserFile,szPerUserDir); strcat(szPerUserFile,szFile+i); } } return ERROR_SUCCESS; } /****************************************************************************** Routine Description: creates szFile and szPerUserPathTemplate from wszFile and wszPerUserPathTemplate strings Arguments: NONE Return Value: TRUE if success ******************************************************************************/ BOOL PER_USER_PATH::InitANSI() { //we already tried and failed //don't try again if(bInitANSIFailed) { return FALSE; } if(!szFile) { //allocate memory for ANSI strings DWORD cFile = cFileLen+1; szFile = (LPSTR) LocalAlloc(LPTR,cFile); if(!szFile) { bInitANSIFailed = TRUE; return FALSE; } DWORD cPerUserDir = cPerUserDirLen+1; szPerUserDir = (LPSTR) LocalAlloc(LPTR,cPerUserDir); if(!szPerUserDir) { bInitANSIFailed = TRUE; return FALSE; } //convert UNICODE wszFile and wszPerUserPath to ANSI if(!WideCharToMultiByte( CP_ACP, // code page 0, // performance and mapping flags wszFile, // wide-character string -1, // number of chars in string szFile, // buffer for new string cFile, // size of buffer NULL, // default for unmappable chars NULL // set when default char used )) { bInitANSIFailed = TRUE; return FALSE; } if(!WideCharToMultiByte( CP_ACP, // code page 0, // performance and mapping flags wszPerUserDir, // wide-character string -1, // number of chars in string szPerUserDir, // buffer for new string cPerUserDir, // size of buffer NULL, // default for unmappable chars NULL // set when default char used )) { bInitANSIFailed = TRUE; return FALSE; } } return TRUE; } /****************************************************************************** Routine Description: returns szPerUserPath if szInFile is equal to szFile. Arguments: IN LPCSTR szInFile - original path IN DWORD cInLen - length of szInFile in characters Return Value: szPerUserPath or NULL if failes ******************************************************************************/ LPCSTR PER_USER_PATH::PathForFileA( IN LPCSTR szInFile, IN DWORD cInLen) { if(bInitANSIFailed) { return NULL; } long j, i, k; if(bWildCardUsed) { // if(cInLen < cFileLen) { return NULL; } //if * is used, we need a special algorithm //the end of the path will more likely be different. //so start comparing from the end. for(j=cInLen-1, i=cFileLen-1; i>=0 && j>=0; i--, j--) { if(szFile[i] == '*') { i--; if(i<0 || szFile[i]!='\\') { //the string in the registry was garbage //the symbol previous to '*' must be '\\' return NULL; } //skip all symbols in szInFile till next '\\' while(j>=0 && szInFile[j]!='\\') j--; //At this point both strings must me of exactly the same size. //If not - they are not equal if(j!=i) { return NULL; } //no need to compare, //we already know that current symbols are equal break; } if(tolower(szFile[i])!=tolower(szInFile[j])) { return NULL; } } //i and j are equal now. //no more * allowed //j we will remember as a position of '\\' for(k=j-1; k>=0; k--) { if(tolower(szFile[k])!=tolower(szInFile[k])) { return NULL; } } //Okay, the strings are equal //Now construct the output string if(szPerUserFile) { LocalFree(szPerUserFile); szPerUserFile=NULL; } DWORD cPerUserFile = cPerUserDirLen + cInLen - j + 1; szPerUserFile = (LPSTR) LocalAlloc(LPTR,cPerUserFile); if(!szPerUserFile) { return NULL; } sprintf(szPerUserFile,"%s%s",szPerUserDir,szInFile+j); } else { //first find if the input string has right size if(cInLen != cFileLen) { return NULL; } //the end of the path will more likely be different. //so start comparing from the end. for(i=cFileLen-1; i>=0; i--) { if(tolower(szFile[i])!=tolower(szInFile[i])) { return NULL; } } } return szPerUserFile; } /****************************************************************************** Routine Description: returns wszPerUserPath if wszInFile is equal to szFile. Arguments: IN LPCSTR wszInFile - original path IN DWORD cInLen - length of wszInFile in characters Return Value: wszPerUserPath or NULL if failes ******************************************************************************/ LPCWSTR PER_USER_PATH::PathForFileW( IN LPCWSTR wszInFile, IN DWORD cInLen) { long j, i, k; if(bWildCardUsed) { // if(cInLen < cFileLen) { return NULL; } //if * is used, we need a special algorithm //the end of the path will more likely be different. //so start comparing from the end. for(j=cInLen-1, i=cFileLen-1; i>=0 && j>=0; i--, j--) { if(wszFile[i] == '*') { i--; if(i<0 || wszFile[i]!='\\') { //the string in the registry was garbage //the symbol previous to '*' must be '\\' return NULL; } //skip all symbols in szInFile till next '\\' while(j>=0 && wszInFile[j]!='\\') j--; //At this point both strings must me of exactly the same size. //If not - they are not equal if(j!=i) { return NULL; } //no need to compare, //we already know that current symbols are equal break; } if(towlower(wszFile[i])!=towlower(wszInFile[j])) { return NULL; } } //i and j are equal now. //no more * allowed //j we will remember as a position of '\\' for(k=j; k>=0; k--) { if(towlower(wszFile[k])!=towlower(wszInFile[k])) { return NULL; } } //Okay, the strings are equal //Now construct the output string if(wszPerUserFile) { LocalFree(wszPerUserFile); wszPerUserFile=NULL; } DWORD cPerUserFile = cPerUserDirLen + cInLen - j + 1; wszPerUserFile = (LPWSTR) LocalAlloc(LPTR,cPerUserFile*sizeof(WCHAR)); if(!wszPerUserFile) { return NULL; } swprintf(wszPerUserFile,L"%s%s",wszPerUserDir,wszInFile+j); } else { //first find if the input string has right size if(cInLen != cFileLen) { return NULL; } //the end of the path will more likely be different. //so start comparing from the end. for(i=cFileLen-1; i>=0; i--) { if(towlower(wszFile[i])!=towlower(wszInFile[i])) { return NULL; } } } return wszPerUserFile; } /////////////////////////////////////////////////////////////////////////////// //class CPerUserPaths /////////////////////////////////////////////////////////////////////////////// CPerUserPaths::CPerUserPaths(): m_pPaths(NULL), m_cPaths(0) { } CPerUserPaths::~CPerUserPaths() { if(m_pPaths) { delete[] m_pPaths; } if(HKLM) { NtClose(HKLM); HKLM = NULL; } } /****************************************************************************** Routine Description: Loads from the registry (HKLM\\Software\\Microsoft\\Windows NT\\CurrentVersion\\ Terminal Server\\Compatibility\\PerUserFiles\\) information of files that need to be redirected to per-user directory. Arguments: NONE Return Value: TRUE if success ******************************************************************************/ BOOL CPerUserPaths::Init() { //Get the name of the current executable. LPCSTR szModule = COMMAND_LINE; //command line is CHAR[] so szModule needs to be CHAR too. DPF("TSPerUserFiles",eDbgLevelInfo," - App Name: %s\n",szModule); //Open HKLM for later use if(!HKLM) { if(RegKeyOpen(NULL, L"\\Registry\\Machine", KEY_READ, &HKLM )!=ERROR_SUCCESS) { DPF("TSPerUserFiles",eDbgLevelError," - FAILED: Cannot open HKLM!\n"); return FALSE; } } //Check if TS App Compat is on. if(!IsAppCompatOn()) { DPF("TSPerUserFiles",eDbgLevelError," - FAILED: TS App Compat is off!\n"); return FALSE; } //Get files we need to redirect. DWORD err; HKEY hKey; WCHAR szKeyNameTemplate[] = L"Software\\Microsoft\\Windows NT\\CurrentVersion" L"\\Terminal Server\\Compatibility\\PerUserFiles\\%S"; LPWSTR szKeyName = (LPWSTR) LocalAlloc(LPTR, sizeof(szKeyNameTemplate)+strlen(szModule)*sizeof(WCHAR)); if(!szKeyName) { DPF("TSPerUserFiles",eDbgLevelError," - FAILED: cannot allocate key name\n"); return FALSE; } swprintf(szKeyName,szKeyNameTemplate,szModule); err = RegKeyOpen(HKLM, szKeyName, KEY_QUERY_VALUE, &hKey ); LocalFree(szKeyName); if(err == ERROR_SUCCESS) { err = RegGetKeyInfo(hKey, &m_cPaths, NULL); if(err == ERROR_SUCCESS) { DPF("TSPerUserFiles",eDbgLevelInfo," - %d file(s) need to be redirected\n",m_cPaths); //Allocate array of PER_USER_PATH structs m_pPaths = new PER_USER_PATH[m_cPaths]; if(!m_pPaths) { NtClose(hKey); return FALSE; } for(DWORD i=0;iData); } return RtlNtStatusToDosError( Status ); } /****************************************************************************** Get key's number of values and max svalue name length ******************************************************************************/ DWORD RegGetKeyInfo( IN HKEY hKey, OUT LPDWORD pcValues, OUT LPDWORD pcbMaxValueNameLen) { NTSTATUS Status; KEY_CACHED_INFORMATION KeyInfo; DWORD dwRead; Status = NtQueryKey( hKey, KeyCachedInformation, &KeyInfo, sizeof(KeyInfo), &dwRead); if (NT_SUCCESS(Status)) { if(pcValues) { *pcValues = KeyInfo.Values; } if(pcbMaxValueNameLen) { *pcbMaxValueNameLen = KeyInfo.MaxValueNameLen; } } return RtlNtStatusToDosError( Status ); } /****************************************************************************** Enumerates values of the registry key Returns name and data for one value at a time ******************************************************************************/ DWORD RegKeyEnumValues( IN HKEY hKey, IN DWORD iValue, OUT LPWSTR *pwszValueName, OUT LPBYTE *ppbData) { KEY_VALUE_FULL_INFORMATION viValue, *pviValue; ULONG dwActualLength; NTSTATUS Status = STATUS_SUCCESS; DWORD err = ERROR_SUCCESS; *pwszValueName = NULL; *ppbData = NULL; pviValue = &viValue; Status = NtEnumerateValueKey( hKey, iValue, KeyValueFullInformation, pviValue, sizeof(KEY_VALUE_FULL_INFORMATION), &dwActualLength); if (Status == STATUS_BUFFER_OVERFLOW) { // // Our default buffer of KEY_VALUE_FULL_INFORMATION size didn't quite cut it // Forced to allocate from heap and make call again. // pviValue = (KEY_VALUE_FULL_INFORMATION *) LocalAlloc(LPTR, dwActualLength); if (!pviValue) { return GetLastError(); } Status = NtEnumerateValueKey( hKey, iValue, KeyValueFullInformation, pviValue, dwActualLength, &dwActualLength); } if (NT_SUCCESS(Status)) { *pwszValueName = (LPWSTR)LocalAlloc(LPTR,pviValue->NameLength+sizeof(WCHAR)); if(*pwszValueName) { *ppbData = (LPBYTE)LocalAlloc(LPTR,pviValue->DataLength); if(*ppbData) { CopyMemory(*pwszValueName, pviValue->Name, pviValue->NameLength); (*pwszValueName)[pviValue->NameLength/sizeof(WCHAR)] = 0; CopyMemory(*ppbData, LPBYTE(pviValue)+pviValue->DataOffset, pviValue->DataLength); } else { err = GetLastError(); LocalFree(*pwszValueName); *pwszValueName = NULL; } } else { err = GetLastError(); } } if(pviValue != &viValue) { LocalFree(pviValue); } if(err != ERROR_SUCCESS) { return err; } else { return RtlNtStatusToDosError( Status ); } } IMPLEMENT_SHIM_END