Source code of Windows XP (NT5)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

450 lines
10 KiB

  1. ///////////////////////////////////////////////////////////////////////////////
  2. // CCOMBaseFactory
  3. // Base class for reusing a single class factory for all components in a DLL
  4. #include "fact.h"
  5. #include "unk.h"
  6. #include "regsvr.h"
  7. #include "dbg.h"
  8. struct OUTPROCINFO
  9. {
  10. // Reserved (used only for COM Exe server)
  11. IClassFactory* _pfact;
  12. DWORD _dwRegister;
  13. };
  14. LONG CCOMBaseFactory::_cServerLocks = 0;
  15. LONG CCOMBaseFactory::_cComponents = 0;
  16. HMODULE CCOMBaseFactory::_hModule = NULL;
  17. CRITICAL_SECTION CCOMBaseFactory::_cs = {0};
  18. OUTPROCINFO* CCOMBaseFactory::_popinfo = NULL;
  19. DWORD CCOMBaseFactory::_dwThreadID = 0;
  20. BOOL CCOMBaseFactory::_fCritSectInit = FALSE;
  21. ///////////////////////////////////////////////////////////////////////////////
  22. // IUnknown implementation
  23. STDMETHODIMP CCOMBaseFactory::QueryInterface(REFIID iid, void** ppv)
  24. {
  25. IUnknown* punk = NULL;
  26. HRESULT hres = S_OK;
  27. if ((iid == IID_IUnknown) || (iid == IID_IClassFactory))
  28. {
  29. punk = this;
  30. punk->AddRef();
  31. }
  32. else
  33. {
  34. hres = E_NOINTERFACE;
  35. }
  36. *ppv = punk;
  37. return hres;
  38. }
  39. STDMETHODIMP_(ULONG) CCOMBaseFactory::AddRef()
  40. {
  41. return ::InterlockedIncrement((LONG*)&_cRef);
  42. }
  43. STDMETHODIMP_(ULONG) CCOMBaseFactory::Release()
  44. {
  45. ULONG cRef = ::InterlockedDecrement((LONG*)&_cRef);
  46. if (!cRef)
  47. {
  48. delete this;
  49. }
  50. return cRef;
  51. }
  52. ///////////////////////////////////////////////////////////////////////////////
  53. // IFactory implementation
  54. STDMETHODIMP CCOMBaseFactory::CreateInstance(IUnknown* pUnknownOuter,
  55. REFIID riid, void** ppv)
  56. {
  57. HRESULT hres = CLASS_E_NOAGGREGATION;
  58. // We don't support aggregation at all for now
  59. if (!pUnknownOuter)
  60. {
  61. // Aggregate only if the requested IID is IID_IUnknown.
  62. if ((pUnknownOuter != NULL) && (riid != IID_IUnknown))
  63. {
  64. hres = CLASS_E_NOAGGREGATION;
  65. }
  66. else
  67. {
  68. // Create the component.
  69. IUnknown* punkNew;
  70. hres = _pFactoryData->CreateInstance(
  71. CCOMBaseFactory::_COMFactoryCB, pUnknownOuter, &punkNew);
  72. if (SUCCEEDED(hres))
  73. {
  74. _COMFactoryCB(TRUE);
  75. // Get the requested interface.
  76. // hres = pNewComponent->NondelegatingQueryInterface(iid, ppv);
  77. hres = punkNew->QueryInterface(riid, ppv);
  78. // Release the reference held by the class factory.
  79. // pNewComponent->NondelegatingRelease();
  80. punkNew->Release();
  81. }
  82. }
  83. }
  84. return hres;
  85. }
  86. STDMETHODIMP CCOMBaseFactory::LockServer(BOOL fLock)
  87. {
  88. return _LockServer(fLock);
  89. }
  90. ///////////////////////////////////////////////////////////////////////////////
  91. // Install/Unintall
  92. //static
  93. HRESULT CCOMBaseFactory::_RegisterAll()
  94. {
  95. for (DWORD dw = 0; dw < _cDLLFactoryData; ++dw)
  96. {
  97. RegisterServer(_hModule,
  98. *(_pDLLFactoryData[dw]._pCLSID),
  99. _pDLLFactoryData[dw]._pszRegistryName,
  100. _pDLLFactoryData[dw]._pszVerIndProgID,
  101. _pDLLFactoryData[dw]._pszProgID,
  102. _pDLLFactoryData[dw]._dwThreadingModel,
  103. _pDLLFactoryData[dw].IsInprocServer(),
  104. _pDLLFactoryData[dw].IsLocalServer(),
  105. _pDLLFactoryData[dw].IsLocalService(),
  106. _pDLLFactoryData[dw]._pszLocalService,
  107. _pDLLFactoryData[dw]._pAppID);
  108. }
  109. return S_OK;
  110. }
  111. //static
  112. HRESULT CCOMBaseFactory::_UnregisterAll()
  113. {
  114. for (DWORD dw = 0; dw < _cDLLFactoryData; ++dw)
  115. {
  116. UnregisterServer(*(_pDLLFactoryData[dw]._pCLSID),
  117. _pDLLFactoryData[dw]._pszVerIndProgID,
  118. _pDLLFactoryData[dw]._pszProgID);
  119. }
  120. return S_OK;
  121. }
  122. ///////////////////////////////////////////////////////////////////////////////
  123. // CCOMBaseFactory implementation
  124. CCOMBaseFactory::CCOMBaseFactory(const CFactoryData* pFactoryData) : _cRef(1),
  125. _pFactoryData(pFactoryData)
  126. {}
  127. //static
  128. BOOL CCOMBaseFactory::_IsLocked()
  129. {
  130. // Always need to be called from within Critical Section
  131. return (_cServerLocks > 0);
  132. }
  133. //static
  134. HRESULT CCOMBaseFactory::_CanUnloadNow()
  135. {
  136. HRESULT hres = S_OK;
  137. // Always need to be called from within Critical Section
  138. if (_IsLocked())
  139. {
  140. hres = S_FALSE;
  141. }
  142. else
  143. {
  144. if (_cComponents)
  145. {
  146. hres = S_FALSE;
  147. }
  148. }
  149. return hres;
  150. }
  151. //static
  152. HRESULT CCOMBaseFactory::_CheckForUnload()
  153. {
  154. // Always need to be called from within Critical Section
  155. if (S_OK == _CanUnloadNow())
  156. {
  157. ::PostThreadMessage(_dwThreadID, WM_QUIT, 0, 0);
  158. }
  159. return S_OK;
  160. }
  161. //static
  162. HRESULT CCOMBaseFactory::_LockServer(BOOL fLock)
  163. {
  164. HRESULT hres = S_OK;
  165. EnterCriticalSection(&_cs);
  166. if (fLock)
  167. {
  168. ++_cServerLocks;
  169. }
  170. else
  171. {
  172. --_cServerLocks;
  173. hres = _CheckForUnload();
  174. }
  175. LeaveCriticalSection(&_cs);
  176. return hres;
  177. }
  178. //static
  179. void CCOMBaseFactory::_COMFactoryCB(BOOL fIncrement)
  180. {
  181. EnterCriticalSection(&_cs);
  182. if (fIncrement)
  183. {
  184. ++_cComponents;
  185. }
  186. else
  187. {
  188. --_cComponents;
  189. _CheckForUnload();
  190. }
  191. LeaveCriticalSection(&_cs);
  192. }
  193. ///////////////////////////////////////////////////////////////////////////////
  194. //
  195. // static
  196. HRESULT CCOMBaseFactory::_GetClassObject(REFCLSID rclsid, REFIID riid,
  197. void** ppv)
  198. {
  199. HRESULT hres = S_OK;
  200. ASSERT(_fCritSectInit);
  201. if ((riid != IID_IUnknown) && (riid != IID_IClassFactory))
  202. {
  203. hres = E_NOINTERFACE;
  204. }
  205. else
  206. {
  207. hres = CLASS_E_CLASSNOTAVAILABLE;
  208. // Traverse the array of data looking for this class ID.
  209. for (DWORD dw = 0; dw < _cDLLFactoryData; ++dw)
  210. {
  211. const CFactoryData* pData = &_pDLLFactoryData[dw];
  212. if (pData->IsClassID(rclsid) && pData->IsInprocServer())
  213. {
  214. // Found the ClassID in the array of components we can
  215. // create. So create a class factory for this component.
  216. // Pass the CDLLFactoryData structure to the class factory
  217. // so that it knows what kind of components to create.
  218. *ppv = (IUnknown*) new CCOMBaseFactory(pData);
  219. if (*ppv == NULL)
  220. {
  221. hres = E_OUTOFMEMORY;
  222. }
  223. else
  224. {
  225. hres = S_OK;
  226. }
  227. break;
  228. }
  229. }
  230. }
  231. return hres;
  232. }
  233. //static
  234. BOOL CCOMBaseFactory::_ProcessConsoleCmdLineParams(int argc, wchar_t* argv[],
  235. BOOL* pfRun, BOOL* pfEmbedded)
  236. {
  237. _dwThreadID = GetCurrentThreadId();
  238. if (argc > 1)
  239. {
  240. if (!lstrcmpi(argv[1], TEXT("-i")) ||
  241. !lstrcmpi(argv[1], TEXT("/i")))
  242. {
  243. CCOMBaseFactory::_RegisterAll();
  244. *pfRun = FALSE;
  245. }
  246. else
  247. {
  248. if (!lstrcmpi(argv[1], TEXT("-u")) ||
  249. !lstrcmpi(argv[1], TEXT("/u")))
  250. {
  251. CCOMBaseFactory::_UnregisterAll();
  252. *pfRun = FALSE;
  253. }
  254. else
  255. {
  256. if (!lstrcmpi(argv[1], TEXT("-Embedding")) ||
  257. !lstrcmpi(argv[1], TEXT("/Embedding")))
  258. {
  259. *pfRun = TRUE;
  260. *pfEmbedded = TRUE;
  261. }
  262. }
  263. }
  264. }
  265. else
  266. {
  267. *pfEmbedded = FALSE;
  268. *pfRun = TRUE;
  269. }
  270. return TRUE;
  271. }
  272. //static
  273. BOOL CCOMBaseFactory::_RegisterFactories(BOOL fEmbedded)
  274. {
  275. HRESULT hres = S_OK;
  276. if (!_fCritSectInit)
  277. {
  278. InitializeCriticalSection(&CCOMBaseFactory::_cs);
  279. _fCritSectInit = TRUE;
  280. }
  281. if (!fEmbedded)
  282. {
  283. hres = _LockServer(TRUE);
  284. }
  285. _popinfo = (OUTPROCINFO*)LocalAlloc(LPTR, sizeof(OUTPROCINFO) * _cDLLFactoryData);
  286. if (_popinfo)
  287. {
  288. for (DWORD dw = 0; SUCCEEDED(hres) && (dw < _cDLLFactoryData); ++dw)
  289. {
  290. const CFactoryData* pData = &_pDLLFactoryData[dw];
  291. if (pData->IsLocalServer() || pData->IsLocalService())
  292. {
  293. _popinfo[dw]._pfact = NULL ;
  294. _popinfo[dw]._dwRegister = NULL ;
  295. IClassFactory* pfact = new CCOMBaseFactory(pData);
  296. if (pfact)
  297. {
  298. DWORD dwRegister;
  299. hres = ::CoRegisterClassObject(*pData->_pCLSID,
  300. static_cast<IUnknown*>(pfact), pData->_dwClsContext,
  301. pData->_dwFlags, &dwRegister);
  302. if (SUCCEEDED(hres))
  303. {
  304. _popinfo[dw]._pfact = pfact;
  305. _popinfo[dw]._dwRegister = dwRegister;
  306. }
  307. else
  308. {
  309. pfact->Release();
  310. }
  311. }
  312. else
  313. {
  314. hres = E_OUTOFMEMORY;
  315. }
  316. }
  317. }
  318. }
  319. else
  320. {
  321. hres = E_OUTOFMEMORY;
  322. }
  323. return SUCCEEDED(hres);
  324. }
  325. //static
  326. BOOL CCOMBaseFactory::_SuspendFactories()
  327. {
  328. return SUCCEEDED(::CoSuspendClassObjects());
  329. }
  330. //static
  331. BOOL CCOMBaseFactory::_ResumeFactories()
  332. {
  333. return SUCCEEDED(::CoResumeClassObjects());
  334. }
  335. //static
  336. BOOL CCOMBaseFactory::_UnregisterFactories(BOOL fEmbedded)
  337. {
  338. HRESULT hres = S_OK;
  339. ASSERT(_popinfo);
  340. for (DWORD dw = 0; dw < _cDLLFactoryData; ++dw)
  341. {
  342. if (_popinfo[dw]._pfact)
  343. {
  344. _popinfo[dw]._pfact->Release();
  345. HRESULT hresTmp = ::CoRevokeClassObject(_popinfo[dw]._dwRegister);
  346. if (FAILED(hresTmp) && (S_OK == hres))
  347. {
  348. hres = hresTmp;
  349. }
  350. }
  351. }
  352. if (!fEmbedded)
  353. {
  354. HRESULT hresTmp = _LockServer(FALSE);
  355. if (FAILED(hresTmp) && (S_OK == hres))
  356. {
  357. hres = hresTmp;
  358. }
  359. }
  360. return SUCCEEDED(hres);
  361. }
  362. //static
  363. void CCOMBaseFactory::_WaitForAllClientsToGo()
  364. {
  365. MSG msg;
  366. while (::GetMessage(&msg, 0, 0, 0))
  367. {
  368. ::DispatchMessage(&msg);
  369. }
  370. }