Leaked source code of windows server 2003
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

476 lines
12 KiB

  1. /*++
  2. Copyright (C) 1996-2001 Microsoft Corporation
  3. Module Name:
  4. EXEMAIN.CPP
  5. Abstract:
  6. EXE/COM Helpers
  7. History:
  8. --*/
  9. #include "precomp.h"
  10. #include <stdio.h>
  11. #include "commain.cpp"
  12. #include <strutils.h>
  13. DWORD g_dwMainThreadId = 0xFFFFFFFF;
  14. class CExeClassFactory : public IClassFactory, IExternalConnection
  15. {
  16. long m_lRef;
  17. IClassFactory* m_pFactory;
  18. CUnkInternal* m_pInternalUnk;
  19. public:
  20. CExeClassFactory( CUnkInternal* pInternalUnk )
  21. : m_pInternalUnk( pInternalUnk ), m_lRef(0), m_pFactory(0)
  22. {
  23. m_pInternalUnk->InternalAddRef();
  24. m_pInternalUnk->InternalQueryInterface( IID_IClassFactory,
  25. (void**)&m_pFactory );
  26. }
  27. ~CExeClassFactory()
  28. {
  29. m_pInternalUnk->InternalRelease();
  30. }
  31. STDMETHOD(QueryInterface)( REFIID riid, void** ppv )
  32. {
  33. HRESULT hr;
  34. if ( riid == IID_IUnknown || riid == IID_IClassFactory )
  35. {
  36. *ppv = this;
  37. AddRef();
  38. hr = S_OK;
  39. }
  40. else if ( riid == IID_IExternalConnection )
  41. {
  42. *ppv = (IExternalConnection*)this;
  43. AddRef();
  44. hr = S_OK;
  45. }
  46. else
  47. {
  48. *ppv = NULL;
  49. hr = E_NOINTERFACE;
  50. }
  51. return hr;
  52. }
  53. STDMETHOD_(ULONG, AddRef)()
  54. {
  55. return InterlockedIncrement( &m_lRef );
  56. }
  57. STDMETHOD_(ULONG, Release)()
  58. {
  59. long lRef = InterlockedDecrement( &m_lRef );
  60. if ( lRef == 0 )
  61. {
  62. delete this;
  63. }
  64. return lRef;
  65. }
  66. STDMETHOD(CreateInstance)(IUnknown* pOuter, REFIID riid, void** ppv)
  67. {
  68. return m_pFactory->CreateInstance( pOuter, riid, ppv );
  69. }
  70. STDMETHOD(LockServer)(BOOL fLock)
  71. {
  72. return m_pFactory->LockServer( fLock );
  73. }
  74. STDMETHOD_(DWORD,AddConnection)( DWORD exconn, DWORD dwreserved )
  75. {
  76. m_pFactory->LockServer( TRUE );
  77. return 1;
  78. }
  79. STDMETHOD_(DWORD,ReleaseConnection)( DWORD exconn,
  80. DWORD dwreserved,
  81. BOOL fLastReleaseCloses )
  82. {
  83. m_pFactory->LockServer( FALSE );
  84. return 1;
  85. }
  86. };
  87. class CExeLifeControl : public CLifeControl
  88. {
  89. protected:
  90. long m_lNumObjects;
  91. BOOL m_bUnloading;
  92. CMyCritSec m_cs;
  93. protected:
  94. virtual void Quit()
  95. {
  96. PostThreadMessage(g_dwMainThreadId, WM_QUIT, 0, 0);
  97. }
  98. public:
  99. CExeLifeControl() : m_lNumObjects(0), m_bUnloading(FALSE){}
  100. virtual BOOL ObjectCreated(IUnknown* pv)
  101. {
  102. CMyInCritSec ics(&m_cs);
  103. if(m_bUnloading)
  104. return FALSE;
  105. m_lNumObjects++;
  106. return TRUE;
  107. }
  108. virtual void ObjectDestroyed(IUnknown* pv)
  109. {
  110. EnterCriticalSection(&m_cs);
  111. long l = --m_lNumObjects;
  112. if(l == 0)
  113. {
  114. m_bUnloading = TRUE;
  115. LeaveCriticalSection(&m_cs);
  116. Quit();
  117. }
  118. else
  119. {
  120. LeaveCriticalSection(&m_cs);
  121. }
  122. }
  123. virtual void AddRef(IUnknown* pv){}
  124. virtual void Release(IUnknown* pv){}
  125. };
  126. BOOL ContainsSubstring( LPCTSTR szStr, LPCTSTR szSubStr )
  127. {
  128. BOOL bContains;
  129. #ifdef UNICODE
  130. bContains = wcsstr(szStr,szSubStr) != NULL;
  131. #else
  132. bContains = strstr(szStr,szSubStr) != NULL;
  133. #endif
  134. return bContains;
  135. }
  136. void MessageLoop()
  137. {
  138. MSG msg;
  139. while(GetMessage(&msg, NULL, 0, 0))
  140. {
  141. TranslateMessage(&msg);
  142. DispatchMessage(&msg);
  143. }
  144. }
  145. struct ServiceInfo
  146. {
  147. BOOL m_bUsed;
  148. LPTSTR m_szServiceName;
  149. LPTSTR m_szDisplayName;
  150. BOOL m_bAuto;
  151. HANDLE m_hEvent;
  152. SERVICE_STATUS_HANDLE m_hStatus;
  153. ServiceInfo() : m_bUsed(FALSE){}
  154. } g_ServiceInfo;
  155. void SetServiceInfo(LPTSTR szServiceName, LPTSTR szDisplayName, BOOL bAuto)
  156. {
  157. g_ServiceInfo.m_bUsed = TRUE;
  158. g_ServiceInfo.m_szServiceName = szServiceName;
  159. g_ServiceInfo.m_szDisplayName = szDisplayName;
  160. g_ServiceInfo.m_bAuto = bAuto;
  161. }
  162. void WINAPI ServiceHandler(DWORD dwControl)
  163. {
  164. SERVICE_STATUS Status;
  165. Status.dwServiceType = SERVICE_WIN32_OWN_PROCESS;
  166. Status.dwCurrentState = SERVICE_RUNNING;
  167. Status.dwControlsAccepted = SERVICE_ACCEPT_STOP;
  168. Status.dwWin32ExitCode = NO_ERROR;
  169. Status.dwCheckPoint = 0;
  170. Status.dwWaitHint = 0;
  171. if(!SetServiceStatus(g_ServiceInfo.m_hStatus, &Status))
  172. {
  173. long lRes = GetLastError();
  174. return;
  175. }
  176. switch(dwControl)
  177. {
  178. case SERVICE_CONTROL_STOP:
  179. Status.dwCurrentState = SERVICE_STOPPED;
  180. SetServiceStatus(g_ServiceInfo.m_hStatus, &Status);
  181. SetEvent(g_ServiceInfo.m_hEvent);
  182. ExitProcess(0);
  183. return;
  184. case SERVICE_CONTROL_PAUSE:
  185. case SERVICE_CONTROL_CONTINUE:
  186. case SERVICE_CONTROL_INTERROGATE:
  187. case SERVICE_CONTROL_SHUTDOWN:
  188. return;
  189. };
  190. }
  191. void WINAPI ServiceMain(DWORD dwArgc, LPTSTR* lpszArgv)
  192. {
  193. g_ServiceInfo.m_hEvent = CreateEvent(NULL, FALSE, FALSE, NULL);
  194. g_ServiceInfo.m_hStatus = RegisterServiceCtrlHandler(
  195. g_ServiceInfo.m_szServiceName,
  196. (LPHANDLER_FUNCTION)&ServiceHandler);
  197. if(g_ServiceInfo.m_hStatus == NULL)
  198. {
  199. long lRes = GetLastError();
  200. return;
  201. }
  202. SERVICE_STATUS Status;
  203. Status.dwServiceType = SERVICE_WIN32_OWN_PROCESS;
  204. Status.dwCurrentState = SERVICE_START_PENDING;
  205. Status.dwControlsAccepted = SERVICE_ACCEPT_STOP;
  206. Status.dwWin32ExitCode = NO_ERROR;
  207. Status.dwCheckPoint = 0;
  208. Status.dwWaitHint = 10000;
  209. if(!SetServiceStatus(g_ServiceInfo.m_hStatus, &Status))
  210. {
  211. long lRes = GetLastError();
  212. return;
  213. }
  214. Status.dwCurrentState = SERVICE_RUNNING;
  215. if(!SetServiceStatus(g_ServiceInfo.m_hStatus, &Status))
  216. {
  217. long lRes = GetLastError();
  218. return;
  219. }
  220. MessageLoop();
  221. }
  222. BOOL StartService()
  223. {
  224. SERVICE_TABLE_ENTRY aEntries[2];
  225. aEntries[0].lpServiceName = g_ServiceInfo.m_szServiceName;
  226. aEntries[0].lpServiceProc = (LPSERVICE_MAIN_FUNCTION)&ServiceMain;
  227. aEntries[1].lpServiceName = NULL;
  228. aEntries[1].lpServiceProc = NULL;
  229. if(!StartServiceCtrlDispatcher(aEntries))
  230. {
  231. long lRes = GetLastError();
  232. return FALSE;
  233. }
  234. return TRUE;
  235. }
  236. BOOL InstallService()
  237. {
  238. SC_HANDLE hManager = OpenSCManager(NULL, NULL, SC_MANAGER_ALL_ACCESS);
  239. TCHAR szFilename[1024];
  240. GetModuleFileName(NULL, szFilename, 1023);
  241. SC_HANDLE hService = CreateService(hManager,
  242. g_ServiceInfo.m_szServiceName,
  243. g_ServiceInfo.m_szDisplayName,
  244. SERVICE_ALL_ACCESS,
  245. SERVICE_WIN32_OWN_PROCESS,
  246. g_ServiceInfo.m_bAuto?SERVICE_AUTO_START : SERVICE_DEMAND_START,
  247. SERVICE_ERROR_NORMAL,
  248. szFilename, NULL, NULL, NULL,
  249. NULL, //Local System
  250. NULL // no password
  251. );
  252. if(hService == NULL)
  253. {
  254. long lRes = GetLastError();
  255. return FALSE;
  256. }
  257. /*
  258. // Create AppId key
  259. // ================
  260. GUID AppId = *g_aClassInfos[0].m_pClsid;
  261. char szAppId[128];
  262. WCHAR wszAppId[128];
  263. char szAppIdKey[128];
  264. StringFromGUID2(*pInfo->m_pClsid, wszAppId, 128);
  265. wcstombs(szAppId, wszAppId, 128);
  266. strcpy(szAppIdKey, "SOFTWARE\\Classes\\AppId\\");
  267. strcat(szAppIdKey, szAppId);
  268. HKEY hKey1;
  269. RegCreateKey(HKEY_LOCAL_MACHINE, szAppIdKey, &hKey1);
  270. RegSetValueEx(hKey1, "LocalService", 0, REG_SZ,
  271. g_ServiceInfo.m_szServiceName,
  272. strlen(g_ServiceInfo.m_szServiceName)+1);
  273. */
  274. return TRUE;
  275. }
  276. BOOL DeinstallService()
  277. {
  278. SC_HANDLE hManager = OpenSCManager(NULL, NULL, SC_MANAGER_ALL_ACCESS);
  279. SC_HANDLE hService = OpenService(hManager, g_ServiceInfo.m_szServiceName,
  280. SERVICE_ALL_ACCESS);
  281. if(hService == NULL)
  282. {
  283. long lRes = GetLastError();
  284. return FALSE;
  285. }
  286. if(!DeleteService(hService))
  287. {
  288. long lRes = GetLastError();
  289. return FALSE;
  290. }
  291. return TRUE;
  292. }
  293. CExeLifeControl g_LifeControl;
  294. CLifeControl* g_pLifeControl = &g_LifeControl;
  295. void CALLBACK MyTimerProc(HWND hWnd, UINT uMsg, UINT idEvent, DWORD dwTime)
  296. {
  297. PostQuitMessage(0);
  298. }
  299. void __cdecl main()
  300. {
  301. LPTSTR szOrigCommandLine = GetCommandLine();
  302. size_t cchLen = lstrlen(szOrigCommandLine)+1;
  303. LPTSTR szCommandLine = new TCHAR[cchLen];
  304. if (!szCommandLine)
  305. return;
  306. StringCchCopy( szCommandLine,cchLen,szOrigCommandLine );
  307. TCHAR * pc = szCommandLine;
  308. while(*pc)
  309. *(pc++) = (TCHAR)wbem_towupper(*pc);
  310. GlobalInitialize();
  311. if ( ContainsSubstring(szCommandLine, TEXT("-REGSERVER")) ||
  312. ContainsSubstring(szCommandLine, TEXT("/REGSERVER")) )
  313. {
  314. GlobalRegister();
  315. for(LIST_ENTRY * pEntry = g_ClassInfoHead.Flink;
  316. pEntry != &g_ClassInfoHead;
  317. pEntry = pEntry->Flink)
  318. {
  319. CClassInfo* pInfo = CONTAINING_RECORD(pEntry,CClassInfo,m_Entry);
  320. HRESULT hres = RegisterServer(pInfo, TRUE);
  321. if(FAILED(hres)) return;
  322. }
  323. if(g_ServiceInfo.m_bUsed)
  324. {
  325. InstallService();
  326. }
  327. }
  328. else if ( ContainsSubstring( szCommandLine, TEXT("-UNREGSERVER")) ||
  329. ContainsSubstring( szCommandLine, TEXT("/UNREGSERVER")))
  330. {
  331. GlobalUnregister();
  332. for(LIST_ENTRY * pEntry = g_ClassInfoHead.Flink;
  333. pEntry != &g_ClassInfoHead;
  334. pEntry = pEntry->Flink)
  335. {
  336. CClassInfo* pInfo = CONTAINING_RECORD(pEntry,CClassInfo,m_Entry);
  337. HRESULT hres = UnregisterServer(pInfo, TRUE);
  338. if(FAILED(hres)) return;
  339. }
  340. if(g_ServiceInfo.m_bUsed)
  341. {
  342. DeinstallService();
  343. }
  344. }
  345. else if( !ContainsSubstring(szCommandLine, TEXT("EMBEDDING")) &&
  346. !g_ServiceInfo.m_bUsed )
  347. {
  348. printf("Cannot run standalone\n");
  349. }
  350. else
  351. {
  352. int i;
  353. if(FAILED(GlobalInitializeCom()))
  354. return;
  355. for(LIST_ENTRY * pEntry = g_ClassInfoHead.Flink;
  356. pEntry != &g_ClassInfoHead;
  357. pEntry = pEntry->Flink)
  358. {
  359. CClassInfo* pInfo = CONTAINING_RECORD(pEntry,CClassInfo,m_Entry);
  360. IClassFactory* pFactory = new CExeClassFactory( pInfo->m_pFactory);
  361. if ( pFactory == NULL )
  362. return;
  363. HRESULT hres = CoRegisterClassObject(
  364. *pInfo->m_pClsid, pFactory, CLSCTX_SERVER,
  365. REGCLS_MULTIPLEUSE, &pInfo->m_dwCookie);
  366. if(FAILED(hres)) return;
  367. }
  368. if(g_ServiceInfo.m_bUsed)
  369. {
  370. StartService();
  371. }
  372. else
  373. {
  374. g_dwMainThreadId = GetCurrentThreadId();
  375. MessageLoop();
  376. }
  377. for(LIST_ENTRY * pEntry = g_ClassInfoHead.Flink;
  378. pEntry != &g_ClassInfoHead;
  379. pEntry = pEntry->Flink)
  380. {
  381. CClassInfo* pInfo = CONTAINING_RECORD(pEntry,CClassInfo,m_Entry);
  382. HRESULT hres = CoRevokeClassObject(pInfo->m_dwCookie);
  383. if(FAILED(hres)) return;
  384. }
  385. SetTimer(NULL, 0, 1000, (TIMERPROC)MyTimerProc);
  386. MessageLoop();
  387. GlobalUninitialize();
  388. }
  389. }
  390. int WINAPI WinMain(HINSTANCE hInst, HINSTANCE hPrev, LPSTR lpCmdLine,
  391. int nCmdShow)
  392. {
  393. main();
  394. return 0;
  395. }