Leaked source code of windows server 2003
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

642 lines
15 KiB

//+----------------------------------------------------------------------------
// File: dll.cxx
//
// Synopsis: This file contains the core routines and globals for creating
// DLLs
//
//-----------------------------------------------------------------------------
// Includes -------------------------------------------------------------------
#include <core.hxx>
// Globals --------------------------------------------------------------------
static THREADSTATE * g_pts = NULL;
HANDLE g_hinst = NULL;
HANDLE g_heap = NULL;
DWORD g_tlsThreadState = NULL_TLS;
LONG g_cUsage = 0;
GINFO g_ginfo = { 0 };
DECLARE_LOCK(DLL);
// Prototypes -----------------------------------------------------------------
class CClassFactory : public CComponent,
public IClassFactory2
{
typedef CComponent parent;
public:
CClassFactory(CLASSFACTORY * pcf);
// IUnknown methods
DEFINE_IUNKNOWN_METHODS;
// IClassFactory methods
STDMETHOD(CreateInstance)(IUnknown * pUnkOuter, REFIID riid, void ** ppvObj);
STDMETHOD(LockServer)(BOOL fLock);
// IClassFactory2 methods
STDMETHOD(GetLicInfo)(LICINFO * pLicInfo);
STDMETHOD(RequestLicKey)(DWORD dwReserved, BSTR * pbstrKey);
STDMETHOD(CreateInstanceLic)(IUnknown * pUnkOuter,
IUnknown * pUnkReserved,
REFIID riid, BSTR bstrKey,
void ** ppvObj);
private:
CLASSFACTORY * _pcf;
HRESULT PrivateQueryInterface(REFIID riid, void ** ppvObj);
};
static HRESULT DllProcessAttach();
static void DllProcessDetach();
static HRESULT DllThreadAttach();
static void DllThreadDetach(THREADSTATE * pts);
static void DllProcessPassivate();
static void DllThreadPassivate();
//+----------------------------------------------------------------------------
// Function: DllMain
//
// Synopsis:
//
//-----------------------------------------------------------------------------
extern "C" BOOL WINAPI
DllMain(
HINSTANCE hinst,
DWORD nReason,
void * ) // pvReserved - Unused
{
HRESULT hr = S_OK;
g_hinst = hinst;
switch (nReason)
{
case DLL_PROCESS_ATTACH:
hr = DllProcessAttach();
break;
case DLL_PROCESS_DETACH:
DllProcessDetach();
break;
case DLL_THREAD_DETACH:
{
THREADSTATE * pts = (THREADSTATE *)TlsGetValue(g_tlsThreadState);
DllThreadDetach(pts);
}
break;
}
return !hr;
}
//+----------------------------------------------------------------------------
// Function: DllGetClassObject
//
// Synopsis:
//
// NOTE: This code limits class objects to supporting IUnknown and IClassFactory
//
//-----------------------------------------------------------------------------
STDAPI
DllGetClassObject(
REFCLSID rclsid,
REFIID riid,
void ** ppv)
{
CLASSFACTORY * pcf;
HRESULT hr;
hr = EnsureThreadState();
if (hr)
return hr;
if (!ppv)
return E_INVALIDARG;
*ppv = NULL;
if (riid != IID_IClassFactory &&
riid != IID_IClassFactory2)
return E_NOINTERFACE;
for (pcf=g_acf; pcf->pclsid; pcf++)
{
if (*(pcf->pclsid) == rclsid)
break;
}
if (!pcf)
return CLASS_E_CLASSNOTAVAILABLE;
if (riid == IID_IClassFactory2 && !pcf->pfnLicense)
return E_NOINTERFACE;
CClassFactory * pCF = new CClassFactory(pcf);
if (!pCF)
return E_OUTOFMEMORY;
*ppv = (void *)(IClassFactory2 *)pCF;
return S_OK;
}
//+----------------------------------------------------------------------------
// Function: DllCanUnloadNow
//
// Synopsis:
//
//-----------------------------------------------------------------------------
STDAPI
DllCanUnloadNow()
{
return ((g_cUsage==0)
? S_OK
: S_FALSE);
}
//+----------------------------------------------------------------------------
// Function: DllProcessAttach
//
// Synopsis:
//
//-----------------------------------------------------------------------------
HRESULT
DllProcessAttach()
{
PFN_PATTACH * ppfnPAttach;
HRESULT hr = S_OK;
g_tlsThreadState = TlsAlloc();
if (g_tlsThreadState == NULL_TLS)
{
return GetWin32Hresult();
}
INIT_LOCK(DLL);
g_heap = GetProcessHeap();
for (ppfnPAttach=g_apfnPAttach; *ppfnPAttach; ppfnPAttach++)
{
hr = (**ppfnPAttach)();
if (hr)
goto Error;
}
Cleanup:
return hr;
Error:
DllProcessDetach();
goto Cleanup;
}
//+----------------------------------------------------------------------------
// Function: DllProcessDetach
//
// Synopsis:
//
//-----------------------------------------------------------------------------
void
DllProcessDetach()
{
THREADSTATE * pts;
PFN_PDETACH * ppfnPDetach;
Implies(g_pts, g_tlsThreadState != NULL_TLS);
while (g_pts)
{
pts = g_pts;
Verify(TlsSetValue(g_tlsThreadState, pts));
DllThreadDetach(pts);
Assert(!TlsGetValue(g_tlsThreadState));
Assert(g_pts != pts);
}
for (ppfnPDetach=g_apfnPDetach; *ppfnPDetach; ppfnPDetach++)
(**ppfnPDetach)();
DEINIT_LOCK(DLL);
if (g_tlsThreadState != NULL_TLS)
{
TlsFree(g_tlsThreadState);
}
}
//+----------------------------------------------------------------------------
// Function: DllThreadAttach
//
// Synopsis:
//
//-----------------------------------------------------------------------------
HRESULT
DllThreadAttach()
{
THREADSTATE * pts;
PFN_TATTACH * ppfnTAttach;
HRESULT hr;
LOCK(DLL);
Assert(g_tlsThreadState != NULL_TLS);
Assert(!::TlsGetValue(g_tlsThreadState));
hr = AllocateThreadState(&pts);
if (hr)
goto Error;
Assert(pts);
pts->dll.idThread = GetCurrentThreadId();
Verify(TlsSetValue(g_tlsThreadState, pts));
Verify(SUCCEEDED(::CoGetMalloc(1, &pts->dll.pmalloc)));
for (ppfnTAttach=g_apfnTAttach; *ppfnTAttach; ppfnTAttach++)
{
hr = (**ppfnTAttach)(pts);
if (hr)
goto Error;
}
pts->ptsNext = g_pts;
g_pts = pts;
Cleanup:
return hr;
Error:
DllThreadDetach(pts);
goto Cleanup;
}
//+----------------------------------------------------------------------------
// Function: DllThreadDetach
//
// Synopsis:
//
// NOTE: Under Win95, DllThreadDetach may be called to clear memory on a
// thread which did not allocate the memory.
//
//-----------------------------------------------------------------------------
void
DllThreadDetach(
THREADSTATE * pts)
{
THREADSTATE ** ppts;
PFN_TDETACH * ppfnTDetach;
LOCK(DLL);
if (!pts)
return;
Assert(!pts->dll.cUsage);
Assert(pts == (THREADSTATE *)TlsGetValue(g_tlsThreadState));
for (ppfnTDetach=g_apfnTDetach; *ppfnTDetach; ppfnTDetach++)
(**ppfnTDetach)(pts);
::SRelease(pts->dll.pmalloc);
::TlsSetValue(g_tlsThreadState, NULL);
for (ppts=&g_pts; *ppts && *ppts != pts; ppts=&((*ppts)->ptsNext));
if (*ppts)
{
*ppts = pts->ptsNext;
}
delete pts;
}
//+----------------------------------------------------------------------------
// Function: DllProcessPassivate
//
// Synopsis:
//
//-----------------------------------------------------------------------------
void
DllProcessPassivate()
{
PFN_PPASSIVATE * ppfnPPassivate;
LOCK(DLL);
Assert(!g_cUsage);
// BUGBUG: What are the respective roles of process/thread passivation?
// BUGBUG: This is an unsafe add into g_cUsage...fix this!
g_cUsage += REF_GUARD;
for (ppfnPPassivate=g_apfnPPassivate; *ppfnPPassivate; ppfnPPassivate++)
(**ppfnPPassivate)();
g_cUsage -= REF_GUARD;
}
//+----------------------------------------------------------------------------
// Function: DllThreadPassivate
//
// Synopsis:
//
//-----------------------------------------------------------------------------
void
DllThreadPassivate()
{
THREADSTATE * pts = GetThreadState();
PFN_TPASSIVATE * ppfnTPassivate;
Assert(!pts->dll.cUsage);
pts->dll.cUsage += REF_GUARD;
for (ppfnTPassivate=g_apfnTPassivate; *ppfnTPassivate; ppfnTPassivate++)
(**ppfnTPassivate)(pts);
pts->dll.cUsage -= REF_GUARD;
}
//+----------------------------------------------------------------------------
// Function: CClassFactory
//
// Synopsis:
//
//-----------------------------------------------------------------------------
CClassFactory::CClassFactory(
CLASSFACTORY * pcf)
: CComponent(NULL)
{
Assert(pcf);
Assert(pcf->pfnFactory);
_pcf = pcf;
}
//+----------------------------------------------------------------------------
// Function: CreateInstance
//
// Synopsis:
//
//-----------------------------------------------------------------------------
STDMETHODIMP
CClassFactory::CreateInstance(
IUnknown * pUnkOuter,
REFIID riid,
void ** ppvObj)
{
if (!ppvObj)
return E_INVALIDARG;
*ppvObj = NULL;
// BUGBUG: What error should be returned?
if (pUnkOuter && riid != IID_IUnknown)
return E_INVALIDARG;
// BUGBUG: Should the factory just create the object and let this
// code perform the appropriate QI?
// BUGBUG: This code should automatically handle aggregation
Assert(_pcf);
Assert(_pcf->pfnFactory);
return _pcf->pfnFactory(pUnkOuter, riid, ppvObj);
}
//+----------------------------------------------------------------------------
// Function: LockServer
//
// Synopsis:
//
//-----------------------------------------------------------------------------
STDMETHODIMP
CClassFactory::LockServer(
BOOL fLock)
{
if (fLock)
{
AddRef();
IncrementThreadUsage();
}
else
{
DecrementThreadUsage();
Release();
}
return S_OK;
}
//+----------------------------------------------------------------------------
// Function: GetLicInfo
//
// Synopsis:
//
//-----------------------------------------------------------------------------
STDMETHODIMP
CClassFactory::GetLicInfo(
LICINFO * pLicInfo)
{
Assert(_pcf->pfnLicense);
return _pcf->pfnLicense(LICREQUEST_INFO, pLicInfo);
}
//+----------------------------------------------------------------------------
// Function: RequestLicKey
//
// Synopsis:
//
//-----------------------------------------------------------------------------
STDMETHODIMP
CClassFactory::RequestLicKey(
DWORD , // dwReserved
BSTR * pbstrKey)
{
Assert(_pcf->pfnLicense);
return _pcf->pfnLicense(LICREQUEST_OBTAIN, pbstrKey);
}
//+----------------------------------------------------------------------------
// Function: CreateInstanceLic
//
// Synopsis:
//
//-----------------------------------------------------------------------------
STDMETHODIMP
CClassFactory::CreateInstanceLic(
IUnknown * pUnkOuter,
IUnknown * , // pUnkReserved
REFIID riid,
BSTR bstrKey,
void ** ppvObj)
{
Assert(_pcf->pfnLicense);
if (!ppvObj)
return E_INVALIDARG;
*ppvObj = NULL;
if (_pcf->pfnLicense(LICREQUEST_VALIDATE, bstrKey) != S_OK)
{
return CLASS_E_NOTLICENSED;
}
return CreateInstance(pUnkOuter, riid, ppvObj);
}
//+----------------------------------------------------------------------------
// Function: PrivateQueryInterface
//
// Synopsis:
//
//-----------------------------------------------------------------------------
HRESULT
CClassFactory::PrivateQueryInterface(
REFIID riid,
void ** ppvObj)
{
if (riid == IID_IClassFactory)
{
*ppvObj = (void *)(IClassFactory *)this;
}
else if (riid == IID_IClassFactory2)
{
if (_pcf->pfnLicense)
{
*ppvObj = (void *)(IClassFactory2 *)this;
}
else
{
return E_NOINTERFACE;
}
}
else
{
return parent::PrivateQueryInterface(riid, ppvObj);
}
return S_OK;
}
//+----------------------------------------------------------------------------
// Function: GetWin32Hresult
//
// Synopsis: Return an HRESULT derived from the current Win32 error
//
//-----------------------------------------------------------------------------
HRESULT
GetWin32Hresult()
{
return HRESULT_FROM_WIN32(GetLastError());
}
//+----------------------------------------------------------------------------
// Function: EnsureThreadState
//
// Synopsis:
//
//-----------------------------------------------------------------------------
HRESULT
EnsureThreadState()
{
extern DWORD g_tlsThreadState;
Assert(g_tlsThreadState != NULL_TLS);
if (!TlsGetValue(g_tlsThreadState))
return DllThreadAttach();
return S_OK;
}
//+----------------------------------------------------------------------------
// Function: IncrementProcessUsage
//
// Synopsis:
//
//-----------------------------------------------------------------------------
void
IncrementProcessUsage()
{
#ifdef _DEBUG
Verify(InterlockedIncrement(&g_cUsage) > 0);
#else
InterlockedIncrement(&g_cUsage);
#endif
}
//+----------------------------------------------------------------------------
// Function: DecrementProcessUsage
//
// Synopsis:
//
//-----------------------------------------------------------------------------
void
DecrementProcessUsage()
{
#if DBG==1
if( 0 == g_cUsage )
{
DebugBreak(); // ref counting problem
}
#endif
if (!InterlockedDecrement(&g_cUsage))
{
DllProcessPassivate();
}
}
//+----------------------------------------------------------------------------
// Function: IncrementThreadUsage
//
// Synopsis:
//
//-----------------------------------------------------------------------------
void
IncrementThreadUsage()
{
#ifdef _DEBUG
Verify(++TLS(dll.cUsage) > 0);
#else
++TLS(dll.cUsage);
#endif
IncrementProcessUsage();
}
//+----------------------------------------------------------------------------
// Function: DecrementThreadUsage
//
// Synopsis:
//
//-----------------------------------------------------------------------------
void
DecrementThreadUsage()
{
THREADSTATE * pts = GetThreadState();
if(pts)
{
pts->dll.cUsage--;
Assert(pts->dll.cUsage >= 0);
if (!pts->dll.cUsage)
{
DllThreadPassivate();
}
}
DecrementProcessUsage();
}