Leaked source code of windows server 2003
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

642 lines
15 KiB

  1. //+----------------------------------------------------------------------------
  2. // File: dll.cxx
  3. //
  4. // Synopsis: This file contains the core routines and globals for creating
  5. // DLLs
  6. //
  7. //-----------------------------------------------------------------------------
  8. // Includes -------------------------------------------------------------------
  9. #include <core.hxx>
  10. // Globals --------------------------------------------------------------------
  11. static THREADSTATE * g_pts = NULL;
  12. HANDLE g_hinst = NULL;
  13. HANDLE g_heap = NULL;
  14. DWORD g_tlsThreadState = NULL_TLS;
  15. LONG g_cUsage = 0;
  16. GINFO g_ginfo = { 0 };
  17. DECLARE_LOCK(DLL);
  18. // Prototypes -----------------------------------------------------------------
  19. class CClassFactory : public CComponent,
  20. public IClassFactory2
  21. {
  22. typedef CComponent parent;
  23. public:
  24. CClassFactory(CLASSFACTORY * pcf);
  25. // IUnknown methods
  26. DEFINE_IUNKNOWN_METHODS;
  27. // IClassFactory methods
  28. STDMETHOD(CreateInstance)(IUnknown * pUnkOuter, REFIID riid, void ** ppvObj);
  29. STDMETHOD(LockServer)(BOOL fLock);
  30. // IClassFactory2 methods
  31. STDMETHOD(GetLicInfo)(LICINFO * pLicInfo);
  32. STDMETHOD(RequestLicKey)(DWORD dwReserved, BSTR * pbstrKey);
  33. STDMETHOD(CreateInstanceLic)(IUnknown * pUnkOuter,
  34. IUnknown * pUnkReserved,
  35. REFIID riid, BSTR bstrKey,
  36. void ** ppvObj);
  37. private:
  38. CLASSFACTORY * _pcf;
  39. HRESULT PrivateQueryInterface(REFIID riid, void ** ppvObj);
  40. };
  41. static HRESULT DllProcessAttach();
  42. static void DllProcessDetach();
  43. static HRESULT DllThreadAttach();
  44. static void DllThreadDetach(THREADSTATE * pts);
  45. static void DllProcessPassivate();
  46. static void DllThreadPassivate();
  47. //+----------------------------------------------------------------------------
  48. // Function: DllMain
  49. //
  50. // Synopsis:
  51. //
  52. //-----------------------------------------------------------------------------
  53. extern "C" BOOL WINAPI
  54. DllMain(
  55. HINSTANCE hinst,
  56. DWORD nReason,
  57. void * ) // pvReserved - Unused
  58. {
  59. HRESULT hr = S_OK;
  60. g_hinst = hinst;
  61. switch (nReason)
  62. {
  63. case DLL_PROCESS_ATTACH:
  64. hr = DllProcessAttach();
  65. break;
  66. case DLL_PROCESS_DETACH:
  67. DllProcessDetach();
  68. break;
  69. case DLL_THREAD_DETACH:
  70. {
  71. THREADSTATE * pts = (THREADSTATE *)TlsGetValue(g_tlsThreadState);
  72. DllThreadDetach(pts);
  73. }
  74. break;
  75. }
  76. return !hr;
  77. }
  78. //+----------------------------------------------------------------------------
  79. // Function: DllGetClassObject
  80. //
  81. // Synopsis:
  82. //
  83. // NOTE: This code limits class objects to supporting IUnknown and IClassFactory
  84. //
  85. //-----------------------------------------------------------------------------
  86. STDAPI
  87. DllGetClassObject(
  88. REFCLSID rclsid,
  89. REFIID riid,
  90. void ** ppv)
  91. {
  92. CLASSFACTORY * pcf;
  93. HRESULT hr;
  94. hr = EnsureThreadState();
  95. if (hr)
  96. return hr;
  97. if (!ppv)
  98. return E_INVALIDARG;
  99. *ppv = NULL;
  100. if (riid != IID_IClassFactory &&
  101. riid != IID_IClassFactory2)
  102. return E_NOINTERFACE;
  103. for (pcf=g_acf; pcf->pclsid; pcf++)
  104. {
  105. if (*(pcf->pclsid) == rclsid)
  106. break;
  107. }
  108. if (!pcf)
  109. return CLASS_E_CLASSNOTAVAILABLE;
  110. if (riid == IID_IClassFactory2 && !pcf->pfnLicense)
  111. return E_NOINTERFACE;
  112. CClassFactory * pCF = new CClassFactory(pcf);
  113. if (!pCF)
  114. return E_OUTOFMEMORY;
  115. *ppv = (void *)(IClassFactory2 *)pCF;
  116. return S_OK;
  117. }
  118. //+----------------------------------------------------------------------------
  119. // Function: DllCanUnloadNow
  120. //
  121. // Synopsis:
  122. //
  123. //-----------------------------------------------------------------------------
  124. STDAPI
  125. DllCanUnloadNow()
  126. {
  127. return ((g_cUsage==0)
  128. ? S_OK
  129. : S_FALSE);
  130. }
  131. //+----------------------------------------------------------------------------
  132. // Function: DllProcessAttach
  133. //
  134. // Synopsis:
  135. //
  136. //-----------------------------------------------------------------------------
  137. HRESULT
  138. DllProcessAttach()
  139. {
  140. PFN_PATTACH * ppfnPAttach;
  141. HRESULT hr = S_OK;
  142. g_tlsThreadState = TlsAlloc();
  143. if (g_tlsThreadState == NULL_TLS)
  144. {
  145. return GetWin32Hresult();
  146. }
  147. INIT_LOCK(DLL);
  148. g_heap = GetProcessHeap();
  149. for (ppfnPAttach=g_apfnPAttach; *ppfnPAttach; ppfnPAttach++)
  150. {
  151. hr = (**ppfnPAttach)();
  152. if (hr)
  153. goto Error;
  154. }
  155. Cleanup:
  156. return hr;
  157. Error:
  158. DllProcessDetach();
  159. goto Cleanup;
  160. }
  161. //+----------------------------------------------------------------------------
  162. // Function: DllProcessDetach
  163. //
  164. // Synopsis:
  165. //
  166. //-----------------------------------------------------------------------------
  167. void
  168. DllProcessDetach()
  169. {
  170. THREADSTATE * pts;
  171. PFN_PDETACH * ppfnPDetach;
  172. Implies(g_pts, g_tlsThreadState != NULL_TLS);
  173. while (g_pts)
  174. {
  175. pts = g_pts;
  176. Verify(TlsSetValue(g_tlsThreadState, pts));
  177. DllThreadDetach(pts);
  178. Assert(!TlsGetValue(g_tlsThreadState));
  179. Assert(g_pts != pts);
  180. }
  181. for (ppfnPDetach=g_apfnPDetach; *ppfnPDetach; ppfnPDetach++)
  182. (**ppfnPDetach)();
  183. DEINIT_LOCK(DLL);
  184. if (g_tlsThreadState != NULL_TLS)
  185. {
  186. TlsFree(g_tlsThreadState);
  187. }
  188. }
  189. //+----------------------------------------------------------------------------
  190. // Function: DllThreadAttach
  191. //
  192. // Synopsis:
  193. //
  194. //-----------------------------------------------------------------------------
  195. HRESULT
  196. DllThreadAttach()
  197. {
  198. THREADSTATE * pts;
  199. PFN_TATTACH * ppfnTAttach;
  200. HRESULT hr;
  201. LOCK(DLL);
  202. Assert(g_tlsThreadState != NULL_TLS);
  203. Assert(!::TlsGetValue(g_tlsThreadState));
  204. hr = AllocateThreadState(&pts);
  205. if (hr)
  206. goto Error;
  207. Assert(pts);
  208. pts->dll.idThread = GetCurrentThreadId();
  209. Verify(TlsSetValue(g_tlsThreadState, pts));
  210. Verify(SUCCEEDED(::CoGetMalloc(1, &pts->dll.pmalloc)));
  211. for (ppfnTAttach=g_apfnTAttach; *ppfnTAttach; ppfnTAttach++)
  212. {
  213. hr = (**ppfnTAttach)(pts);
  214. if (hr)
  215. goto Error;
  216. }
  217. pts->ptsNext = g_pts;
  218. g_pts = pts;
  219. Cleanup:
  220. return hr;
  221. Error:
  222. DllThreadDetach(pts);
  223. goto Cleanup;
  224. }
  225. //+----------------------------------------------------------------------------
  226. // Function: DllThreadDetach
  227. //
  228. // Synopsis:
  229. //
  230. // NOTE: Under Win95, DllThreadDetach may be called to clear memory on a
  231. // thread which did not allocate the memory.
  232. //
  233. //-----------------------------------------------------------------------------
  234. void
  235. DllThreadDetach(
  236. THREADSTATE * pts)
  237. {
  238. THREADSTATE ** ppts;
  239. PFN_TDETACH * ppfnTDetach;
  240. LOCK(DLL);
  241. if (!pts)
  242. return;
  243. Assert(!pts->dll.cUsage);
  244. Assert(pts == (THREADSTATE *)TlsGetValue(g_tlsThreadState));
  245. for (ppfnTDetach=g_apfnTDetach; *ppfnTDetach; ppfnTDetach++)
  246. (**ppfnTDetach)(pts);
  247. ::SRelease(pts->dll.pmalloc);
  248. ::TlsSetValue(g_tlsThreadState, NULL);
  249. for (ppts=&g_pts; *ppts && *ppts != pts; ppts=&((*ppts)->ptsNext));
  250. if (*ppts)
  251. {
  252. *ppts = pts->ptsNext;
  253. }
  254. delete pts;
  255. }
  256. //+----------------------------------------------------------------------------
  257. // Function: DllProcessPassivate
  258. //
  259. // Synopsis:
  260. //
  261. //-----------------------------------------------------------------------------
  262. void
  263. DllProcessPassivate()
  264. {
  265. PFN_PPASSIVATE * ppfnPPassivate;
  266. LOCK(DLL);
  267. Assert(!g_cUsage);
  268. // BUGBUG: What are the respective roles of process/thread passivation?
  269. // BUGBUG: This is an unsafe add into g_cUsage...fix this!
  270. g_cUsage += REF_GUARD;
  271. for (ppfnPPassivate=g_apfnPPassivate; *ppfnPPassivate; ppfnPPassivate++)
  272. (**ppfnPPassivate)();
  273. g_cUsage -= REF_GUARD;
  274. }
  275. //+----------------------------------------------------------------------------
  276. // Function: DllThreadPassivate
  277. //
  278. // Synopsis:
  279. //
  280. //-----------------------------------------------------------------------------
  281. void
  282. DllThreadPassivate()
  283. {
  284. THREADSTATE * pts = GetThreadState();
  285. PFN_TPASSIVATE * ppfnTPassivate;
  286. Assert(!pts->dll.cUsage);
  287. pts->dll.cUsage += REF_GUARD;
  288. for (ppfnTPassivate=g_apfnTPassivate; *ppfnTPassivate; ppfnTPassivate++)
  289. (**ppfnTPassivate)(pts);
  290. pts->dll.cUsage -= REF_GUARD;
  291. }
  292. //+----------------------------------------------------------------------------
  293. // Function: CClassFactory
  294. //
  295. // Synopsis:
  296. //
  297. //-----------------------------------------------------------------------------
  298. CClassFactory::CClassFactory(
  299. CLASSFACTORY * pcf)
  300. : CComponent(NULL)
  301. {
  302. Assert(pcf);
  303. Assert(pcf->pfnFactory);
  304. _pcf = pcf;
  305. }
  306. //+----------------------------------------------------------------------------
  307. // Function: CreateInstance
  308. //
  309. // Synopsis:
  310. //
  311. //-----------------------------------------------------------------------------
  312. STDMETHODIMP
  313. CClassFactory::CreateInstance(
  314. IUnknown * pUnkOuter,
  315. REFIID riid,
  316. void ** ppvObj)
  317. {
  318. if (!ppvObj)
  319. return E_INVALIDARG;
  320. *ppvObj = NULL;
  321. // BUGBUG: What error should be returned?
  322. if (pUnkOuter && riid != IID_IUnknown)
  323. return E_INVALIDARG;
  324. // BUGBUG: Should the factory just create the object and let this
  325. // code perform the appropriate QI?
  326. // BUGBUG: This code should automatically handle aggregation
  327. Assert(_pcf);
  328. Assert(_pcf->pfnFactory);
  329. return _pcf->pfnFactory(pUnkOuter, riid, ppvObj);
  330. }
  331. //+----------------------------------------------------------------------------
  332. // Function: LockServer
  333. //
  334. // Synopsis:
  335. //
  336. //-----------------------------------------------------------------------------
  337. STDMETHODIMP
  338. CClassFactory::LockServer(
  339. BOOL fLock)
  340. {
  341. if (fLock)
  342. {
  343. AddRef();
  344. IncrementThreadUsage();
  345. }
  346. else
  347. {
  348. DecrementThreadUsage();
  349. Release();
  350. }
  351. return S_OK;
  352. }
  353. //+----------------------------------------------------------------------------
  354. // Function: GetLicInfo
  355. //
  356. // Synopsis:
  357. //
  358. //-----------------------------------------------------------------------------
  359. STDMETHODIMP
  360. CClassFactory::GetLicInfo(
  361. LICINFO * pLicInfo)
  362. {
  363. Assert(_pcf->pfnLicense);
  364. return _pcf->pfnLicense(LICREQUEST_INFO, pLicInfo);
  365. }
  366. //+----------------------------------------------------------------------------
  367. // Function: RequestLicKey
  368. //
  369. // Synopsis:
  370. //
  371. //-----------------------------------------------------------------------------
  372. STDMETHODIMP
  373. CClassFactory::RequestLicKey(
  374. DWORD , // dwReserved
  375. BSTR * pbstrKey)
  376. {
  377. Assert(_pcf->pfnLicense);
  378. return _pcf->pfnLicense(LICREQUEST_OBTAIN, pbstrKey);
  379. }
  380. //+----------------------------------------------------------------------------
  381. // Function: CreateInstanceLic
  382. //
  383. // Synopsis:
  384. //
  385. //-----------------------------------------------------------------------------
  386. STDMETHODIMP
  387. CClassFactory::CreateInstanceLic(
  388. IUnknown * pUnkOuter,
  389. IUnknown * , // pUnkReserved
  390. REFIID riid,
  391. BSTR bstrKey,
  392. void ** ppvObj)
  393. {
  394. Assert(_pcf->pfnLicense);
  395. if (!ppvObj)
  396. return E_INVALIDARG;
  397. *ppvObj = NULL;
  398. if (_pcf->pfnLicense(LICREQUEST_VALIDATE, bstrKey) != S_OK)
  399. {
  400. return CLASS_E_NOTLICENSED;
  401. }
  402. return CreateInstance(pUnkOuter, riid, ppvObj);
  403. }
  404. //+----------------------------------------------------------------------------
  405. // Function: PrivateQueryInterface
  406. //
  407. // Synopsis:
  408. //
  409. //-----------------------------------------------------------------------------
  410. HRESULT
  411. CClassFactory::PrivateQueryInterface(
  412. REFIID riid,
  413. void ** ppvObj)
  414. {
  415. if (riid == IID_IClassFactory)
  416. {
  417. *ppvObj = (void *)(IClassFactory *)this;
  418. }
  419. else if (riid == IID_IClassFactory2)
  420. {
  421. if (_pcf->pfnLicense)
  422. {
  423. *ppvObj = (void *)(IClassFactory2 *)this;
  424. }
  425. else
  426. {
  427. return E_NOINTERFACE;
  428. }
  429. }
  430. else
  431. {
  432. return parent::PrivateQueryInterface(riid, ppvObj);
  433. }
  434. return S_OK;
  435. }
  436. //+----------------------------------------------------------------------------
  437. // Function: GetWin32Hresult
  438. //
  439. // Synopsis: Return an HRESULT derived from the current Win32 error
  440. //
  441. //-----------------------------------------------------------------------------
  442. HRESULT
  443. GetWin32Hresult()
  444. {
  445. return HRESULT_FROM_WIN32(GetLastError());
  446. }
  447. //+----------------------------------------------------------------------------
  448. // Function: EnsureThreadState
  449. //
  450. // Synopsis:
  451. //
  452. //-----------------------------------------------------------------------------
  453. HRESULT
  454. EnsureThreadState()
  455. {
  456. extern DWORD g_tlsThreadState;
  457. Assert(g_tlsThreadState != NULL_TLS);
  458. if (!TlsGetValue(g_tlsThreadState))
  459. return DllThreadAttach();
  460. return S_OK;
  461. }
  462. //+----------------------------------------------------------------------------
  463. // Function: IncrementProcessUsage
  464. //
  465. // Synopsis:
  466. //
  467. //-----------------------------------------------------------------------------
  468. void
  469. IncrementProcessUsage()
  470. {
  471. #ifdef _DEBUG
  472. Verify(InterlockedIncrement(&g_cUsage) > 0);
  473. #else
  474. InterlockedIncrement(&g_cUsage);
  475. #endif
  476. }
  477. //+----------------------------------------------------------------------------
  478. // Function: DecrementProcessUsage
  479. //
  480. // Synopsis:
  481. //
  482. //-----------------------------------------------------------------------------
  483. void
  484. DecrementProcessUsage()
  485. {
  486. #if DBG==1
  487. if( 0 == g_cUsage )
  488. {
  489. DebugBreak(); // ref counting problem
  490. }
  491. #endif
  492. if (!InterlockedDecrement(&g_cUsage))
  493. {
  494. DllProcessPassivate();
  495. }
  496. }
  497. //+----------------------------------------------------------------------------
  498. // Function: IncrementThreadUsage
  499. //
  500. // Synopsis:
  501. //
  502. //-----------------------------------------------------------------------------
  503. void
  504. IncrementThreadUsage()
  505. {
  506. #ifdef _DEBUG
  507. Verify(++TLS(dll.cUsage) > 0);
  508. #else
  509. ++TLS(dll.cUsage);
  510. #endif
  511. IncrementProcessUsage();
  512. }
  513. //+----------------------------------------------------------------------------
  514. // Function: DecrementThreadUsage
  515. //
  516. // Synopsis:
  517. //
  518. //-----------------------------------------------------------------------------
  519. void
  520. DecrementThreadUsage()
  521. {
  522. THREADSTATE * pts = GetThreadState();
  523. if(pts)
  524. {
  525. pts->dll.cUsage--;
  526. Assert(pts->dll.cUsage >= 0);
  527. if (!pts->dll.cUsage)
  528. {
  529. DllThreadPassivate();
  530. }
  531. }
  532. DecrementProcessUsage();
  533. }