/*
 * clsfact.cpp - IClassFactory implementation.
 */


/* Headers
 **********/

#include "project.hpp"
#pragma hdrstop

#include "clsfact.h"
#include "ftps.hpp"
#include "inetcpl.h"
#include "inetps.hpp"


/* Types
 ********/

// callback function used by ClassFactory::ClassFactory()

typedef PIUnknown (*NEWOBJECTPROC)(OBJECTDESTROYEDPROC);
DECLARE_STANDARD_TYPES(NEWOBJECTPROC);

// description of class supported by DllGetClassObject()

typedef struct classconstructor
{
   PCCLSID pcclsid;

   NEWOBJECTPROC NewObject;
}
CLASSCONSTRUCTOR;
DECLARE_STANDARD_TYPES(CLASSCONSTRUCTOR);


/* Classes
 **********/

// object class factory

class ClassFactory : public RefCount,
                     public IClassFactory
{
private:
   NEWOBJECTPROC m_NewObject;

public:
   ClassFactory(NEWOBJECTPROC NewObject, OBJECTDESTROYEDPROC ObjectDestroyed);
   ~ClassFactory(void);

   // IClassFactory methods

   HRESULT STDMETHODCALLTYPE CreateInstance(PIUnknown piunkOuter, REFIID riid, PVOID *ppvObject);
   HRESULT STDMETHODCALLTYPE LockServer(BOOL bLock);

   // IUnknown methods

   ULONG STDMETHODCALLTYPE AddRef(void);
   ULONG STDMETHODCALLTYPE Release(void);
   HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, PVOID *ppvObj);

   // friends

#ifdef DEBUG

   friend BOOL IsValidPCClassFactory(const ClassFactory *pcurlcf);

#endif

};
DECLARE_STANDARD_TYPES(ClassFactory);


/* Module Prototypes
 ********************/

PRIVATE_CODE PIUnknown NewInternetShortcut(OBJECTDESTROYEDPROC ObjectDestroyed);
PRIVATE_CODE PIUnknown NewMIMEHook(OBJECTDESTROYEDPROC ObjectDestroyed);
PRIVATE_CODE PIUnknown NewInternet(OBJECTDESTROYEDPROC ObjectDestroyed);


/* Module Constants
 *******************/

#pragma data_seg(DATA_SEG_READ_ONLY)

PRIVATE_DATA CCLASSCONSTRUCTOR s_cclscnstr[] =
{
   { &CLSID_InternetShortcut,             &NewInternetShortcut },
   { &CLSID_MIMEFileTypesPropSheetHook,   &NewMIMEHook },
   { &CLSID_Internet,                     &NewInternet },
};

#pragma data_seg()


/* Module Variables
 *******************/

#pragma data_seg(DATA_SEG_PER_INSTANCE)

// DLL reference count == number of class factories +
//                        number of URLs +
//                        LockServer() count

PRIVATE_DATA ULONG s_ulcDLLRef   = 0;

#pragma data_seg()


/***************************** Private Functions *****************************/


PRIVATE_CODE HRESULT GetClassConstructor(REFCLSID rclsid,
                                         PNEWOBJECTPROC pNewObject)
{
   HRESULT hr = CLASS_E_CLASSNOTAVAILABLE;
   UINT u;

   ASSERT(IsValidREFCLSID(rclsid));
   ASSERT(IS_VALID_WRITE_PTR(pNewObject, NEWOBJECTPROC));

   *pNewObject = NULL;

   for (u = 0; u < ARRAY_ELEMENTS(s_cclscnstr); u++)
   {
      if (rclsid == *(s_cclscnstr[u].pcclsid))
      {
         *pNewObject = s_cclscnstr[u].NewObject;
         hr = S_OK;
      }
   }

   ASSERT((hr == S_OK &&
           IS_VALID_CODE_PTR(*pNewObject, NEWOBJECTPROC)) ||
          (hr == CLASS_E_CLASSNOTAVAILABLE &&
           ! *pNewObject));

   return(hr);
}


PRIVATE_CODE void STDMETHODCALLTYPE DLLObjectDestroyed(void)
{
   TRACE_OUT(("DLLObjectDestroyed(): Object destroyed."));

   DLLRelease();
}


PRIVATE_CODE PIUnknown NewInternetShortcut(OBJECTDESTROYEDPROC ObjectDestroyed)
{
   ASSERT(! ObjectDestroyed ||
          IS_VALID_CODE_PTR(ObjectDestroyed, OBJECTDESTROYEDPROC));

   TRACE_OUT(("NewInternetShortcut(): Creating a new InternetShortcut."));

   return((PIUnknown)(PIUniformResourceLocator)new(InternetShortcut(ObjectDestroyed)));
}


