|
|
//+-------------------------------------------------------------------------
//
// Microsoft Windows
//
// Copyright (C) Microsoft Corporation, 2000
//
// File: mmcprotocol.h
//
// Purpose: Creates a temporary pluggable internet protocol, mmc://
//
// History: 14-April-2000 Vivekj added
//--------------------------------------------------------------------------
#include<stdafx.h>
#include<mmcprotocol.h>
#include "tasks.h"
#include "typeinfo.h" // for COleCacheCleanupObserver
// {3C5F432A-EF40-4669-9974-9671D4FC2E12}
static const CLSID CLSID_MMCProtocol = { 0x3c5f432a, 0xef40, 0x4669, { 0x99, 0x74, 0x96, 0x71, 0xd4, 0xfc, 0x2e, 0x12 } }; static const WCHAR szMMC[] = _W(MMC_PROTOCOL_SCHEMA_NAME); static const WCHAR szMMCC[] = _W(MMC_PROTOCOL_SCHEMA_NAME) _W(":"); static const WCHAR szPageBreak[] = _W(MMC_PAGEBREAK_RELATIVE_URL);
static const WCHAR szMMCRES[] = L"%mmcres%"; static const WCHAR chUNICODE = 0xfeff;
#ifdef DBG
CTraceTag tagProtocol(_T("MMC iNet Protocol"), _T("MMCProtocol")); #endif //DBG
/***************************************************************************\
* * FUNCTION: HasSchema * * PURPOSE: helper: determines if URL contains schema (like "something:" or "http:" ) * * PARAMETERS: * LPCWSTR strURL * * RETURNS: * bool ; true == does contain schema * \***************************************************************************/ inline bool HasSchema(LPCWSTR strURL) { if (strURL == NULL) return false;
// skip spaces and schema name
while ( iswspace(*strURL) || iswalnum(*strURL) ) strURL++;
// valid schema ends with ':'
return *strURL == L':'; }
/***************************************************************************\
* * FUNCTION: HasMMCSchema * * PURPOSE: helper: determines if URL contains mmc schema ( begins with "mmc:" ) * * PARAMETERS: * LPCWSTR strURL * * RETURNS: * bool ; true == does contain mmc schema * \***************************************************************************/ inline bool HasMMCSchema(LPCWSTR strURL) { if (strURL == NULL) return false;
// skip spaces
while ( iswspace(*strURL) ) strURL++;
return (0 == _wcsnicmp( strURL, szMMCC, wcslen(szMMCC) ) ); }
/***************************************************************************\
* * CLASS: CMMCProtocolRegistrar * * PURPOSE: register/ unregisters mmc protocol. * Also class provides cleanup functionality. Because it registers as * COleCacheCleanupObserver, it will receive the event when MMC * is about to uninitialize OLE, and will revoke registered mmc protocol * \***************************************************************************/ class CMMCProtocolRegistrar : public COleCacheCleanupObserver { bool m_bRegistered; IClassFactoryPtr m_spClassFactory; public: // c-tor.
CMMCProtocolRegistrar() : m_bRegistered(false) {}
// registration / unregistration
SC ScRegister(); SC ScUnregister();
// event sensor - unregisters mmc protocol
virtual SC ScOnReleaseCachedOleObjects() { DECLARE_SC(sc, TEXT("ScOnReleaseCachedOleObjects"));
return sc = ScUnregister(); } };
/***************************************************************************\
* * METHOD: CMMCProtocolRegistrar::ScRegister * * PURPOSE: registers the protocol if required * * PARAMETERS: * * RETURNS: * SC - result code * \***************************************************************************/ SC CMMCProtocolRegistrar::ScRegister() { DECLARE_SC(sc, TEXT("CMMCProtocolRegistrar::ScRegister"));
// one time registration only
if(m_bRegistered) return sc;
// get internet session
IInternetSessionPtr spInternetSession; sc = CoInternetGetSession(0, &spInternetSession, 0); if(sc) return sc;
// doublecheck
sc = ScCheckPointers(spInternetSession, E_FAIL); if(sc) return sc;
// ask CComModule for the class factory
sc = _Module.GetClassObject(CLSID_MMCProtocol, IID_IClassFactory, (void **)&m_spClassFactory); if(sc) return sc;
// register the namespace
sc = spInternetSession->RegisterNameSpace(m_spClassFactory, CLSID_MMCProtocol, szMMC, 0, NULL, 0); if(sc) return sc;
// start observing cleanup requests - to unregister in time
COleCacheCleanupManager::AddOleObserver(this);
m_bRegistered = true; // did it.
return sc; }
/***************************************************************************\
* * METHOD: CMMCProtocolRegistrar::ScUnregister * * PURPOSE: unregisters the protocol if one was registered * * PARAMETERS: * * RETURNS: * SC - result code * \***************************************************************************/ SC CMMCProtocolRegistrar::ScUnregister() { DECLARE_SC(sc, TEXT("CMMCProtocolRegistrar::ScUnregister"));
if (!m_bRegistered) return sc;
// unregister
IInternetSessionPtr spInternetSession; sc = CoInternetGetSession(0, &spInternetSession, 0); if(sc) { sc.Clear(); // no session - no headache
} else // need to unregister
{ // recheck
sc = ScCheckPointers(spInternetSession, E_UNEXPECTED); if(sc) return sc;
// unregister the namespace
sc = spInternetSession->UnregisterNameSpace(m_spClassFactory, szMMC); if(sc) return sc; }
m_spClassFactory.Release(); m_bRegistered = false;
return sc; }
/***************************************************************************\
* * METHOD: CMMCProtocol::ScRegisterProtocol * * PURPOSE: Registers mmc protocol. IE will resove "mmc:..." ULRs to it * * PARAMETERS: * * RETURNS: * SC - result code * \***************************************************************************/ SC CMMCProtocol::ScRegisterProtocol() { DECLARE_SC(sc, TEXT("CMMCProtocol::ScRegisterProtocol"));
// registrar (unregisters on cleanup event) - needs to be static
static CMMCProtocolRegistrar registrar;
// let the registrar do the job
return sc = registrar.ScRegister(); }
//*****************************************************************************
// IInternetProtocolRoot interface
//*****************************************************************************
/***************************************************************************\
* * METHOD: CMMCProtocol::Start * * PURPOSE: Starts data download thru this protocol * * PARAMETERS: * LPCWSTR szUrl * IInternetProtocolSink *pOIProtSink * IInternetBindInfo *pOIBindInfo * DWORD grfPI * HANDLE_PTR dwReserved * * RETURNS: * SC - result code * \***************************************************************************/ STDMETHODIMP CMMCProtocol::Start(LPCWSTR szUrl, IInternetProtocolSink *pOIProtSink, IInternetBindInfo *pOIBindInfo, DWORD grfPI, HANDLE_PTR dwReserved) { DECLARE_SC(sc, TEXT("CMMCProtocol::Start"));
// check inputs
sc = ScCheckPointers(szUrl, pOIProtSink, pOIBindInfo); if(sc) return sc.ToHr();
// reset position for reading
m_uiReadOffs = 0;
bool bPageBreakRequest = false;
// see if it was a pagebreak requested
sc = ScParsePageBreakURL( szUrl, bPageBreakRequest ); if(sc) return sc.ToHr();
if ( bPageBreakRequest ) { // just report success (S_OK/S_FALSE) in case we were just parsing
if ( grfPI & PI_PARSE_URL ) return sc.ToHr();
// construct a pagebreak
m_strData = L"<HTML/>";
sc = pOIProtSink->ReportProgress(BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE, L"text/html"); if (sc) sc.TraceAndClear(); // ignore and continue
} else { //if not a pagebreak - then taskpad
GUID guidTaskpad = GUID_NULL; sc = ScParseTaskpadURL( szUrl, guidTaskpad ); if(sc) return sc.ToHr();
// report the S_FALSE instead of error in case we were just parsing
if ( grfPI & PI_PARSE_URL ) return ( sc.IsError() ? (sc = S_FALSE) : sc ).ToHr();
if (sc) return sc.ToHr();
// load the contents
sc = ScGetTaskpadXML( guidTaskpad, m_strData ); if (sc) return sc.ToHr();
sc = pOIProtSink->ReportProgress(BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE, L"text/html"); if (sc) sc.TraceAndClear(); // ignore and continue
}
const DWORD grfBSCF = BSCF_LASTDATANOTIFICATION | BSCF_DATAFULLYAVAILABLE; const DWORD dwDataSize = m_strData.length() * sizeof (WCHAR); sc = pOIProtSink->ReportData(grfBSCF, dwDataSize , dwDataSize); if (sc) sc.TraceAndClear(); // ignore and continue
sc = pOIProtSink->ReportResult(0, 0, 0); if (sc) sc.TraceAndClear(); // ignore and continue
return sc.ToHr(); }
STDMETHODIMP CMMCProtocol::Continue(PROTOCOLDATA *pProtocolData) { return E_NOTIMPL; } STDMETHODIMP CMMCProtocol::Abort(HRESULT hrReason, DWORD dwOptions) { return S_OK; } STDMETHODIMP CMMCProtocol::Terminate(DWORD dwOptions) { return S_OK; } STDMETHODIMP CMMCProtocol::LockRequest(DWORD dwOptions) { return S_OK; } STDMETHODIMP CMMCProtocol::UnlockRequest() { return S_OK; } STDMETHODIMP CMMCProtocol::Suspend() { return E_NOTIMPL; } STDMETHODIMP CMMCProtocol::Resume() { return E_NOTIMPL; }
STDMETHODIMP CMMCProtocol::Seek(LARGE_INTEGER dlibMove, DWORD dwOrigin, ULARGE_INTEGER *plibNewPosition) { return E_NOTIMPL; }
//*****************************************************************************
// IInternetProtocol interface
//*****************************************************************************
/***************************************************************************\
* * METHOD: CMMCProtocol::Read * * PURPOSE: Reads data from the protocol * * PARAMETERS: * void *pv * ULONG cb * ULONG *pcbRead * * RETURNS: * SC - result code * \***************************************************************************/ STDMETHODIMP CMMCProtocol::Read(void *pv, ULONG cb, ULONG *pcbRead) { DECLARE_SC(sc, TEXT("CMMCProtocol::Read"));
// parameter check;
sc = ScCheckPointers(pv, pcbRead); if(sc) return sc.ToHr();
// init out parameter;
*pcbRead = 0;
size_t size = ( m_strData.length() ) * sizeof(WCHAR);
if ( size <= m_uiReadOffs ) return (sc = S_FALSE).ToHr(); // no more data
// calculate the size we'll return
*pcbRead = size - m_uiReadOffs; if (size - m_uiReadOffs > cb) *pcbRead = cb;
if (*pcbRead) memcpy( pv, reinterpret_cast<const BYTE*>( m_strData.begin() ) + m_uiReadOffs, *pcbRead );
m_uiReadOffs += *pcbRead;
if ( size <= m_uiReadOffs ) return (sc = S_FALSE).ToHr(); // no more data
return sc.ToHr(); }
//*****************************************************************************
// IInternetProtocolInfo interface
//*****************************************************************************
STDMETHODIMP CMMCProtocol::ParseUrl( LPCWSTR pwzUrl, PARSEACTION ParseAction, DWORD dwParseFlags, LPWSTR pwzResult, DWORD cchResult, DWORD *pcchResult, DWORD dwReserved) { DECLARE_SC(sc, TEXT("CMMCProtocol::ParseUrl"));
if (ParseAction == PARSE_SECURITY_URL) { // get system directory (like "c:\winnt\system32\")
std::wstring windir; AppendMMCPath(windir); windir += L'\\';
// we are as secure as windir is - report the url (like "c:\winnt\system32\")
*pcchResult = windir.length() + 1;
// check if we have enough place for the result and terminating zero
if ( cchResult <= windir.length() ) return S_FALSE; // not enough
sc = StringCchCopyW(pwzResult,cchResult, windir.c_str()); if(sc) return sc.ToHr();
return (sc = S_OK).ToHr(); }
return INET_E_DEFAULT_ACTION; }
/***************************************************************************\
* * METHOD: CMMCProtocol::CombineUrl * * PURPOSE: combines base + relative url to resulting url * we do local variable substitution here * * PARAMETERS: * LPCWSTR pwzBaseUrl * LPCWSTR pwzRelativeUrl * DWORD dwCombineFlags * LPWSTR pwzResult * DWORD cchResult * DWORD *pcchResult * DWORD dwReserved * * RETURNS: * SC - result code * \***************************************************************************/ STDMETHODIMP CMMCProtocol::CombineUrl(LPCWSTR pwzBaseUrl, LPCWSTR pwzRelativeUrl, DWORD dwCombineFlags, LPWSTR pwzResult, DWORD cchResult, DWORD *pcchResult, DWORD dwReserved) { DECLARE_SC(sc, TEXT("CMMCProtocol::CombineUrl"));
#ifdef DBG
USES_CONVERSION; Trace(tagProtocol, _T("CombineUrl: [%s] + [%s]"), W2CT(pwzBaseUrl), W2CT(pwzRelativeUrl)); #endif //DBG
std::wstring temp1; if (HasMMCSchema(pwzBaseUrl)) { // our stuff
temp1 = pwzRelativeUrl; ExpandMMCVars(temp1);
if ( ! HasSchema( temp1.c_str() ) ) { // combine everything into relative URL
temp1.insert( 0, pwzBaseUrl ); }
// form 'new' relative address
pwzRelativeUrl = temp1.c_str();
// say we are refered from http - let it do the dirty job ;)
pwzBaseUrl = L"http://"; }
// since we stripped out ourselfs from pwzBaseUrl - it will not recurse back,
// but will do original html stuff
sc = CoInternetCombineUrl( pwzBaseUrl, pwzRelativeUrl, dwCombineFlags, pwzResult, cchResult, pcchResult, dwReserved ); if (sc) return sc.ToHr();
Trace(tagProtocol, _T("CombineUrl: == [%s]"), W2CT(pwzResult));
return sc.ToHr(); }
/***************************************************************************\
* * METHOD: CMMCProtocol::CompareUrl * * PURPOSE: compares URLs if they are the same * * PARAMETERS: * LPCWSTR pwzUrl1 * LPCWSTR pwzUrl2 * DWORD dwCompareFlags * * RETURNS: * SC - result code * \***************************************************************************/ STDMETHODIMP CMMCProtocol::CompareUrl(LPCWSTR pwzUrl1, LPCWSTR pwzUrl2,DWORD dwCompareFlags) { DECLARE_SC(sc, TEXT("CMMCProtocol::CompareUrl"));
return INET_E_DEFAULT_ACTION; }
/***************************************************************************\
* * METHOD: CMMCProtocol::QueryInfo * * PURPOSE: Queries info about URL * * PARAMETERS: * LPCWSTR pwzUrl * QUERYOPTION QueryOption * DWORD dwQueryFlags * LPVOID pBuffer * DWORD cbBuffer * DWORD *pcbBuf * DWORD dwReserved * * RETURNS: * SC - result code * \***************************************************************************/ STDMETHODIMP CMMCProtocol::QueryInfo( LPCWSTR pwzUrl, QUERYOPTION QueryOption,DWORD dwQueryFlags, LPVOID pBuffer, DWORD cbBuffer, DWORD *pcbBuf, DWORD dwReserved) { DECLARE_SC(sc, TEXT("CMMCProtocol::QueryInfo"));
if (QueryOption == QUERY_USES_NETWORK) { if (cbBuffer >= 4) { *(LPDWORD)pBuffer = FALSE; // does not use the network
*pcbBuf = 4; return S_OK; } } else if (QueryOption == QUERY_IS_SAFE) { if (cbBuffer >= 4) { *(LPDWORD)pBuffer = TRUE; // only serves trusted content
*pcbBuf = 4; return S_OK; } }
return INET_E_DEFAULT_ACTION; }
/***************************************************************************\
* * METHOD: CMMCProtocol::ScParseTaskpadURL * * PURPOSE: Extracts taskpad guid from URL given to the protocol * * PARAMETERS: * LPCWSTR strURL [in] - URL * GUID& guid [out] - extracted guid * * RETURNS: * SC - result code * \***************************************************************************/ SC CMMCProtocol::ScParseTaskpadURL( LPCWSTR strURL, GUID& guid ) { DECLARE_SC(sc, TEXT("CMMCProtocol::ScParseTaskpadURL"));
guid = GUID_NULL;
sc = ScCheckPointers(strURL); if (sc) return sc;
// taskpad url should be in form "mmc:{guid}"
// check for "mmc:"
if ( 0 != _wcsnicmp( strURL, szMMCC, wcslen(szMMCC) ) ) return sc = E_FAIL;
// skip "mmc:"
strURL += wcslen(szMMCC);
// get the url
sc = CLSIDFromString( const_cast<LPWSTR>(strURL), &guid ); if (sc) return sc;
return sc; }
/***************************************************************************\
* * METHOD: CMMCProtocol::ScParsePageBreakURL * * PURPOSE: Checks if URL given to the protocol is a request for a pagebreak * * PARAMETERS: * LPCWSTR strURL [in] - URL * bool& bPageBreak [out] - true it it is a request for pagebreak * * RETURNS: * SC - result code * \***************************************************************************/ SC CMMCProtocol::ScParsePageBreakURL( LPCWSTR strURL, bool& bPageBreak ) { DECLARE_SC(sc, TEXT("CMMCProtocol::ScParsePageBreakURL"));
bPageBreak = false;
sc = ScCheckPointers(strURL); if (sc) return sc;
// pagebreak url should be in form "mmc:pagebreak.<number>"
// check for "mmc:"
if ( 0 != _wcsnicmp( strURL, szMMCC, wcslen(szMMCC) ) ) return sc; // not an error - return value updated
// skip "mmc:"
strURL += wcslen(szMMCC);
// get the url
bPageBreak = ( 0 == wcsncmp( strURL, szPageBreak, wcslen(szPageBreak) ) );
return sc; }
/***************************************************************************\
* * METHOD: CMMCProtocol::ScGetTaskpadXML * * PURPOSE: given the guid uploads taskpad XML string to the string * * PARAMETERS: * const GUID& guid [in] - taskpad guid * std::wstring& strResultData [out] - taskpad xml string * * RETURNS: * SC - result code * \***************************************************************************/ SC CMMCProtocol::ScGetTaskpadXML( const GUID& guid, std::wstring& strResultData ) { DECLARE_SC(sc, TEXT("CMMCProtocol::ScGetTaskpadXML"));
strResultData.erase();
CScopeTree* pScopeTree = CScopeTree::GetScopeTree();
sc = ScCheckPointers(pScopeTree, E_FAIL); if(sc) return sc.ToHr();
CConsoleTaskpadList * pConsoleTaskpadList = pScopeTree->GetConsoleTaskpadList(); sc = ScCheckPointers(pConsoleTaskpadList, E_FAIL); if(sc) return sc.ToHr();
for(CConsoleTaskpadList::iterator iter = pConsoleTaskpadList->begin(); iter!= pConsoleTaskpadList->end(); ++iter) { CConsoleTaskpad &consoleTaskpad = *iter;
// check if this is the one we are looking for
if ( !IsEqualGUID( guid, consoleTaskpad.GetID() ) ) continue;
// convert the taskpad to a string
CStr strTaskpadHTML; sc = consoleTaskpad.ScGetHTML(strTaskpadHTML); // create a string version of the taskpad
if(sc) return sc.ToHr();
// form the result string
USES_CONVERSION; strResultData = chUNICODE; strResultData += T2CW(strTaskpadHTML);
return sc; }
// not found
return sc = E_FAIL; }
/***************************************************************************\
* * METHOD: CMMCProtocol::AppendMMCPath * * PURPOSE: helper. Appends the mmcndmgr.dll dir (no file name) to the string * It may append something like: "c:\winnt\system32" * * PARAMETERS: * std::wstring& str [in/out] - string to edit * * RETURNS: * SC - result code * \***************************************************************************/ void CMMCProtocol::AppendMMCPath(std::wstring& str) { TCHAR szModule[_MAX_PATH+10] = { 0 }; DWORD cchSize = countof(szModule); DWORD dwRet = GetModuleFileName(_Module.GetModuleInstance(), szModule, cchSize); if(0==dwRet) return;
// NTRAID#NTBUG9-613782-2002/05/02-ronmart-prefast warning 53: Call to 'GetModuleFileNameW' may not zero-terminate string
szModule[cchSize - 1] = 0;
USES_CONVERSION; LPCWSTR strModule = T2CW(szModule);
LPCWSTR dirEnd = wcsrchr( strModule, L'\\' ); if (dirEnd != NULL) str.append(strModule, dirEnd); }
/***************************************************************************\
* * METHOD: CMMCProtocol::ExpandMMCVars * * PURPOSE: helper. expands any %mmcres% contained in the string * It expands it to something like "res://c:\winnt\system32\mmcndmgr.dll" * * PARAMETERS: * std::wstring& str [in/out] - string to edit * * RETURNS: * SC - result code * \***************************************************************************/ void CMMCProtocol::ExpandMMCVars(std::wstring& str) { TCHAR szModule[_MAX_PATH+10] = { 0 }; DWORD cchSize = countof(szModule); DWORD dwRet = GetModuleFileName(_Module.GetModuleInstance(), szModule, cchSize); if(0==dwRet) return;
// NTRAID#NTBUG9-613782-2002/05/02-ronmart-prefast warning 53: Call to 'GetModuleFileNameW' may not zero-terminate string
szModule[cchSize - 1] = 0;
USES_CONVERSION; LPCWSTR strModule = T2CW(szModule);
std::wstring mmcres = L"res://"; mmcres += strModule;
// second - replace the instances
int pos; while (std::wstring::npos != (pos = str.find(szMMCRES) ) ) { // make one substitution
str.replace( pos, wcslen(szMMCRES), mmcres) ; } }
|