//+------------------------------------------------------------------------- // // Microsoft Windows // // Copyright (C) Microsoft Corporation, 2000 // // File: mmcprotocol.h // // Purpose: Creates a temporary pluggable internet protocol, mmc:// // // History: 14-April-2000 Vivekj added //-------------------------------------------------------------------------- #include #include #include "tasks.h" #include "typeinfo.h" // for COleCacheCleanupObserver // {3C5F432A-EF40-4669-9974-9671D4FC2E12} static const CLSID CLSID_MMCProtocol = { 0x3c5f432a, 0xef40, 0x4669, { 0x99, 0x74, 0x96, 0x71, 0xd4, 0xfc, 0x2e, 0x12 } }; static const WCHAR szMMC[] = _W(MMC_PROTOCOL_SCHEMA_NAME); static const WCHAR szMMCC[] = _W(MMC_PROTOCOL_SCHEMA_NAME) _W(":"); static const WCHAR szPageBreak[] = _W(MMC_PAGEBREAK_RELATIVE_URL); static const WCHAR szMMCRES[] = L"%mmcres%"; static const WCHAR chUNICODE = 0xfeff; #ifdef DBG CTraceTag tagProtocol(_T("MMC iNet Protocol"), _T("MMCProtocol")); #endif //DBG /***************************************************************************\ * * FUNCTION: HasSchema * * PURPOSE: helper: determines if URL contains schema (like "something:" or "http:" ) * * PARAMETERS: * LPCWSTR strURL * * RETURNS: * bool ; true == does contain schema * \***************************************************************************/ inline bool HasSchema(LPCWSTR strURL) { if (strURL == NULL) return false; // skip spaces and schema name while ( iswspace(*strURL) || iswalnum(*strURL) ) strURL++; // valid schema ends with ':' return *strURL == L':'; } /***************************************************************************\ * * FUNCTION: HasMMCSchema * * PURPOSE: helper: determines if URL contains mmc schema ( begins with "mmc:" ) * * PARAMETERS: * LPCWSTR strURL * * RETURNS: * bool ; true == does contain mmc schema * \***************************************************************************/ inline bool HasMMCSchema(LPCWSTR strURL) { if (strURL == NULL) return false; // skip spaces while ( iswspace(*strURL) ) strURL++; return (0 == _wcsnicmp( strURL, szMMCC, wcslen(szMMCC) ) ); } /***************************************************************************\ * * CLASS: CMMCProtocolRegistrar * * PURPOSE: register/ unregisters mmc protocol. * Also class provides cleanup functionality. Because it registers as * COleCacheCleanupObserver, it will receive the event when MMC * is about to uninitialize OLE, and will revoke registered mmc protocol * \***************************************************************************/ class CMMCProtocolRegistrar : public COleCacheCleanupObserver { bool m_bRegistered; IClassFactoryPtr m_spClassFactory; public: // c-tor. CMMCProtocolRegistrar() : m_bRegistered(false) {} // registration / unregistration SC ScRegister(); SC ScUnregister(); // event sensor - unregisters mmc protocol virtual SC ScOnReleaseCachedOleObjects() { DECLARE_SC(sc, TEXT("ScOnReleaseCachedOleObjects")); return sc = ScUnregister(); } }; /***************************************************************************\ * * METHOD: CMMCProtocolRegistrar::ScRegister * * PURPOSE: registers the protocol if required * * PARAMETERS: * * RETURNS: * SC - result code * \***************************************************************************/ SC CMMCProtocolRegistrar::ScRegister() { DECLARE_SC(sc, TEXT("CMMCProtocolRegistrar::ScRegister")); // one time registration only if(m_bRegistered) return sc; // get internet session IInternetSessionPtr spInternetSession; sc = CoInternetGetSession(0, &spInternetSession, 0); if(sc) return sc; // doublecheck sc = ScCheckPointers(spInternetSession, E_FAIL); if(sc) return sc; // ask CComModule for the class factory sc = _Module.GetClassObject(CLSID_MMCProtocol, IID_IClassFactory, (void **)&m_spClassFactory); if(sc) return sc; // register the namespace sc = spInternetSession->RegisterNameSpace(m_spClassFactory, CLSID_MMCProtocol, szMMC, 0, NULL, 0); if(sc) return sc; // start observing cleanup requests - to unregister in time COleCacheCleanupManager::AddOleObserver(this); m_bRegistered = true; // did it. return sc; } /***************************************************************************\ * * METHOD: CMMCProtocolRegistrar::ScUnregister * * PURPOSE: unregisters the protocol if one was registered * * PARAMETERS: * * RETURNS: * SC - result code * \***************************************************************************/ SC CMMCProtocolRegistrar::ScUnregister() { DECLARE_SC(sc, TEXT("CMMCProtocolRegistrar::ScUnregister")); if (!m_bRegistered) return sc; // unregister IInternetSessionPtr spInternetSession; sc = CoInternetGetSession(0, &spInternetSession, 0); if(sc) { sc.Clear(); // no session - no headache } else // need to unregister { // recheck sc = ScCheckPointers(spInternetSession, E_UNEXPECTED); if(sc) return sc; // unregister the namespace sc = spInternetSession->UnregisterNameSpace(m_spClassFactory, szMMC); if(sc) return sc; } m_spClassFactory.Release(); m_bRegistered = false; return sc; } /***************************************************************************\ * * METHOD: CMMCProtocol::ScRegisterProtocol * * PURPOSE: Registers mmc protocol. IE will resove "mmc:..." ULRs to it * * PARAMETERS: * * RETURNS: * SC - result code * \***************************************************************************/ SC CMMCProtocol::ScRegisterProtocol() { DECLARE_SC(sc, TEXT("CMMCProtocol::ScRegisterProtocol")); // registrar (unregisters on cleanup event) - needs to be static static CMMCProtocolRegistrar registrar; // let the registrar do the job return sc = registrar.ScRegister(); } //***************************************************************************** // IInternetProtocolRoot interface //***************************************************************************** /***************************************************************************\ * * METHOD: CMMCProtocol::Start * * PURPOSE: Starts data download thru this protocol * * PARAMETERS: * LPCWSTR szUrl * IInternetProtocolSink *pOIProtSink * IInternetBindInfo *pOIBindInfo * DWORD grfPI * HANDLE_PTR dwReserved * * RETURNS: * SC - result code * \***************************************************************************/ STDMETHODIMP CMMCProtocol::Start(LPCWSTR szUrl, IInternetProtocolSink *pOIProtSink, IInternetBindInfo *pOIBindInfo, DWORD grfPI, HANDLE_PTR dwReserved) { DECLARE_SC(sc, TEXT("CMMCProtocol::Start")); // check inputs sc = ScCheckPointers(szUrl, pOIProtSink, pOIBindInfo); if(sc) return sc.ToHr(); // reset position for reading m_uiReadOffs = 0; bool bPageBreakRequest = false; // see if it was a pagebreak requested sc = ScParsePageBreakURL( szUrl, bPageBreakRequest ); if(sc) return sc.ToHr(); if ( bPageBreakRequest ) { // just report success (S_OK/S_FALSE) in case we were just parsing if ( grfPI & PI_PARSE_URL ) return sc.ToHr(); // construct a pagebreak m_strData = L""; sc = pOIProtSink->ReportProgress(BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE, L"text/html"); if (sc) sc.TraceAndClear(); // ignore and continue } else { //if not a pagebreak - then taskpad GUID guidTaskpad = GUID_NULL; sc = ScParseTaskpadURL( szUrl, guidTaskpad ); if(sc) return sc.ToHr(); // report the S_FALSE instead of error in case we were just parsing if ( grfPI & PI_PARSE_URL ) return ( sc.IsError() ? (sc = S_FALSE) : sc ).ToHr(); if (sc) return sc.ToHr(); // load the contents sc = ScGetTaskpadXML( guidTaskpad, m_strData ); if (sc) return sc.ToHr(); sc = pOIProtSink->ReportProgress(BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE, L"text/html"); if (sc) sc.TraceAndClear(); // ignore and continue } const DWORD grfBSCF = BSCF_LASTDATANOTIFICATION | BSCF_DATAFULLYAVAILABLE; const DWORD dwDataSize = m_strData.length() * sizeof (WCHAR); sc = pOIProtSink->ReportData(grfBSCF, dwDataSize , dwDataSize); if (sc) sc.TraceAndClear(); // ignore and continue sc = pOIProtSink->ReportResult(0, 0, 0); if (sc) sc.TraceAndClear(); // ignore and continue return sc.ToHr(); } STDMETHODIMP CMMCProtocol::Continue(PROTOCOLDATA *pProtocolData) { return E_NOTIMPL; } STDMETHODIMP CMMCProtocol::Abort(HRESULT hrReason, DWORD dwOptions) { return S_OK; } STDMETHODIMP CMMCProtocol::Terminate(DWORD dwOptions) { return S_OK; } STDMETHODIMP CMMCProtocol::LockRequest(DWORD dwOptions) { return S_OK; } STDMETHODIMP CMMCProtocol::UnlockRequest() { return S_OK; } STDMETHODIMP CMMCProtocol::Suspend() { return E_NOTIMPL; } STDMETHODIMP CMMCProtocol::Resume() { return E_NOTIMPL; } STDMETHODIMP CMMCProtocol::Seek(LARGE_INTEGER dlibMove, DWORD dwOrigin, ULARGE_INTEGER *plibNewPosition) { return E_NOTIMPL; } //***************************************************************************** // IInternetProtocol interface //***************************************************************************** /***************************************************************************\ * * METHOD: CMMCProtocol::Read * * PURPOSE: Reads data from the protocol * * PARAMETERS: * void *pv * ULONG cb * ULONG *pcbRead * * RETURNS: * SC - result code * \***************************************************************************/ STDMETHODIMP CMMCProtocol::Read(void *pv, ULONG cb, ULONG *pcbRead) { DECLARE_SC(sc, TEXT("CMMCProtocol::Read")); // parameter check; sc = ScCheckPointers(pv, pcbRead); if(sc) return sc.ToHr(); // init out parameter; *pcbRead = 0; size_t size = ( m_strData.length() ) * sizeof(WCHAR); if ( size <= m_uiReadOffs ) return (sc = S_FALSE).ToHr(); // no more data // calculate the size we'll return *pcbRead = size - m_uiReadOffs; if (size - m_uiReadOffs > cb) *pcbRead = cb; if (*pcbRead) memcpy( pv, reinterpret_cast( m_strData.begin() ) + m_uiReadOffs, *pcbRead ); m_uiReadOffs += *pcbRead; if ( size <= m_uiReadOffs ) return (sc = S_FALSE).ToHr(); // no more data return sc.ToHr(); } //***************************************************************************** // IInternetProtocolInfo interface //***************************************************************************** STDMETHODIMP CMMCProtocol::ParseUrl( LPCWSTR pwzUrl, PARSEACTION ParseAction, DWORD dwParseFlags, LPWSTR pwzResult, DWORD cchResult, DWORD *pcchResult, DWORD dwReserved) { DECLARE_SC(sc, TEXT("CMMCProtocol::ParseUrl")); if (ParseAction == PARSE_SECURITY_URL) { // get system directory (like "c:\winnt\system32\") std::wstring windir; AppendMMCPath(windir); windir += L'\\'; // we are as secure as windir is - report the url (like "c:\winnt\system32\") *pcchResult = windir.length() + 1; // check if we have enough place for the result and terminating zero if ( cchResult <= windir.length() ) return S_FALSE; // not enough sc = StringCchCopyW(pwzResult,cchResult, windir.c_str()); if(sc) return sc.ToHr(); return (sc = S_OK).ToHr(); } return INET_E_DEFAULT_ACTION; } /***************************************************************************\ * * METHOD: CMMCProtocol::CombineUrl * * PURPOSE: combines base + relative url to resulting url * we do local variable substitution here * * PARAMETERS: * LPCWSTR pwzBaseUrl * LPCWSTR pwzRelativeUrl * DWORD dwCombineFlags * LPWSTR pwzResult * DWORD cchResult * DWORD *pcchResult * DWORD dwReserved * * RETURNS: * SC - result code * \***************************************************************************/ STDMETHODIMP CMMCProtocol::CombineUrl(LPCWSTR pwzBaseUrl, LPCWSTR pwzRelativeUrl, DWORD dwCombineFlags, LPWSTR pwzResult, DWORD cchResult, DWORD *pcchResult, DWORD dwReserved) { DECLARE_SC(sc, TEXT("CMMCProtocol::CombineUrl")); #ifdef DBG USES_CONVERSION; Trace(tagProtocol, _T("CombineUrl: [%s] + [%s]"), W2CT(pwzBaseUrl), W2CT(pwzRelativeUrl)); #endif //DBG std::wstring temp1; if (HasMMCSchema(pwzBaseUrl)) { // our stuff temp1 = pwzRelativeUrl; ExpandMMCVars(temp1); if ( ! HasSchema( temp1.c_str() ) ) { // combine everything into relative URL temp1.insert( 0, pwzBaseUrl ); } // form 'new' relative address pwzRelativeUrl = temp1.c_str(); // say we are refered from http - let it do the dirty job ;) pwzBaseUrl = L"http://"; } // since we stripped out ourselfs from pwzBaseUrl - it will not recurse back, // but will do original html stuff sc = CoInternetCombineUrl( pwzBaseUrl, pwzRelativeUrl, dwCombineFlags, pwzResult, cchResult, pcchResult, dwReserved ); if (sc) return sc.ToHr(); Trace(tagProtocol, _T("CombineUrl: == [%s]"), W2CT(pwzResult)); return sc.ToHr(); } /***************************************************************************\ * * METHOD: CMMCProtocol::CompareUrl * * PURPOSE: compares URLs if they are the same * * PARAMETERS: * LPCWSTR pwzUrl1 * LPCWSTR pwzUrl2 * DWORD dwCompareFlags * * RETURNS: * SC - result code * \***************************************************************************/ STDMETHODIMP CMMCProtocol::CompareUrl(LPCWSTR pwzUrl1, LPCWSTR pwzUrl2,DWORD dwCompareFlags) { DECLARE_SC(sc, TEXT("CMMCProtocol::CompareUrl")); return INET_E_DEFAULT_ACTION; } /***************************************************************************\ * * METHOD: CMMCProtocol::QueryInfo * * PURPOSE: Queries info about URL * * PARAMETERS: * LPCWSTR pwzUrl * QUERYOPTION QueryOption * DWORD dwQueryFlags * LPVOID pBuffer * DWORD cbBuffer * DWORD *pcbBuf * DWORD dwReserved * * RETURNS: * SC - result code * \***************************************************************************/ STDMETHODIMP CMMCProtocol::QueryInfo( LPCWSTR pwzUrl, QUERYOPTION QueryOption,DWORD dwQueryFlags, LPVOID pBuffer, DWORD cbBuffer, DWORD *pcbBuf, DWORD dwReserved) { DECLARE_SC(sc, TEXT("CMMCProtocol::QueryInfo")); if (QueryOption == QUERY_USES_NETWORK) { if (cbBuffer >= 4) { *(LPDWORD)pBuffer = FALSE; // does not use the network *pcbBuf = 4; return S_OK; } } else if (QueryOption == QUERY_IS_SAFE) { if (cbBuffer >= 4) { *(LPDWORD)pBuffer = TRUE; // only serves trusted content *pcbBuf = 4; return S_OK; } } return INET_E_DEFAULT_ACTION; } /***************************************************************************\ * * METHOD: CMMCProtocol::ScParseTaskpadURL * * PURPOSE: Extracts taskpad guid from URL given to the protocol * * PARAMETERS: * LPCWSTR strURL [in] - URL * GUID& guid [out] - extracted guid * * RETURNS: * SC - result code * \***************************************************************************/ SC CMMCProtocol::ScParseTaskpadURL( LPCWSTR strURL, GUID& guid ) { DECLARE_SC(sc, TEXT("CMMCProtocol::ScParseTaskpadURL")); guid = GUID_NULL; sc = ScCheckPointers(strURL); if (sc) return sc; // taskpad url should be in form "mmc:{guid}" // check for "mmc:" if ( 0 != _wcsnicmp( strURL, szMMCC, wcslen(szMMCC) ) ) return sc = E_FAIL; // skip "mmc:" strURL += wcslen(szMMCC); // get the url sc = CLSIDFromString( const_cast(strURL), &guid ); if (sc) return sc; return sc; } /***************************************************************************\ * * METHOD: CMMCProtocol::ScParsePageBreakURL * * PURPOSE: Checks if URL given to the protocol is a request for a pagebreak * * PARAMETERS: * LPCWSTR strURL [in] - URL * bool& bPageBreak [out] - true it it is a request for pagebreak * * RETURNS: * SC - result code * \***************************************************************************/ SC CMMCProtocol::ScParsePageBreakURL( LPCWSTR strURL, bool& bPageBreak ) { DECLARE_SC(sc, TEXT("CMMCProtocol::ScParsePageBreakURL")); bPageBreak = false; sc = ScCheckPointers(strURL); if (sc) return sc; // pagebreak url should be in form "mmc:pagebreak." // check for "mmc:" if ( 0 != _wcsnicmp( strURL, szMMCC, wcslen(szMMCC) ) ) return sc; // not an error - return value updated // skip "mmc:" strURL += wcslen(szMMCC); // get the url bPageBreak = ( 0 == wcsncmp( strURL, szPageBreak, wcslen(szPageBreak) ) ); return sc; } /***************************************************************************\ * * METHOD: CMMCProtocol::ScGetTaskpadXML * * PURPOSE: given the guid uploads taskpad XML string to the string * * PARAMETERS: * const GUID& guid [in] - taskpad guid * std::wstring& strResultData [out] - taskpad xml string * * RETURNS: * SC - result code * \***************************************************************************/ SC CMMCProtocol::ScGetTaskpadXML( const GUID& guid, std::wstring& strResultData ) { DECLARE_SC(sc, TEXT("CMMCProtocol::ScGetTaskpadXML")); strResultData.erase(); CScopeTree* pScopeTree = CScopeTree::GetScopeTree(); sc = ScCheckPointers(pScopeTree, E_FAIL); if(sc) return sc.ToHr(); CConsoleTaskpadList * pConsoleTaskpadList = pScopeTree->GetConsoleTaskpadList(); sc = ScCheckPointers(pConsoleTaskpadList, E_FAIL); if(sc) return sc.ToHr(); for(CConsoleTaskpadList::iterator iter = pConsoleTaskpadList->begin(); iter!= pConsoleTaskpadList->end(); ++iter) { CConsoleTaskpad &consoleTaskpad = *iter; // check if this is the one we are looking for if ( !IsEqualGUID( guid, consoleTaskpad.GetID() ) ) continue; // convert the taskpad to a string CStr strTaskpadHTML; sc = consoleTaskpad.ScGetHTML(strTaskpadHTML); // create a string version of the taskpad if(sc) return sc.ToHr(); // form the result string USES_CONVERSION; strResultData = chUNICODE; strResultData += T2CW(strTaskpadHTML); return sc; } // not found return sc = E_FAIL; } /***************************************************************************\ * * METHOD: CMMCProtocol::AppendMMCPath * * PURPOSE: helper. Appends the mmcndmgr.dll dir (no file name) to the string * It may append something like: "c:\winnt\system32" * * PARAMETERS: * std::wstring& str [in/out] - string to edit * * RETURNS: * SC - result code * \***************************************************************************/ void CMMCProtocol::AppendMMCPath(std::wstring& str) { TCHAR szModule[_MAX_PATH+10] = { 0 }; DWORD cchSize = countof(szModule); DWORD dwRet = GetModuleFileName(_Module.GetModuleInstance(), szModule, cchSize); if(0==dwRet) return; // NTRAID#NTBUG9-613782-2002/05/02-ronmart-prefast warning 53: Call to 'GetModuleFileNameW' may not zero-terminate string szModule[cchSize - 1] = 0; USES_CONVERSION; LPCWSTR strModule = T2CW(szModule); LPCWSTR dirEnd = wcsrchr( strModule, L'\\' ); if (dirEnd != NULL) str.append(strModule, dirEnd); } /***************************************************************************\ * * METHOD: CMMCProtocol::ExpandMMCVars * * PURPOSE: helper. expands any %mmcres% contained in the string * It expands it to something like "res://c:\winnt\system32\mmcndmgr.dll" * * PARAMETERS: * std::wstring& str [in/out] - string to edit * * RETURNS: * SC - result code * \***************************************************************************/ void CMMCProtocol::ExpandMMCVars(std::wstring& str) { TCHAR szModule[_MAX_PATH+10] = { 0 }; DWORD cchSize = countof(szModule); DWORD dwRet = GetModuleFileName(_Module.GetModuleInstance(), szModule, cchSize); if(0==dwRet) return; // NTRAID#NTBUG9-613782-2002/05/02-ronmart-prefast warning 53: Call to 'GetModuleFileNameW' may not zero-terminate string szModule[cchSize - 1] = 0; USES_CONVERSION; LPCWSTR strModule = T2CW(szModule); std::wstring mmcres = L"res://"; mmcres += strModule; // second - replace the instances int pos; while (std::wstring::npos != (pos = str.find(szMMCRES) ) ) { // make one substitution str.replace( pos, wcslen(szMMCRES), mmcres) ; } }