//+---------------------------------------------------------------------------- // File: dll.cxx // // Synopsis: This file contains the core routines and globals for creating // DLLs // //----------------------------------------------------------------------------- // Includes ------------------------------------------------------------------- #include // Globals -------------------------------------------------------------------- static THREADSTATE * g_pts = NULL; HANDLE g_hinst = NULL; HANDLE g_heap = NULL; DWORD g_tlsThreadState = NULL_TLS; LONG g_cUsage = 0; GINFO g_ginfo = { 0 }; DECLARE_LOCK(DLL); // Prototypes ----------------------------------------------------------------- class CClassFactory : public CComponent, public IClassFactory2 { typedef CComponent parent; public: CClassFactory(CLASSFACTORY * pcf); // IUnknown methods DEFINE_IUNKNOWN_METHODS; // IClassFactory methods STDMETHOD(CreateInstance)(IUnknown * pUnkOuter, REFIID riid, void ** ppvObj); STDMETHOD(LockServer)(BOOL fLock); // IClassFactory2 methods STDMETHOD(GetLicInfo)(LICINFO * pLicInfo); STDMETHOD(RequestLicKey)(DWORD dwReserved, BSTR * pbstrKey); STDMETHOD(CreateInstanceLic)(IUnknown * pUnkOuter, IUnknown * pUnkReserved, REFIID riid, BSTR bstrKey, void ** ppvObj); private: CLASSFACTORY * _pcf; HRESULT PrivateQueryInterface(REFIID riid, void ** ppvObj); }; static HRESULT DllProcessAttach(); static void DllProcessDetach(); static HRESULT DllThreadAttach(); static void DllThreadDetach(THREADSTATE * pts); static void DllProcessPassivate(); static void DllThreadPassivate(); //+---------------------------------------------------------------------------- // Function: DllMain // // Synopsis: // //----------------------------------------------------------------------------- extern "C" BOOL WINAPI DllMain( HINSTANCE hinst, DWORD nReason, void * ) // pvReserved - Unused { HRESULT hr = S_OK; g_hinst = hinst; switch (nReason) { case DLL_PROCESS_ATTACH: hr = DllProcessAttach(); break; case DLL_PROCESS_DETACH: DllProcessDetach(); break; case DLL_THREAD_DETACH: { THREADSTATE * pts = (THREADSTATE *)TlsGetValue(g_tlsThreadState); DllThreadDetach(pts); } break; } return !hr; } //+---------------------------------------------------------------------------- // Function: DllGetClassObject // // Synopsis: // // NOTE: This code limits class objects to supporting IUnknown and IClassFactory // //----------------------------------------------------------------------------- STDAPI DllGetClassObject( REFCLSID rclsid, REFIID riid, void ** ppv) { CLASSFACTORY * pcf; HRESULT hr; hr = EnsureThreadState(); if (hr) return hr; if (!ppv) return E_INVALIDARG; *ppv = NULL; if (riid != IID_IClassFactory && riid != IID_IClassFactory2) return E_NOINTERFACE; for (pcf=g_acf; pcf->pclsid; pcf++) { if (*(pcf->pclsid) == rclsid) break; } if (!pcf) return CLASS_E_CLASSNOTAVAILABLE; if (riid == IID_IClassFactory2 && !pcf->pfnLicense) return E_NOINTERFACE; CClassFactory * pCF = new CClassFactory(pcf); if (!pCF) return E_OUTOFMEMORY; *ppv = (void *)(IClassFactory2 *)pCF; return S_OK; } //+---------------------------------------------------------------------------- // Function: DllCanUnloadNow // // Synopsis: // //----------------------------------------------------------------------------- STDAPI DllCanUnloadNow() { return ((g_cUsage==0) ? S_OK : S_FALSE); } //+---------------------------------------------------------------------------- // Function: DllProcessAttach // // Synopsis: // //----------------------------------------------------------------------------- HRESULT DllProcessAttach() { PFN_PATTACH * ppfnPAttach; HRESULT hr = S_OK; g_tlsThreadState = TlsAlloc(); if (g_tlsThreadState == NULL_TLS) { return GetWin32Hresult(); } INIT_LOCK(DLL); g_heap = GetProcessHeap(); for (ppfnPAttach=g_apfnPAttach; *ppfnPAttach; ppfnPAttach++) { hr = (**ppfnPAttach)(); if (hr) goto Error; } Cleanup: return hr; Error: DllProcessDetach(); goto Cleanup; } //+---------------------------------------------------------------------------- // Function: DllProcessDetach // // Synopsis: // //----------------------------------------------------------------------------- void DllProcessDetach() { THREADSTATE * pts; PFN_PDETACH * ppfnPDetach; Implies(g_pts, g_tlsThreadState != NULL_TLS); while (g_pts) { pts = g_pts; Verify(TlsSetValue(g_tlsThreadState, pts)); DllThreadDetach(pts); Assert(!TlsGetValue(g_tlsThreadState)); Assert(g_pts != pts); } for (ppfnPDetach=g_apfnPDetach; *ppfnPDetach; ppfnPDetach++) (**ppfnPDetach)(); DEINIT_LOCK(DLL); if (g_tlsThreadState != NULL_TLS) { TlsFree(g_tlsThreadState); } } //+---------------------------------------------------------------------------- // Function: DllThreadAttach // // Synopsis: // //----------------------------------------------------------------------------- HRESULT DllThreadAttach() { THREADSTATE * pts; PFN_TATTACH * ppfnTAttach; HRESULT hr; LOCK(DLL); Assert(g_tlsThreadState != NULL_TLS); Assert(!::TlsGetValue(g_tlsThreadState)); hr = AllocateThreadState(&pts); if (hr) goto Error; Assert(pts); pts->dll.idThread = GetCurrentThreadId(); Verify(TlsSetValue(g_tlsThreadState, pts)); Verify(SUCCEEDED(::CoGetMalloc(1, &pts->dll.pmalloc))); for (ppfnTAttach=g_apfnTAttach; *ppfnTAttach; ppfnTAttach++) { hr = (**ppfnTAttach)(pts); if (hr) goto Error; } pts->ptsNext = g_pts; g_pts = pts; Cleanup: return hr; Error: DllThreadDetach(pts); goto Cleanup; } //+---------------------------------------------------------------------------- // Function: DllThreadDetach // // Synopsis: // // NOTE: Under Win95, DllThreadDetach may be called to clear memory on a // thread which did not allocate the memory. // //----------------------------------------------------------------------------- void DllThreadDetach( THREADSTATE * pts) { THREADSTATE ** ppts; PFN_TDETACH * ppfnTDetach; LOCK(DLL); if (!pts) return; Assert(!pts->dll.cUsage); Assert(pts == (THREADSTATE *)TlsGetValue(g_tlsThreadState)); for (ppfnTDetach=g_apfnTDetach; *ppfnTDetach; ppfnTDetach++) (**ppfnTDetach)(pts); ::SRelease(pts->dll.pmalloc); ::TlsSetValue(g_tlsThreadState, NULL); for (ppts=&g_pts; *ppts && *ppts != pts; ppts=&((*ppts)->ptsNext)); if (*ppts) { *ppts = pts->ptsNext; } delete pts; } //+---------------------------------------------------------------------------- // Function: DllProcessPassivate // // Synopsis: // //----------------------------------------------------------------------------- void DllProcessPassivate() { PFN_PPASSIVATE * ppfnPPassivate; LOCK(DLL); Assert(!g_cUsage); // BUGBUG: What are the respective roles of process/thread passivation? // BUGBUG: This is an unsafe add into g_cUsage...fix this! g_cUsage += REF_GUARD; for (ppfnPPassivate=g_apfnPPassivate; *ppfnPPassivate; ppfnPPassivate++) (**ppfnPPassivate)(); g_cUsage -= REF_GUARD; } //+---------------------------------------------------------------------------- // Function: DllThreadPassivate // // Synopsis: // //----------------------------------------------------------------------------- void DllThreadPassivate() { THREADSTATE * pts = GetThreadState(); PFN_TPASSIVATE * ppfnTPassivate; Assert(!pts->dll.cUsage); pts->dll.cUsage += REF_GUARD; for (ppfnTPassivate=g_apfnTPassivate; *ppfnTPassivate; ppfnTPassivate++) (**ppfnTPassivate)(pts); pts->dll.cUsage -= REF_GUARD; } //+---------------------------------------------------------------------------- // Function: CClassFactory // // Synopsis: // //----------------------------------------------------------------------------- CClassFactory::CClassFactory( CLASSFACTORY * pcf) : CComponent(NULL) { Assert(pcf); Assert(pcf->pfnFactory); _pcf = pcf; } //+---------------------------------------------------------------------------- // Function: CreateInstance // // Synopsis: // //----------------------------------------------------------------------------- STDMETHODIMP CClassFactory::CreateInstance( IUnknown * pUnkOuter, REFIID riid, void ** ppvObj) { if (!ppvObj) return E_INVALIDARG; *ppvObj = NULL; // BUGBUG: What error should be returned? if (pUnkOuter && riid != IID_IUnknown) return E_INVALIDARG; // BUGBUG: Should the factory just create the object and let this // code perform the appropriate QI? // BUGBUG: This code should automatically handle aggregation Assert(_pcf); Assert(_pcf->pfnFactory); return _pcf->pfnFactory(pUnkOuter, riid, ppvObj); } //+---------------------------------------------------------------------------- // Function: LockServer // // Synopsis: // //----------------------------------------------------------------------------- STDMETHODIMP CClassFactory::LockServer( BOOL fLock) { if (fLock) { AddRef(); IncrementThreadUsage(); } else { DecrementThreadUsage(); Release(); } return S_OK; } //+---------------------------------------------------------------------------- // Function: GetLicInfo // // Synopsis: // //----------------------------------------------------------------------------- STDMETHODIMP CClassFactory::GetLicInfo( LICINFO * pLicInfo) { Assert(_pcf->pfnLicense); return _pcf->pfnLicense(LICREQUEST_INFO, pLicInfo); } //+---------------------------------------------------------------------------- // Function: RequestLicKey // // Synopsis: // //----------------------------------------------------------------------------- STDMETHODIMP CClassFactory::RequestLicKey( DWORD , // dwReserved BSTR * pbstrKey) { Assert(_pcf->pfnLicense); return _pcf->pfnLicense(LICREQUEST_OBTAIN, pbstrKey); } //+---------------------------------------------------------------------------- // Function: CreateInstanceLic // // Synopsis: // //----------------------------------------------------------------------------- STDMETHODIMP CClassFactory::CreateInstanceLic( IUnknown * pUnkOuter, IUnknown * , // pUnkReserved REFIID riid, BSTR bstrKey, void ** ppvObj) { Assert(_pcf->pfnLicense); if (!ppvObj) return E_INVALIDARG; *ppvObj = NULL; if (_pcf->pfnLicense(LICREQUEST_VALIDATE, bstrKey) != S_OK) { return CLASS_E_NOTLICENSED; } return CreateInstance(pUnkOuter, riid, ppvObj); } //+---------------------------------------------------------------------------- // Function: PrivateQueryInterface // // Synopsis: // //----------------------------------------------------------------------------- HRESULT CClassFactory::PrivateQueryInterface( REFIID riid, void ** ppvObj) { if (riid == IID_IClassFactory) { *ppvObj = (void *)(IClassFactory *)this; } else if (riid == IID_IClassFactory2) { if (_pcf->pfnLicense) { *ppvObj = (void *)(IClassFactory2 *)this; } else { return E_NOINTERFACE; } } else { return parent::PrivateQueryInterface(riid, ppvObj); } return S_OK; } //+---------------------------------------------------------------------------- // Function: GetWin32Hresult // // Synopsis: Return an HRESULT derived from the current Win32 error // //----------------------------------------------------------------------------- HRESULT GetWin32Hresult() { return HRESULT_FROM_WIN32(GetLastError()); } //+---------------------------------------------------------------------------- // Function: EnsureThreadState // // Synopsis: // //----------------------------------------------------------------------------- HRESULT EnsureThreadState() { extern DWORD g_tlsThreadState; Assert(g_tlsThreadState != NULL_TLS); if (!TlsGetValue(g_tlsThreadState)) return DllThreadAttach(); return S_OK; } //+---------------------------------------------------------------------------- // Function: IncrementProcessUsage // // Synopsis: // //----------------------------------------------------------------------------- void IncrementProcessUsage() { #ifdef _DEBUG Verify(InterlockedIncrement(&g_cUsage) > 0); #else InterlockedIncrement(&g_cUsage); #endif } //+---------------------------------------------------------------------------- // Function: DecrementProcessUsage // // Synopsis: // //----------------------------------------------------------------------------- void DecrementProcessUsage() { #if DBG==1 if( 0 == g_cUsage ) { DebugBreak(); // ref counting problem } #endif if (!InterlockedDecrement(&g_cUsage)) { DllProcessPassivate(); } } //+---------------------------------------------------------------------------- // Function: IncrementThreadUsage // // Synopsis: // //----------------------------------------------------------------------------- void IncrementThreadUsage() { #ifdef _DEBUG Verify(++TLS(dll.cUsage) > 0); #else ++TLS(dll.cUsage); #endif IncrementProcessUsage(); } //+---------------------------------------------------------------------------- // Function: DecrementThreadUsage // // Synopsis: // //----------------------------------------------------------------------------- void DecrementThreadUsage() { THREADSTATE * pts = GetThreadState(); if(pts) { pts->dll.cUsage--; Assert(pts->dll.cUsage >= 0); if (!pts->dll.cUsage) { DllThreadPassivate(); } } DecrementProcessUsage(); }