Source code of Windows XP (NT5)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

460 lines
9.8 KiB

  1. /* - - - - - - - - */
  2. /*
  3. ** Copyright (C) Microsoft Corporation 1992-1995. All rights reserved.
  4. */
  5. /* - - - - - - - - */
  6. #include <windows.h>
  7. #include <windowsx.h>
  8. #include <shellapi.h>
  9. #include <ole2.h>
  10. #include <coguid.h>
  11. #define INTERNAL_(type) type
  12. const char aszRegServerKey[] = "InprocServer";
  13. const char aszServerEntry[] = "DllGetClassObject";
  14. const char aszServerQuery[] = "DllCanUnloadNow";
  15. const char aszCLSID[] = "CLSID\\";
  16. STDAPI_(void) MyFreeUnusedLibraries(void);
  17. /* - - - - - - - - */
  18. struct DllEntry {
  19. public:
  20. CLSID clsid;
  21. IClassFactory FAR* pFactory;
  22. LPFNCANUNLOADNOW lpfnCanUnloadNow;
  23. HINSTANCE hInstance;
  24. DllEntry FAR* pNextDll;
  25. };
  26. class FAR CTask {
  27. public:
  28. static CTask FAR* LookupTask(HTASK FAR& hTask);
  29. static IClassFactory FAR* LookupClass(CTask FAR* pTask, REFCLSID clsid, HINSTANCE hInstance);
  30. static void FreeUnusedLibraries(CTask FAR* pTask);
  31. HRESULT AddTaskDll(REFCLSID rclsid, IClassFactory FAR* pFactory, LPFNCANUNLOADNOW lpfnCanUnloadNow, HINSTANCE hInstance);
  32. IMalloc FAR* QueryMalloc(void)
  33. {
  34. return m_pMalloc;
  35. };
  36. CTask(HTASK hTask, IMalloc FAR* pMalloc);
  37. AddRef(void);
  38. Release(void);
  39. ~CTask(void);
  40. private:
  41. ULONG m_refs;
  42. HTASK m_hTask;
  43. IMalloc FAR* m_pMalloc;
  44. DllEntry FAR* m_pDllEntry;
  45. CTask FAR* m_pTaskNext;
  46. };
  47. /* - - - - - - - - */
  48. #define GlobalPtrHandle(pv) ((HGLOBAL)LOWORD(GlobalHandle(SELECTOROF(pv))))
  49. class CStdMalloc : public IMalloc {
  50. public:
  51. CStdMalloc(void)
  52. {
  53. m_refs = 0;
  54. }
  55. HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void FAR* FAR* ppvObj)
  56. {
  57. if (iid == IID_IUnknown || iid == IID_IMalloc) {
  58. *ppvObj = this;
  59. m_refs++;
  60. return NULL;
  61. } else
  62. return ResultFromScode(E_NOINTERFACE);
  63. }
  64. ULONG STDMETHODCALLTYPE AddRef(void)
  65. {
  66. return ++m_refs;
  67. }
  68. ULONG STDMETHODCALLTYPE Release(void)
  69. {
  70. return --m_refs;
  71. }
  72. void FAR* STDMETHODCALLTYPE Alloc(ULONG cb)
  73. {
  74. return (void FAR*)GlobalLock(GlobalAlloc(GMEM_SHARE | GMEM_FIXED, cb));
  75. }
  76. void FAR* STDMETHODCALLTYPE Realloc(void FAR* pv, ULONG cb)
  77. {
  78. HGLOBAL h;
  79. h = GlobalPtrHandle(pv);
  80. GlobalUnlock(h);
  81. return (void FAR*)GlobalLock(GlobalReAlloc(h, cb, GMEM_FIXED));
  82. }
  83. void STDMETHODCALLTYPE Free(void FAR* pv)
  84. {
  85. GlobalFree(GlobalPtrHandle(pv));
  86. }
  87. ULONG STDMETHODCALLTYPE GetSize(void FAR* pv)
  88. {
  89. return GlobalSize(GlobalPtrHandle(pv));
  90. }
  91. int STDMETHODCALLTYPE DidAlloc(void FAR* pv)
  92. {
  93. return !IsBadWritePtr(pv, 0);
  94. }
  95. void STDMETHODCALLTYPE HeapMinimize(void)
  96. {
  97. GlobalCompact(-1);
  98. }
  99. private:
  100. ULONG m_refs;
  101. };
  102. /* - - - - - - - - */
  103. CTask FAR* pTaskList;
  104. CStdMalloc NEAR v_stdMalloc;
  105. /* - - - - - - - - */
  106. CTask FAR* CTask::LookupTask(
  107. HTASK FAR& hTask)
  108. {
  109. CTask FAR* pTaskCurrent;
  110. hTask = GetCurrentTask();
  111. for (pTaskCurrent = pTaskList; pTaskCurrent; pTaskCurrent = pTaskCurrent->m_pTaskNext)
  112. if (pTaskCurrent->m_hTask == hTask)
  113. return pTaskCurrent;
  114. return NULL;
  115. }
  116. /* - - - - - - - - */
  117. IClassFactory FAR* CTask::LookupClass(
  118. CTask FAR* pTask,
  119. REFCLSID rclsid,
  120. HINSTANCE hInstance)
  121. {
  122. DllEntry FAR* pDllEntry;
  123. for (pDllEntry = pTask->m_pDllEntry; pDllEntry; pDllEntry = pDllEntry->pNextDll) {
  124. if ((hInstance == pDllEntry->hInstance) &&
  125. (rclsid == pDllEntry->clsid))
  126. return pDllEntry->pFactory;
  127. }
  128. return NULL;
  129. }
  130. /* - - - - - - - - */
  131. void CTask::FreeUnusedLibraries(CTask FAR* pTask)
  132. {
  133. DllEntry FAR* pDllEntryPrev;
  134. DllEntry FAR* pDllEntryCur;
  135. pDllEntryPrev = NULL;
  136. pDllEntryCur = pTask->m_pDllEntry;
  137. for (; pDllEntryCur;)
  138. if (pDllEntryCur->lpfnCanUnloadNow() == S_OK) {
  139. pDllEntryCur->pFactory->Release();
  140. FreeModule(pDllEntryCur->hInstance);
  141. if (pDllEntryPrev == NULL) {
  142. pTask->m_pDllEntry = pDllEntryCur->pNextDll;
  143. pTask->m_pMalloc->Free(pDllEntryCur);
  144. pDllEntryCur = pTask->m_pDllEntry;
  145. } else {
  146. pDllEntryPrev->pNextDll = pDllEntryCur->pNextDll;
  147. pTask->m_pMalloc->Free(pDllEntryCur);
  148. pDllEntryCur = pDllEntryPrev->pNextDll;
  149. }
  150. } else {
  151. pDllEntryPrev = pDllEntryCur;
  152. pDllEntryCur = pDllEntryCur->pNextDll;
  153. }
  154. }
  155. /* - - - - - - - - */
  156. HRESULT CTask::AddTaskDll(
  157. REFCLSID rclsid,
  158. IClassFactory FAR* pFactory,
  159. LPFNCANUNLOADNOW lpfnCanUnloadNow,
  160. HINSTANCE hInstance)
  161. {
  162. DllEntry FAR* pDllEntry;
  163. pDllEntry = (DllEntry FAR*)(m_pMalloc->Alloc(sizeof(DllEntry)));
  164. if (!pDllEntry)
  165. return ResultFromScode(E_OUTOFMEMORY);
  166. pDllEntry->clsid = rclsid;
  167. pDllEntry->pFactory = pFactory;
  168. pDllEntry->lpfnCanUnloadNow = lpfnCanUnloadNow;
  169. pDllEntry->hInstance = hInstance;
  170. pDllEntry->pNextDll = m_pDllEntry;
  171. m_pDllEntry = pDllEntry;
  172. return NULL;
  173. }
  174. /* - - - - - - - - */
  175. CTask::CTask(
  176. HTASK hTask,
  177. IMalloc FAR* pMalloc)
  178. {
  179. m_refs = 1;
  180. m_hTask = hTask;
  181. m_pMalloc = pMalloc;
  182. m_pMalloc->AddRef();
  183. m_pDllEntry = NULL;
  184. m_pTaskNext = pTaskList;
  185. pTaskList = this;
  186. }
  187. CTask::AddRef(void)
  188. {
  189. ++m_refs;
  190. return 0;
  191. }
  192. CTask::Release(void)
  193. {
  194. if (m_refs == 1)
  195. delete this;
  196. else
  197. --m_refs;
  198. return 0;
  199. }
  200. /* - - - - - - - - */
  201. CTask::~CTask(
  202. void)
  203. {
  204. for (; m_pDllEntry;) {
  205. DllEntry FAR* pDllEntry;
  206. m_pDllEntry->pFactory->Release();
  207. FreeModule(m_pDllEntry->hInstance);
  208. pDllEntry = m_pDllEntry->pNextDll;
  209. m_pMalloc->Free(m_pDllEntry);
  210. m_pDllEntry = pDllEntry;
  211. }
  212. m_pMalloc->Release();
  213. if (this == pTaskList)
  214. pTaskList = m_pTaskNext;
  215. else {
  216. CTask FAR* pTask;
  217. for (pTask = pTaskList; pTask->m_pTaskNext != this; pTask = pTask->m_pTaskNext)
  218. ;
  219. pTask->m_pTaskNext = m_pTaskNext;
  220. }
  221. }
  222. /* - - - - - - - - */
  223. // converts GUID into (...) form without leading identifier; no errors
  224. INTERNAL_(int) StringFromGUID2(REFGUID rguid, LPSTR lpsz)
  225. {
  226. wsprintf(lpsz, "{%08lX-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X}",
  227. rguid.Data1, rguid.Data2, rguid.Data3,
  228. rguid.Data4[0], rguid.Data4[1],
  229. rguid.Data4[2], rguid.Data4[3],
  230. rguid.Data4[4], rguid.Data4[5],
  231. rguid.Data4[6], rguid.Data4[7]);
  232. return _fstrlen(lpsz) + 1;
  233. }
  234. /* - - - - - - - - */
  235. #define GUIDSTR_MAX (1+ 3*sizeof(GUID) +sizeof(GUID)-1 +1 +1)
  236. #define CLSIDSTR_MAX (sizeof(aszCLSID)-1+GUIDSTR_MAX)
  237. // alternate to StringFromCLSID which puts string in caller-supplied buffer;
  238. // returns the amount of data copied including the zero terminator; 0 if none.
  239. STDAPI_(int) StringFromCLSID2(REFCLSID rclsid, LPSTR lpsz, int cbMax)
  240. {
  241. if (cbMax < CLSIDSTR_MAX)
  242. return 0;
  243. return sizeof(aszCLSID)-1 +
  244. StringFromGUID2(rclsid, _fstrchr(_fstrcpy(lpsz, aszCLSID),'\0'));
  245. }
  246. /* - - - - - - - - */
  247. static LONG RegQueryClassValue(REFCLSID rclsid, LPCSTR lpszSubKey, LPSTR lpszValue, int cbMax)
  248. {
  249. char szKey[256];
  250. int cbClsid;
  251. LONG cbValue = cbMax;
  252. // translate rclsid into string
  253. cbClsid = StringFromCLSID2(rclsid, &szKey[0], sizeof(szKey));
  254. szKey[cbClsid-1] = '\\';
  255. _fstrcpy(&szKey[cbClsid], lpszSubKey);
  256. return RegQueryValue(HKEY_CLASSES_ROOT, szKey, lpszValue, &cbValue);
  257. }
  258. /* - - - - - - - - */
  259. STDAPI CoGetClassObject(
  260. REFCLSID rclsid,
  261. DWORD dwClsContext,
  262. LPVOID pvReserved,
  263. REFIID riid,
  264. void FAR* FAR* ppv)
  265. {
  266. char aszServer[256];
  267. HTASK htask;
  268. CTask FAR* pTask;
  269. IClassFactory FAR* pFactory;
  270. HINSTANCE hInstance;
  271. HRESULT hr;
  272. if (pvReserved != NULL)
  273. return ResultFromScode(E_INVALIDARG);
  274. if (!(dwClsContext & CLSCTX_INPROC_SERVER))
  275. return ResultFromScode(E_INVALIDARG);
  276. if (!(pTask = CTask::LookupTask(htask)))
  277. return ResultFromScode(E_UNEXPECTED);
  278. if (RegQueryClassValue(rclsid, aszRegServerKey, aszServer, sizeof(aszServer)) != 0)
  279. return ResultFromScode(E_UNEXPECTED);
  280. hInstance = LoadLibrary(aszServer);
  281. if (hInstance < HINSTANCE_ERROR)
  282. return ResultFromScode(E_UNEXPECTED);
  283. if (pFactory = CTask::LookupClass(pTask, rclsid, hInstance))
  284. hr = pFactory->QueryInterface(riid, ppv);
  285. else {
  286. LPFNCANUNLOADNOW lpfnCanUnloadNow;
  287. LPFNGETCLASSOBJECT lpfnGetClassObject;
  288. lpfnCanUnloadNow = (LPFNCANUNLOADNOW)GetProcAddress(hInstance, aszServerQuery);
  289. if ((lpfnGetClassObject = (LPFNGETCLASSOBJECT)GetProcAddress(hInstance, aszServerEntry)) != NULL) {
  290. IMalloc FAR* pMalloc;
  291. pMalloc = pTask->QueryMalloc();
  292. hr = (*lpfnGetClassObject)(rclsid, IID_IClassFactory, (void FAR* FAR*)&pFactory);
  293. if (!hr) {
  294. hr = pTask->AddTaskDll(rclsid, pFactory, lpfnCanUnloadNow, hInstance);
  295. if (!hr)
  296. return pFactory->QueryInterface(riid, ppv);
  297. pFactory->Release();
  298. }
  299. } else
  300. hr = ResultFromScode(E_UNEXPECTED);
  301. }
  302. FreeLibrary(hInstance);
  303. return hr;
  304. }
  305. /* - - - - - - - - */
  306. STDAPI CoCreateInstance(
  307. REFCLSID rclsid,
  308. IUnknown FAR* pUnkOuter,
  309. DWORD dwClsContext,
  310. REFIID riid,
  311. LPVOID FAR* ppv)
  312. {
  313. HRESULT hr;
  314. IClassFactory FAR* pFactory;
  315. hr = CoGetClassObject(rclsid, dwClsContext, NULL, IID_IClassFactory, (void FAR* FAR*)&pFactory);
  316. if (!hr) {
  317. hr = pFactory->CreateInstance(pUnkOuter, riid, ppv);
  318. pFactory->Release();
  319. }
  320. return hr;
  321. }
  322. /* - - - - - - - - */
  323. STDAPI GetStandardTaskMalloc(
  324. IMalloc FAR* FAR* ppMalloc)
  325. {
  326. v_stdMalloc.AddRef();
  327. *ppMalloc = &v_stdMalloc;
  328. return NULL;
  329. }
  330. /* - - - - - - - - */
  331. STDAPI CoGetMalloc(
  332. DWORD dwMemContext,
  333. IMalloc FAR* FAR* ppMalloc)
  334. {
  335. HTASK htask;
  336. CTask FAR* pTask;
  337. IMalloc FAR* pMalloc;
  338. if (dwMemContext != MEMCTX_TASK)
  339. return ResultFromScode(E_UNEXPECTED);
  340. if (!(pTask = CTask::LookupTask(htask)))
  341. return ResultFromScode(E_UNEXPECTED);
  342. pMalloc = pTask->QueryMalloc();
  343. pMalloc->AddRef();
  344. *ppMalloc = pMalloc;
  345. return NULL;
  346. }
  347. /* - - - - - - - - */
  348. STDAPI CoInitialize(
  349. IMalloc FAR* pMalloc)
  350. {
  351. HTASK htask;
  352. CTask FAR* pTask;
  353. if (!pMalloc)
  354. pMalloc = (IMalloc FAR *) &v_stdMalloc;
  355. if (pTask = CTask::LookupTask(htask)) {
  356. pTask->AddRef();
  357. return ResultFromScode(S_FALSE);
  358. }
  359. pTask = new FAR CTask(htask, pMalloc);
  360. return pTask ? NULL : ResultFromScode(E_OUTOFMEMORY);
  361. }
  362. /* - - - - - - - - */
  363. STDAPI_(void) CoUninitialize(
  364. void)
  365. {
  366. HTASK htask;
  367. CTask FAR* pTask;
  368. if (pTask = CTask::LookupTask(htask))
  369. pTask->Release();
  370. }
  371. /* - - - - - - - - */
  372. STDAPI_(void) MyFreeUnusedLibraries(
  373. void)
  374. {
  375. HTASK htask;
  376. CTask FAR* pTask;
  377. if (pTask = CTask::LookupTask(htask))
  378. CTask::FreeUnusedLibraries(pTask);
  379. }
  380. /* - - - - - - - - */
  381. STDAPI_(BOOL) IsEqualGUID(REFGUID guid1, REFGUID guid2)
  382. {
  383. return guid1 == guid2;
  384. }