#include #include #include #include #include "iucommon.h" #include "logging.h" #include "download.h" #include "dlutil.h" #include "malloc.h" #include "wusafefn.h" /////////////////////////////////////////////////////////////////////////////// // typedef BOOL (WINAPI *pfn_OpenProcessToken)(HANDLE, DWORD, PHANDLE); typedef BOOL (WINAPI *pfn_OpenThreadToken)(HANDLE, DWORD, BOOL, PHANDLE); typedef BOOL (WINAPI *pfn_SetThreadToken)(PHANDLE, HANDLE); typedef BOOL (WINAPI *pfn_GetTokenInformation)(HANDLE, TOKEN_INFORMATION_CLASS, LPVOID, DWORD, PDWORD); typedef BOOL (WINAPI *pfn_IsValidSid)(PSID); typedef BOOL (WINAPI *pfn_AllocateAndInitializeSid)(PSID_IDENTIFIER_AUTHORITY, BYTE, DWORD, DWORD, DWORD, DWORD, DWORD, DWORD, DWORD, DWORD, PSID); typedef BOOL (WINAPI *pfn_EqualSid)(PSID, PSID); typedef PVOID (WINAPI *pfn_FreeSid)(PSID); const TCHAR c_szRPWU[] = _T("Software\\Microsoft\\Windows\\CurrentVersion\\WindowsUpdate"); const TCHAR c_szRVTransport[] = _T("DownloadTransport"); const TCHAR c_szAdvapi32[] = _T("advapi32.dll"); // *************************************************************************** static BOOL AmIPrivileged(void) { LOG_Block("AmINotPrivileged()"); pfn_AllocateAndInitializeSid pfnAllocateAndInitializeSid = NULL; pfn_GetTokenInformation pfnGetTokenInformation = NULL; pfn_OpenProcessToken pfnOpenProcessToken = NULL; pfn_OpenThreadToken pfnOpenThreadToken = NULL; pfn_SetThreadToken pfnSetThreadToken = NULL; pfn_IsValidSid pfnIsValidSid = NULL; pfn_EqualSid pfnEqualSid = NULL; pfn_FreeSid pfnFreeSid = NULL; HMODULE hmod = NULL; SID_IDENTIFIER_AUTHORITY siaNT = SECURITY_NT_AUTHORITY; TOKEN_USER *ptu = NULL; HANDLE hToken = NULL, hTokenImp = NULL; DWORD cb, cbGot, i; PSID psid = NULL; BOOL fRet = FALSE; DWORD rgRIDs[3] = { SECURITY_LOCAL_SYSTEM_RID, SECURITY_LOCAL_SERVICE_RID, SECURITY_NETWORK_SERVICE_RID }; hmod = LoadLibraryFromSystemDir(c_szAdvapi32); if (hmod == NULL) goto done; pfnAllocateAndInitializeSid = (pfn_AllocateAndInitializeSid)GetProcAddress(hmod, "AllocateAndInitializeSid"); pfnGetTokenInformation = (pfn_GetTokenInformation)GetProcAddress(hmod, "GetTokenInformation"); pfnOpenProcessToken = (pfn_OpenProcessToken)GetProcAddress(hmod, "OpenProcessToken"); pfnOpenThreadToken = (pfn_OpenThreadToken)GetProcAddress(hmod, "OpenThreadToken"); pfnSetThreadToken = (pfn_SetThreadToken)GetProcAddress(hmod, "SetThreadToken"); pfnIsValidSid = (pfn_IsValidSid)GetProcAddress(hmod, "IsValidSid"); pfnEqualSid = (pfn_EqualSid)GetProcAddress(hmod, "EqualSid"); pfnFreeSid = (pfn_FreeSid)GetProcAddress(hmod, "FreeSid"); if (pfnAllocateAndInitializeSid == NULL || pfnGetTokenInformation == NULL || pfnOpenProcessToken == NULL || pfnOpenThreadToken == NULL || pfnSetThreadToken == NULL || pfnIsValidSid == NULL || pfnEqualSid == NULL || pfnFreeSid == NULL) { SetLastError(ERROR_PROC_NOT_FOUND); goto done; } // need the process token fRet = (*pfnOpenProcessToken)(GetCurrentProcess(), TOKEN_READ, &hToken); if (fRet == FALSE) { if (GetLastError() == ERROR_ACCESS_DENIED) { fRet = (*pfnOpenThreadToken)(GetCurrentThread(), TOKEN_READ | TOKEN_IMPERSONATE, TRUE, &hTokenImp); if (fRet == FALSE) goto done; fRet = (*pfnSetThreadToken)(NULL, NULL); fRet = (*pfnOpenProcessToken)(GetCurrentProcess(), TOKEN_READ, &hToken); if ((*pfnSetThreadToken)(NULL, hTokenImp) == FALSE) fRet = FALSE; } if (fRet == FALSE) goto done; } // need the SID from the token fRet = (*pfnGetTokenInformation)(hToken, TokenUser, NULL, 0, &cb); if (fRet != FALSE && GetLastError() != ERROR_INSUFFICIENT_BUFFER) { fRet = FALSE; goto done; } ptu = (TOKEN_USER *)HeapAlloc(GetProcessHeap(), 0, cb); if (ptu == NULL) { SetLastError(ERROR_OUTOFMEMORY); fRet = FALSE; goto done; } fRet = (*pfnGetTokenInformation)(hToken, TokenUser, (LPVOID)ptu, cb, &cbGot); if (fRet == FALSE) goto done; fRet = (*pfnIsValidSid)(ptu->User.Sid); if (fRet == FALSE) goto done; // loop thru & check against the SIDs we are interested in for (i = 0; i < 3; i++) { fRet = (*pfnAllocateAndInitializeSid)(&siaNT, 1, rgRIDs[i], 0, 0, 0, 0, 0, 0, 0, &psid); if (fRet == FALSE) goto done; fRet = (*pfnIsValidSid)(psid); if (fRet == FALSE) goto done; // if we get a SID match, then return TRUE fRet = (*pfnEqualSid)(psid, ptu->User.Sid); (*pfnFreeSid)(psid); psid = NULL; if (fRet) { fRet = TRUE; goto done; } } // only way to get here is to fail all the SID checks above. So we ain't // privileged. Yeehaw. fRet = FALSE; done: // if we had an impersonation token on the thread, put it back in place. if (ptu != NULL) HeapFree(GetProcessHeap(), 0, ptu); if (hToken != NULL) CloseHandle(hToken); if (hTokenImp != NULL) CloseHandle(hTokenImp); if (psid != NULL && pfnFreeSid != NULL) (*pfnFreeSid)(psid); if (hmod != NULL) FreeLibrary(hmod); return fRet; } #if defined(DEBUG) || defined(DBG) // ************************************************************************** static BOOL CheckDebugRegKey(DWORD *pdwAllowed) { LOG_Block("CheckDebugRegKey()"); DWORD dw, dwType, dwValue, cb; HKEY hkey = NULL; BOOL fRet = FALSE; // explictly do not initialize *pdwAllowed. We only want it overwritten // if the reg key is properly set dw = RegOpenKeyEx(HKEY_LOCAL_MACHINE, c_szRPWU, 0, KEY_READ, &hkey); if (dw != ERROR_SUCCESS) goto done; cb = sizeof(dwValue); dw = RegQueryValueEx(hkey, c_szRVTransport, 0, &dwType, (LPBYTE)&dwValue, &cb); if (dw != ERROR_SUCCESS) goto done; // set this to 3 so we'll fall down into the error case below if (dwType != REG_DWORD) dwValue = 3; fRet = TRUE; switch(dwValue) { case 0: *pdwAllowed = 0; break; case 1: *pdwAllowed = WUDF_ALLOWWINHTTPONLY; break; case 2: *pdwAllowed = WUDF_ALLOWWININETONLY; break; default: LOG_Internet(_T("Bad reg value in DownloadTransport. Ignoring.")); fRet = FALSE; break; } done: if (hkey != NULL) RegCloseKey(hkey); return fRet; } #endif // ************************************************************************** DWORD GetAllowedDownloadTransport(DWORD dwFlagsInitial) { DWORD dwFlags = (dwFlagsInitial & WUDF_TRANSPORTMASK); #if defined(UNICODE) // don't bother checking if we're local system if we're already using // wininet if ((dwFlags & WUDF_ALLOWWININETONLY) == 0) { if (AmIPrivileged() == FALSE) dwFlags = WUDF_ALLOWWININETONLY; } #if defined(DEBUG) || defined(DBG) CheckDebugRegKey(&dwFlags); #endif // defined(DEBUG) || defined(DBG) #else // defined(UNICODE) // only allow wininet on ANSI dwFlags = WUDF_ALLOWWININETONLY; #endif // defined(UNICODE) return (dwFlags | (dwFlagsInitial & ~WUDF_TRANSPORTMASK)); } /////////////////////////////////////////////////////////////////////////////// // // ************************************************************************** static inline BOOL IsServerFileDifferentWorker(FILETIME &ftServerTime, DWORD dwServerFileSize, HANDLE hFile) { LOG_Block("IsServerFileNewerWorker()"); FILETIME ftCreateTime; DWORD cbLocalFile; // By default, always return TRUE so we can download a new file.. BOOL fRet = TRUE; // if we don't have a valid file handle, just return TRUE to download a // new copy if (hFile == INVALID_HANDLE_VALUE) goto done; cbLocalFile = GetFileSize(hFile, NULL); LOG_Internet(_T("IsServerFileNewer: Local size: %d. Remote size: %d"), cbLocalFile, dwServerFileSize); // if the sizes are not equal, then return TRUE if (cbLocalFile != dwServerFileSize) goto done; if (GetFileTime(hFile, &ftCreateTime, NULL, NULL)) { LOG_Internet(_T("IsServerFileNewer: Local time: %x%0x. Remote time: %x%0x."), ftCreateTime.dwHighDateTime, ftCreateTime.dwLowDateTime, ftServerTime.dwHighDateTime, ftServerTime.dwLowDateTime); // if the local file has a different timestamp, then return TRUE. fRet = (CompareFileTime(&ftCreateTime, &ftServerTime) != 0); } done: return fRet; } // ************************************************************************** BOOL IsServerFileDifferentW(FILETIME &ftServerTime, DWORD dwServerFileSize, LPCWSTR wszLocalFile) { LOG_Block("IsServerFileDifferentW()"); HANDLE hFile = INVALID_HANDLE_VALUE; BOOL fRet = TRUE; // if we have an error opening the file, just return TRUE to download a // new copy hFile = CreateFileW(wszLocalFile, GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); if (hFile == INVALID_HANDLE_VALUE) { LOG_Internet(_T("IsServerFileDifferent: %ls does not exist."), wszLocalFile); return TRUE; } else { fRet = IsServerFileDifferentWorker(ftServerTime, dwServerFileSize, hFile); CloseHandle(hFile); return fRet; } } // ************************************************************************** BOOL IsServerFileDifferentA(FILETIME &ftServerTime, DWORD dwServerFileSize, LPCSTR szLocalFile) { LOG_Block("IsServerFileDifferentA()"); HANDLE hFile = INVALID_HANDLE_VALUE; // if we have an error opening the file, just return TRUE to download a // new copy hFile = CreateFileA(szLocalFile, GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL); if (hFile == INVALID_HANDLE_VALUE) { LOG_Internet(_T("IsServerFileDifferent: %s does not exist."), szLocalFile); return TRUE; } else { BOOL fRet; fRet = IsServerFileDifferentWorker(ftServerTime, dwServerFileSize, hFile); CloseHandle(hFile); return fRet; } } // ************************************************************************** // helper function to handle quit events // // return TRUE if okay to continue // return FALSE if we should quit now! BOOL HandleEvents(HANDLE *phEvents, UINT nEventCount) { LOG_Block("HandleEvents()"); DWORD dwWait; // is there any events to handle? if (phEvents == NULL || nEventCount == 0) return TRUE; // we only want to check the signaled status, so don't bother waiting dwWait = WaitForMultipleObjects(nEventCount, phEvents, FALSE, 0); if (dwWait == WAIT_TIMEOUT) { return TRUE; } else { LOG_Internet(_T("HandleEvents: A quit event was signaled. Aborting...")); return FALSE; } } /////////////////////////////////////////////////////////////////////////////// // // ************************************************************************** HRESULT PerformDownloadToFile(pfn_ReadDataFromSite pfnRead, HINTERNET hRequest, HANDLE hFile, DWORD cbFile, DWORD cbBuffer, HANDLE *rghEvents, DWORD cEvents, PFNDownloadCallback fpnCallback, LPVOID pCallbackData, DWORD *pcbDownloaded) { LOG_Block("PerformDownloadToFile()"); HRESULT hr = S_OK; PBYTE pbBuffer = NULL; DWORD cbDownloaded = 0, cbRead, cbWritten; LONG lCallbackRequest = 0; pbBuffer = (PBYTE)HeapAlloc(GetProcessHeap(), 0, cbBuffer); if (pbBuffer == NULL) { hr = E_OUTOFMEMORY; LOG_ErrorMsg(hr); goto done; } // Download the File for(;;) { if ((*pfnRead)(hRequest, pbBuffer, cbBuffer, &cbRead) == FALSE) { hr = HRESULT_FROM_WIN32(GetLastError()); if (FAILED(hr)) { LOG_ErrorMsg(hr); goto done; } } if (cbRead == 0) { BYTE bTemp[32]; // Make one final call to WinHttpReadData to commit the file to // Cache. (the download is not complete otherwise) (*pfnRead)(hRequest, bTemp, ARRAYSIZE(bTemp), &cbRead); break; } cbDownloaded += cbRead; if (fpnCallback != NULL) { fpnCallback(pCallbackData, DOWNLOAD_STATUS_OK, cbFile, cbRead, NULL, &lCallbackRequest); if (lCallbackRequest == 4) { // QuitEvent was Signaled.. abort requested. We will do // another callback and pass the Abort State back fpnCallback(pCallbackData, DOWNLOAD_STATUS_ABORTED, cbFile, cbRead, NULL, NULL); hr = E_ABORT; // set return result to abort. goto done; } } if (WriteFile(hFile, pbBuffer, cbRead, &cbWritten, NULL) == FALSE) { hr = HRESULT_FROM_WIN32(GetLastError()); LOG_ErrorMsg(hr); goto done; } if (HandleEvents(rghEvents, cEvents) == FALSE) { // we need to quit the download clean up, send abort event and clean up what we've downloaded if (fpnCallback != NULL) fpnCallback(pCallbackData, DOWNLOAD_STATUS_ABORTED, cbFile, cbRead, NULL, NULL); hr = E_ABORT; // set return result to abort. goto done; } } if (pcbDownloaded != NULL) *pcbDownloaded = cbDownloaded; done: SafeHeapFree(pbBuffer); return hr; } /////////////////////////////////////////////////////////////////////////////// // struct MY_OSVERSIONINFOEX { OSVERSIONINFOEX osvi; LCID lcidCompare; }; static MY_OSVERSIONINFOEX g_myosvi; static BOOL g_fInit = FALSE; // ************************************************************************** // Loads the current OS version info if needed, and returns a pointer to // a cached copy of it. const OSVERSIONINFOEX* GetOSVersionInfo(void) { if (g_fInit == FALSE) { OSVERSIONINFOEX* pOSVI = &g_myosvi.osvi; g_myosvi.osvi.dwOSVersionInfoSize = sizeof(g_myosvi.osvi); GetVersionEx((OSVERSIONINFO*)&g_myosvi.osvi); // WinXP-specific stuff if ((pOSVI->dwMajorVersion > 5) || (pOSVI->dwMajorVersion == 5 && pOSVI->dwMinorVersion >= 1)) g_myosvi.lcidCompare = LOCALE_INVARIANT; else g_myosvi.lcidCompare = MAKELCID(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US), SORT_DEFAULT); g_fInit = TRUE; } return &g_myosvi.osvi; } // ************************************************************************** // String lengths can be -1 if the strings are null-terminated. int LangNeutralStrCmpNIA(LPCSTR psz1, int cch1, LPCSTR psz2, int cch2) { if (g_fInit == FALSE) GetOSVersionInfo(); int nCompare = CompareStringA(g_myosvi.lcidCompare, NORM_IGNORECASE, psz1, cch1, psz2, cch2); return (nCompare - 2); // convert from (1, 2, 3) to (-1, 0, 1) } // ************************************************************************** // Finds the first instance of pszSearchFor in pszSearchIn, case-insensitive. // Returns an index into pszSearchIn if found, or -1 if not. // You can pass -1 for either or both of the lengths. int LangNeutralStrStrNIA(LPCSTR pszSearchIn, int cchSearchIn, LPCSTR pszSearchFor, int cchSearchFor) { char chLower, chUpper; if (cchSearchIn == -1) cchSearchIn = lstrlenA(pszSearchIn); if (cchSearchFor == -1) cchSearchFor = lstrlenA(pszSearchFor); // Note: since this is lang-neutral, we can assume no DBCS search chars chLower = (char)CharLowerA(MAKEINTRESOURCEA(*pszSearchFor)); chUpper = (char)CharUpperA(MAKEINTRESOURCEA(*pszSearchFor)); // Note: since search-for is lang-neutral, we can ignore any DBCS chars // in search-in for (int ichIn = 0; ichIn <= cchSearchIn - cchSearchFor; ichIn++) { if (pszSearchIn[ichIn] == chLower || pszSearchIn[ichIn] == chUpper) { if (LangNeutralStrCmpNIA(pszSearchIn + ichIn + 1, cchSearchFor - 1, pszSearchFor + 1, cchSearchFor - 1) == 0) { return ichIn; } } } return -1; } // ************************************************************************** // Opens the given file and looks for " 200) cbFile = 200; hMapping = CreateFileMapping(hFile, NULL, PAGE_READONLY, 0, cbFile, NULL); if (hMapping == NULL) { hr = HRESULT_FROM_WIN32(GetLastError()); LOG_ErrorMsg(hr); goto done; } pvMem = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, cbFile); if (pvMem == NULL) { hr = HRESULT_FROM_WIN32(GetLastError()); LOG_ErrorMsg(hr); goto done; } pszFile = (LPCSTR)pvMem; int ichHtml = LangNeutralStrStrNIA(pszFile, cbFile, " tag for (int ich = 0; ich < ichHtml; ich++) { char ch = pszFile[ich]; if (ch < 32 && ch != '\t' && ch != '\r' && ch != '\n') { // Found a binary character (before ) hr = S_FALSE; break; } } } done: if (pvMem != NULL) UnmapViewOfFile(pvMem); if (hMapping != NULL) CloseHandle(hMapping); if (hFile != INVALID_HANDLE_VALUE) CloseHandle(hFile); return hr; }