Source code of Windows XP (NT5)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

649 lines
20 KiB

#include <windows.h>
#include <shlwapi.h>
#include <tchar.h>
#include <winhttp.h>
#include "iucommon.h"
#include "logging.h"
#include "download.h"
#include "dlutil.h"
#include "malloc.h"
#include "wusafefn.h"
///////////////////////////////////////////////////////////////////////////////
//
typedef BOOL (WINAPI *pfn_OpenProcessToken)(HANDLE, DWORD, PHANDLE);
typedef BOOL (WINAPI *pfn_OpenThreadToken)(HANDLE, DWORD, BOOL, PHANDLE);
typedef BOOL (WINAPI *pfn_SetThreadToken)(PHANDLE, HANDLE);
typedef BOOL (WINAPI *pfn_GetTokenInformation)(HANDLE, TOKEN_INFORMATION_CLASS, LPVOID, DWORD, PDWORD);
typedef BOOL (WINAPI *pfn_IsValidSid)(PSID);
typedef BOOL (WINAPI *pfn_AllocateAndInitializeSid)(PSID_IDENTIFIER_AUTHORITY, BYTE, DWORD, DWORD, DWORD, DWORD, DWORD, DWORD, DWORD, DWORD, PSID);
typedef BOOL (WINAPI *pfn_EqualSid)(PSID, PSID);
typedef PVOID (WINAPI *pfn_FreeSid)(PSID);
const TCHAR c_szRPWU[] = _T("Software\\Microsoft\\Windows\\CurrentVersion\\WindowsUpdate");
const TCHAR c_szRVTransport[] = _T("DownloadTransport");
const TCHAR c_szAdvapi32[] = _T("advapi32.dll");
// ***************************************************************************
static
BOOL AmIPrivileged(void)
{
LOG_Block("AmINotPrivileged()");
pfn_AllocateAndInitializeSid pfnAllocateAndInitializeSid = NULL;
pfn_GetTokenInformation pfnGetTokenInformation = NULL;
pfn_OpenProcessToken pfnOpenProcessToken = NULL;
pfn_OpenThreadToken pfnOpenThreadToken = NULL;
pfn_SetThreadToken pfnSetThreadToken = NULL;
pfn_IsValidSid pfnIsValidSid = NULL;
pfn_EqualSid pfnEqualSid = NULL;
pfn_FreeSid pfnFreeSid = NULL;
HMODULE hmod = NULL;
SID_IDENTIFIER_AUTHORITY siaNT = SECURITY_NT_AUTHORITY;
TOKEN_USER *ptu = NULL;
HANDLE hToken = NULL, hTokenImp = NULL;
DWORD cb, cbGot, i;
PSID psid = NULL;
BOOL fRet = FALSE;
DWORD rgRIDs[3] = { SECURITY_LOCAL_SYSTEM_RID,
SECURITY_LOCAL_SERVICE_RID,
SECURITY_NETWORK_SERVICE_RID };
hmod = LoadLibraryFromSystemDir(c_szAdvapi32);
if (hmod == NULL)
goto done;
pfnAllocateAndInitializeSid = (pfn_AllocateAndInitializeSid)GetProcAddress(hmod, "AllocateAndInitializeSid");
pfnGetTokenInformation = (pfn_GetTokenInformation)GetProcAddress(hmod, "GetTokenInformation");
pfnOpenProcessToken = (pfn_OpenProcessToken)GetProcAddress(hmod, "OpenProcessToken");
pfnOpenThreadToken = (pfn_OpenThreadToken)GetProcAddress(hmod, "OpenThreadToken");
pfnSetThreadToken = (pfn_SetThreadToken)GetProcAddress(hmod, "SetThreadToken");
pfnIsValidSid = (pfn_IsValidSid)GetProcAddress(hmod, "IsValidSid");
pfnEqualSid = (pfn_EqualSid)GetProcAddress(hmod, "EqualSid");
pfnFreeSid = (pfn_FreeSid)GetProcAddress(hmod, "FreeSid");
if (pfnAllocateAndInitializeSid == NULL ||
pfnGetTokenInformation == NULL ||
pfnOpenProcessToken == NULL ||
pfnOpenThreadToken == NULL ||
pfnSetThreadToken == NULL ||
pfnIsValidSid == NULL ||
pfnEqualSid == NULL ||
pfnFreeSid == NULL)
{
SetLastError(ERROR_PROC_NOT_FOUND);
goto done;
}
// need the process token
fRet = (*pfnOpenProcessToken)(GetCurrentProcess(), TOKEN_READ, &hToken);
if (fRet == FALSE)
{
if (GetLastError() == ERROR_ACCESS_DENIED)
{
fRet = (*pfnOpenThreadToken)(GetCurrentThread(),
TOKEN_READ | TOKEN_IMPERSONATE,
TRUE, &hTokenImp);
if (fRet == FALSE)
goto done;
fRet = (*pfnSetThreadToken)(NULL, NULL);
fRet = (*pfnOpenProcessToken)(GetCurrentProcess(), TOKEN_READ,
&hToken);
if ((*pfnSetThreadToken)(NULL, hTokenImp) == FALSE)
fRet = FALSE;
}
if (fRet == FALSE)
goto done;
}
// need the SID from the token
fRet = (*pfnGetTokenInformation)(hToken, TokenUser, NULL, 0, &cb);
if (fRet != FALSE && GetLastError() != ERROR_INSUFFICIENT_BUFFER)
{
fRet = FALSE;
goto done;
}
ptu = (TOKEN_USER *)HeapAlloc(GetProcessHeap(), 0, cb);
if (ptu == NULL)
{
SetLastError(ERROR_OUTOFMEMORY);
fRet = FALSE;
goto done;
}
fRet = (*pfnGetTokenInformation)(hToken, TokenUser, (LPVOID)ptu, cb,
&cbGot);
if (fRet == FALSE)
goto done;
fRet = (*pfnIsValidSid)(ptu->User.Sid);
if (fRet == FALSE)
goto done;
// loop thru & check against the SIDs we are interested in
for (i = 0; i < 3; i++)
{
fRet = (*pfnAllocateAndInitializeSid)(&siaNT, 1, rgRIDs[i], 0, 0, 0,
0, 0, 0, 0, &psid);
if (fRet == FALSE)
goto done;
fRet = (*pfnIsValidSid)(psid);
if (fRet == FALSE)
goto done;
// if we get a SID match, then return TRUE
fRet = (*pfnEqualSid)(psid, ptu->User.Sid);
(*pfnFreeSid)(psid);
psid = NULL;
if (fRet)
{
fRet = TRUE;
goto done;
}
}
// only way to get here is to fail all the SID checks above. So we ain't
// privileged. Yeehaw.
fRet = FALSE;
done:
// if we had an impersonation token on the thread, put it back in place.
if (ptu != NULL)
HeapFree(GetProcessHeap(), 0, ptu);
if (hToken != NULL)
CloseHandle(hToken);
if (hTokenImp != NULL)
CloseHandle(hTokenImp);
if (psid != NULL && pfnFreeSid != NULL)
(*pfnFreeSid)(psid);
if (hmod != NULL)
FreeLibrary(hmod);
return fRet;
}
#if defined(DEBUG) || defined(DBG)
// **************************************************************************
static
BOOL CheckDebugRegKey(DWORD *pdwAllowed)
{
LOG_Block("CheckDebugRegKey()");
DWORD dw, dwType, dwValue, cb;
HKEY hkey = NULL;
BOOL fRet = FALSE;
// explictly do not initialize *pdwAllowed. We only want it overwritten
// if the reg key is properly set
dw = RegOpenKeyEx(HKEY_LOCAL_MACHINE, c_szRPWU, 0, KEY_READ, &hkey);
if (dw != ERROR_SUCCESS)
goto done;
cb = sizeof(dwValue);
dw = RegQueryValueEx(hkey, c_szRVTransport, 0, &dwType, (LPBYTE)&dwValue,
&cb);
if (dw != ERROR_SUCCESS)
goto done;
// set this to 3 so we'll fall down into the error case below
if (dwType != REG_DWORD)
dwValue = 3;
fRet = TRUE;
switch(dwValue)
{
case 0:
*pdwAllowed = 0;
break;
case 1:
*pdwAllowed = WUDF_ALLOWWINHTTPONLY;
break;
case 2:
*pdwAllowed = WUDF_ALLOWWININETONLY;
break;
default:
LOG_Internet(_T("Bad reg value in DownloadTransport. Ignoring."));
fRet = FALSE;
break;
}
done:
if (hkey != NULL)
RegCloseKey(hkey);
return fRet;
}
#endif
// **************************************************************************
DWORD GetAllowedDownloadTransport(DWORD dwFlagsInitial)
{
DWORD dwFlags = (dwFlagsInitial & WUDF_TRANSPORTMASK);
#if defined(UNICODE)
// don't bother checking if we're local system if we're already using
// wininet
if ((dwFlags & WUDF_ALLOWWININETONLY) == 0)
{
if (AmIPrivileged() == FALSE)
dwFlags = WUDF_ALLOWWININETONLY;
}
#if defined(DEBUG) || defined(DBG)
CheckDebugRegKey(&dwFlags);
#endif // defined(DEBUG) || defined(DBG)
#else // defined(UNICODE)
// only allow wininet on ANSI
dwFlags = WUDF_ALLOWWININETONLY;
#endif // defined(UNICODE)
return (dwFlags | (dwFlagsInitial & ~WUDF_TRANSPORTMASK));
}
///////////////////////////////////////////////////////////////////////////////
//
// **************************************************************************
static inline
BOOL IsServerFileDifferentWorker(FILETIME &ftServerTime,
DWORD dwServerFileSize, HANDLE hFile)
{
LOG_Block("IsServerFileNewerWorker()");
FILETIME ftCreateTime;
DWORD cbLocalFile;
// By default, always return TRUE so we can download a new file..
BOOL fRet = TRUE;
// if we don't have a valid file handle, just return TRUE to download a
// new copy
if (hFile == INVALID_HANDLE_VALUE)
goto done;
cbLocalFile = GetFileSize(hFile, NULL);
LOG_Internet(_T("IsServerFileNewer: Local size: %d. Remote size: %d"),
cbLocalFile, dwServerFileSize);
// if the sizes are not equal, then return TRUE
if (cbLocalFile != dwServerFileSize)
goto done;
if (GetFileTime(hFile, &ftCreateTime, NULL, NULL))
{
LOG_Internet(_T("IsServerFileNewer: Local time: %x%0x. Remote time: %x%0x."),
ftCreateTime.dwHighDateTime, ftCreateTime.dwLowDateTime,
ftServerTime.dwHighDateTime, ftServerTime.dwLowDateTime);
// if the local file has a different timestamp, then return TRUE.
fRet = (CompareFileTime(&ftCreateTime, &ftServerTime) != 0);
}
done:
return fRet;
}
// **************************************************************************
BOOL IsServerFileDifferentW(FILETIME &ftServerTime, DWORD dwServerFileSize,
LPCWSTR wszLocalFile)
{
LOG_Block("IsServerFileDifferentW()");
HANDLE hFile = INVALID_HANDLE_VALUE;
BOOL fRet = TRUE;
// if we have an error opening the file, just return TRUE to download a
// new copy
hFile = CreateFileW(wszLocalFile, GENERIC_READ, FILE_SHARE_READ, NULL,
OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
if (hFile == INVALID_HANDLE_VALUE)
{
LOG_Internet(_T("IsServerFileDifferent: %ls does not exist."), wszLocalFile);
return TRUE;
}
else
{
fRet = IsServerFileDifferentWorker(ftServerTime, dwServerFileSize, hFile);
CloseHandle(hFile);
return fRet;
}
}
// **************************************************************************
BOOL IsServerFileDifferentA(FILETIME &ftServerTime, DWORD dwServerFileSize,
LPCSTR szLocalFile)
{
LOG_Block("IsServerFileDifferentA()");
HANDLE hFile = INVALID_HANDLE_VALUE;
// if we have an error opening the file, just return TRUE to download a
// new copy
hFile = CreateFileA(szLocalFile, GENERIC_READ, FILE_SHARE_READ, NULL,
OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
if (hFile == INVALID_HANDLE_VALUE)
{
LOG_Internet(_T("IsServerFileDifferent: %s does not exist."), szLocalFile);
return TRUE;
}
else
{
BOOL fRet;
fRet = IsServerFileDifferentWorker(ftServerTime, dwServerFileSize, hFile);
CloseHandle(hFile);
return fRet;
}
}
// **************************************************************************
// helper function to handle quit events
//
// return TRUE if okay to continue
// return FALSE if we should quit now!
BOOL HandleEvents(HANDLE *phEvents, UINT nEventCount)
{
LOG_Block("HandleEvents()");
DWORD dwWait;
// is there any events to handle?
if (phEvents == NULL || nEventCount == 0)
return TRUE;
// we only want to check the signaled status, so don't bother waiting
dwWait = WaitForMultipleObjects(nEventCount, phEvents, FALSE, 0);
if (dwWait == WAIT_TIMEOUT)
{
return TRUE;
}
else
{
LOG_Internet(_T("HandleEvents: A quit event was signaled. Aborting..."));
return FALSE;
}
}
///////////////////////////////////////////////////////////////////////////////
//
// **************************************************************************
HRESULT PerformDownloadToFile(pfn_ReadDataFromSite pfnRead,
HINTERNET hRequest,
HANDLE hFile, DWORD cbFile,
DWORD cbBuffer,
HANDLE *rghEvents, DWORD cEvents,
PFNDownloadCallback fpnCallback, LPVOID pCallbackData,
DWORD *pcbDownloaded)
{
LOG_Block("PerformDownloadToFile()");
HRESULT hr = S_OK;
PBYTE pbBuffer = NULL;
DWORD cbDownloaded = 0, cbRead, cbWritten;
LONG lCallbackRequest = 0;
pbBuffer = (PBYTE)HeapAlloc(GetProcessHeap(), 0, cbBuffer);
if (pbBuffer == NULL)
{
hr = E_OUTOFMEMORY;
LOG_ErrorMsg(hr);
goto done;
}
// Download the File
for(;;)
{
if ((*pfnRead)(hRequest, pbBuffer, cbBuffer, &cbRead) == FALSE)
{
hr = HRESULT_FROM_WIN32(GetLastError());
if (FAILED(hr))
{
LOG_ErrorMsg(hr);
goto done;
}
}
if (cbRead == 0)
{
BYTE bTemp[32];
// Make one final call to WinHttpReadData to commit the file to
// Cache. (the download is not complete otherwise)
(*pfnRead)(hRequest, bTemp, ARRAYSIZE(bTemp), &cbRead);
break;
}
cbDownloaded += cbRead;
if (fpnCallback != NULL)
{
fpnCallback(pCallbackData, DOWNLOAD_STATUS_OK, cbFile, cbRead, NULL,
&lCallbackRequest);
if (lCallbackRequest == 4)
{
// QuitEvent was Signaled.. abort requested. We will do
// another callback and pass the Abort State back
fpnCallback(pCallbackData, DOWNLOAD_STATUS_ABORTED, cbFile, cbRead, NULL, NULL);
hr = E_ABORT; // set return result to abort.
goto done;
}
}
if (WriteFile(hFile, pbBuffer, cbRead, &cbWritten, NULL) == FALSE)
{
hr = HRESULT_FROM_WIN32(GetLastError());
LOG_ErrorMsg(hr);
goto done;
}
if (HandleEvents(rghEvents, cEvents) == FALSE)
{
// we need to quit the download clean up, send abort event and clean up what we've downloaded
if (fpnCallback != NULL)
fpnCallback(pCallbackData, DOWNLOAD_STATUS_ABORTED, cbFile, cbRead, NULL, NULL);
hr = E_ABORT; // set return result to abort.
goto done;
}
}
if (pcbDownloaded != NULL)
*pcbDownloaded = cbDownloaded;
done:
SafeHeapFree(pbBuffer);
return hr;
}
///////////////////////////////////////////////////////////////////////////////
//
struct MY_OSVERSIONINFOEX
{
OSVERSIONINFOEX osvi;
LCID lcidCompare;
};
static MY_OSVERSIONINFOEX g_myosvi;
static BOOL g_fInit = FALSE;
// **************************************************************************
// Loads the current OS version info if needed, and returns a pointer to
// a cached copy of it.
const OSVERSIONINFOEX* GetOSVersionInfo(void)
{
if (g_fInit == FALSE)
{
OSVERSIONINFOEX* pOSVI = &g_myosvi.osvi;
g_myosvi.osvi.dwOSVersionInfoSize = sizeof(g_myosvi.osvi);
GetVersionEx((OSVERSIONINFO*)&g_myosvi.osvi);
// WinXP-specific stuff
if ((pOSVI->dwMajorVersion > 5) ||
(pOSVI->dwMajorVersion == 5 && pOSVI->dwMinorVersion >= 1))
g_myosvi.lcidCompare = LOCALE_INVARIANT;
else
g_myosvi.lcidCompare = MAKELCID(MAKELANGID(LANG_ENGLISH, SUBLANG_ENGLISH_US), SORT_DEFAULT);
g_fInit = TRUE;
}
return &g_myosvi.osvi;
}
// **************************************************************************
// String lengths can be -1 if the strings are null-terminated.
int LangNeutralStrCmpNIA(LPCSTR psz1, int cch1, LPCSTR psz2, int cch2)
{
if (g_fInit == FALSE)
GetOSVersionInfo();
int nCompare = CompareStringA(g_myosvi.lcidCompare,
NORM_IGNORECASE,
psz1, cch1,
psz2, cch2);
return (nCompare - 2); // convert from (1, 2, 3) to (-1, 0, 1)
}
// **************************************************************************
// Finds the first instance of pszSearchFor in pszSearchIn, case-insensitive.
// Returns an index into pszSearchIn if found, or -1 if not.
// You can pass -1 for either or both of the lengths.
int LangNeutralStrStrNIA(LPCSTR pszSearchIn, int cchSearchIn,
LPCSTR pszSearchFor, int cchSearchFor)
{
char chLower, chUpper;
if (cchSearchIn == -1)
cchSearchIn = lstrlenA(pszSearchIn);
if (cchSearchFor == -1)
cchSearchFor = lstrlenA(pszSearchFor);
// Note: since this is lang-neutral, we can assume no DBCS search chars
chLower = (char)CharLowerA(MAKEINTRESOURCEA(*pszSearchFor));
chUpper = (char)CharUpperA(MAKEINTRESOURCEA(*pszSearchFor));
// Note: since search-for is lang-neutral, we can ignore any DBCS chars
// in search-in
for (int ichIn = 0; ichIn <= cchSearchIn - cchSearchFor; ichIn++)
{
if (pszSearchIn[ichIn] == chLower || pszSearchIn[ichIn] == chUpper)
{
if (LangNeutralStrCmpNIA(pszSearchIn + ichIn + 1, cchSearchFor - 1,
pszSearchFor + 1, cchSearchFor - 1) == 0)
{
return ichIn;
}
}
}
return -1;
}
// **************************************************************************
// Opens the given file and looks for "<html" (case-insensitive) within the
// first 200 characters. If there are any binary chars before "<html", the
// file is assumed to *not* be HTML.
// Returns S_OK if so, S_FALSE if not, or an error if file couldn't be opened.
HRESULT IsFileHtml(LPCTSTR pszFileName)
{
LOG_Block("IsFileHtml()");
HRESULT hr = S_FALSE;
LPCSTR pszFile;
HANDLE hFile = INVALID_HANDLE_VALUE;
HANDLE hMapping = NULL;
LPVOID pvMem = NULL;
DWORD cbFile;
hFile = CreateFile(pszFileName, GENERIC_READ,
FILE_SHARE_READ | FILE_SHARE_WRITE,
NULL, OPEN_EXISTING, FILE_FLAG_SEQUENTIAL_SCAN, NULL);
if (hFile == INVALID_HANDLE_VALUE)
{
hr = HRESULT_FROM_WIN32(GetLastError());
LOG_ErrorMsg(hr);
goto done;
}
cbFile = GetFileSize(hFile, NULL);
if (cbFile == 0)
goto done;
// Only examine the 1st 200 bytes
if (cbFile > 200)
cbFile = 200;
hMapping = CreateFileMapping(hFile, NULL, PAGE_READONLY, 0, cbFile, NULL);
if (hMapping == NULL)
{
hr = HRESULT_FROM_WIN32(GetLastError());
LOG_ErrorMsg(hr);
goto done;
}
pvMem = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, cbFile);
if (pvMem == NULL)
{
hr = HRESULT_FROM_WIN32(GetLastError());
LOG_ErrorMsg(hr);
goto done;
}
pszFile = (LPCSTR)pvMem;
int ichHtml = LangNeutralStrStrNIA(pszFile, cbFile, "<html", 5);
if (ichHtml != -1)
{
// Looks like html...
hr = S_OK;
// Just make sure there aren't any binary chars before the <HTML> tag
for (int ich = 0; ich < ichHtml; ich++)
{
char ch = pszFile[ich];
if (ch < 32 && ch != '\t' && ch != '\r' && ch != '\n')
{
// Found a binary character (before <HTML>)
hr = S_FALSE;
break;
}
}
}
done:
if (pvMem != NULL)
UnmapViewOfFile(pvMem);
if (hMapping != NULL)
CloseHandle(hMapping);
if (hFile != INVALID_HANDLE_VALUE)
CloseHandle(hFile);
return hr;
}