PRIVATE_CODE PIUnknown NewMIMEHook(OBJECTDESTROYEDPROC ObjectDestroyed)
{
   ASSERT(! ObjectDestroyed ||
          IS_VALID_CODE_PTR(ObjectDestroyed, OBJECTDESTROYEDPROC));

   TRACE_OUT(("NewMIMEHook(): Creating a new MIMEHook."));

   return((PIUnknown)(PIShellPropSheetExt)new(MIMEHook(ObjectDestroyed)));
}


PRIVATE_CODE PIUnknown NewInternet(OBJECTDESTROYEDPROC ObjectDestroyed)
{
   ASSERT(! ObjectDestroyed ||
          IS_VALID_CODE_PTR(ObjectDestroyed, OBJECTDESTROYEDPROC));

   TRACE_OUT(("NewInternet(): Creating a new Internet."));

   return((PIUnknown)(PIShellPropSheetExt)new(Internet(ObjectDestroyed)));
}


#ifdef DEBUG

PRIVATE_CODE BOOL IsValidPCClassFactory(PCClassFactory pccf)
{
   return(IS_VALID_READ_PTR(pccf, CClassFactory) &&
          IS_VALID_CODE_PTR(pccf->m_NewObject, NEWOBJECTPROC) &&
          IS_VALID_STRUCT_PTR((PCRefCount)pccf, CRefCount) &&
          IS_VALID_INTERFACE_PTR((PCIClassFactory)pccf, IClassFactory));
}

#endif


/****************************** Public Functions *****************************/


PUBLIC_CODE ULONG DLLAddRef(void)
{
   ULONG ulcRef;

   ASSERT(s_ulcDLLRef < ULONG_MAX);

   ulcRef = ++s_ulcDLLRef;

   TRACE_OUT(("DLLAddRef(): DLL reference count is now %lu.",
              ulcRef));

   return(ulcRef);
}


PUBLIC_CODE ULONG DLLRelease(void)
{
   ULONG ulcRef;

   if (EVAL(s_ulcDLLRef > 0))
      s_ulcDLLRef--;

   ulcRef = s_ulcDLLRef;

   TRACE_OUT(("DLLRelease(): DLL reference count is now %lu.",
              ulcRef));

   return(ulcRef);
}


PUBLIC_CODE PULONG GetDLLRefCountPtr(void)
{
   return(&s_ulcDLLRef);
}


/********************************** Methods **********************************/


ClassFactory::ClassFactory(NEWOBJECTPROC NewObject,
                           OBJECTDESTROYEDPROC ObjectDestroyed) :
   RefCount(ObjectDestroyed)
{
   DebugEntry(ClassFactory::ClassFactory);

   // Don't validate this until after construction.

   ASSERT(IS_VALID_CODE_PTR(NewObject, NEWOBJECTPROC));

   m_NewObject = NewObject;

   ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));

   DebugExitVOID(ClassFactory::ClassFactory);

   return;
}


ClassFactory::~ClassFactory(void)
{
   DebugEntry(ClassFactory::~ClassFactory);

   ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));

   m_NewObject = NULL;

   // Don't validate this after destruction.

   DebugExitVOID(ClassFactory::~ClassFactory);

   return;
}


ULONG STDMETHODCALLTYPE ClassFactory::AddRef(void)
{
   ULONG ulcRef;

   DebugEntry(ClassFactory::AddRef);

   ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));

   ulcRef = RefCount::AddRef();

   ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));

   DebugExitULONG(ClassFactory::AddRef, ulcRef);

   return(ulcRef);
}


ULONG STDMETHODCALLTYPE ClassFactory::Release(void)
{
   ULONG ulcRef;

   DebugEntry(ClassFactory::Release);

   ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));

   ulcRef = RefCount::Release();

   DebugExitULONG(ClassFactory::Release, ulcRef);

   return(ulcRef);
}


HRESULT STDMETHODCALLTYPE ClassFactory::QueryInterface(REFIID riid,
                                                       PVOID *ppvObject)
{
   HRESULT hr = S_OK;

   DebugEntry(ClassFactory::QueryInterface);

   ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));
   ASSERT(IsValidREFIID(riid));
   ASSERT(IS_VALID_WRITE_PTR(ppvObject, PVOID));

   if (riid == IID_IClassFactory)
   {
      *ppvObject = (PIClassFactory)this;
      ASSERT(IS_VALID_INTERFACE_PTR((PIClassFactory)*ppvObject, IClassFactory));
      TRACE_OUT(("ClassFactory::QueryInterface(): Returning IClassFactory."));
   }
   else if (riid == IID_IUnknown)
   {
      *ppvObject = (PIUnknown)this;
      ASSERT(IS_VALID_INTERFACE_PTR((PIUnknown)*ppvObject, IUnknown));
      TRACE_OUT(("ClassFactory::QueryInterface(): Returning IUnknown."));
   }
   else
   {
      *ppvObject = NULL;
      hr = E_NOINTERFACE;
      TRACE_OUT(("ClassFactory::QueryInterface(): Called on unknown interface."));
   }

   if (hr == S_OK)
      AddRef();

   ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));
   ASSERT(FAILED(hr) ||
          IS_VALID_INTERFACE_PTR(*ppvObject, INTERFACE));

   DebugExitHRESULT(ClassFactory::QueryInterface, hr);

   return(hr);
}


HRESULT STDMETHODCALLTYPE ClassFactory::CreateInstance(PIUnknown piunkOuter,
                                                       REFIID riid,
                                                       PVOID *ppvObject)
{
   HRESULT hr;

   DebugEntry(ClassFactory::CreateInstance);

   ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));
   ASSERT(! piunkOuter ||
          IS_VALID_INTERFACE_PTR(piunkOuter, IUnknown));
   ASSERT(IsValidREFIID(riid));
   ASSERT(IS_VALID_WRITE_PTR(ppvObject, PVOID));

   *ppvObject = NULL;

   if (! piunkOuter)
   {
      PIUnknown piunk;

      piunk = (*m_NewObject)(&DLLObjectDestroyed);

      if (piunk)
      {
         DLLAddRef();

         hr = piunk->QueryInterface(riid, ppvObject);

         // N.b., the Release() method will destroy the object if the
         // QueryInterface() method failed.

         piunk->Release();
      }
      else
         hr = E_OUTOFMEMORY;
   }
   else
   {
      hr = CLASS_E_NOAGGREGATION;
      WARNING_OUT(("ClassFactory::CreateInstance(): Aggregation not supported."));
   }

   ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));
   ASSERT(FAILED(hr) ||
          IS_VALID_INTERFACE_PTR(*ppvObject, INTERFACE));

   DebugExitHRESULT(ClassFactory::CreateInstance, hr);

   return(hr);
}


HRESULT STDMETHODCALLTYPE ClassFactory::LockServer(BOOL bLock)
{
   HRESULT hr;

   DebugEntry(ClassFactory::LockServer);

   ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));

   // bLock may be any value.

   if (bLock)
      DLLAddRef();
   else
      DLLRelease();

   hr = S_OK;

   ASSERT(IS_VALID_STRUCT_PTR(this, CClassFactory));

   DebugExitHRESULT(ClassFactory::LockServer, hr);

   return(hr);
}


/***************************** Exported Functions ****************************/


STDAPI DllGetClassObject(REFCLSID rclsid, REFIID riid, PVOID *ppvObject)
{
   HRESULT hr = S_OK;
   NEWOBJECTPROC NewObject;

   DebugEntry(DllGetClassObject);

   ASSERT(IsValidREFCLSID(rclsid));
   ASSERT(IsValidREFIID(riid));
   ASSERT(IS_VALID_WRITE_PTR(ppvObject, PVOID));

   *ppvObject = NULL;

   hr = GetClassConstructor(rclsid, &NewObject);

   if (hr == S_OK)
   {
      if (riid == IID_IUnknown ||
          riid == IID_IClassFactory)
      {
         PClassFactory pcf;

         pcf = new(ClassFactory(NewObject, &DLLObjectDestroyed));

         if (pcf)
         {
            if (riid == IID_IClassFactory)
            {
               *ppvObject = (PIClassFactory)pcf;
               ASSERT(IS_VALID_INTERFACE_PTR((PIClassFactory)*ppvObject, IClassFactory));
               TRACE_OUT(("DllGetClassObject(): Returning IClassFactory."));
            }
            else
            {
               ASSERT(riid == IID_IUnknown);
               *ppvObject = (PIUnknown)pcf;
               ASSERT(IS_VALID_INTERFACE_PTR((PIUnknown)*ppvObject, IUnknown));
               TRACE_OUT(("DllGetClassObject(): Returning IUnknown."));
            }

            DLLAddRef();
            hr = S_OK;

            TRACE_OUT(("DllGetClassObject(): Created a new class factory."));
         }
         else
            hr = E_OUTOFMEMORY;
      }
      else
      {
         WARNING_OUT(("DllGetClassObject(): Called on unknown interface."));
         hr = E_NOINTERFACE;
      }
   }
   else
      WARNING_OUT(("DllGetClassObject(): Called on unknown class."));

   ASSERT(FAILED(hr) ||
          IS_VALID_INTERFACE_PTR(*ppvObject, INTERFACE));

   DebugExitHRESULT(DllGetClassObject, hr);

   return(hr);
}


STDAPI DllCanUnloadNow(void)
{
   HRESULT hr;

   DebugEntry(DllCanUnloadNow);

   hr = (s_ulcDLLRef > 0) ? S_FALSE : S_OK;
    
    if (hr == S_OK) 
        hr = InternetCPLCanUnloadNow();

   TRACE_OUT(("DllCanUnloadNow(): DLL reference count is %lu.",
              s_ulcDLLRef));

   DebugExitHRESULT(DllCanUnloadNow, hr);

   return(hr);
}