Leaked source code of windows server 2003
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

447 lines
8.7 KiB

  1. //=============================================================================
  2. // session.h -- definition of session collection class.
  3. //
  4. // Copyright (c) 1998-2002 Microsoft Corporation, All Rights Reserved
  5. //=============================================================================
  6. #include "ctoken.h"
  7. typedef NTSTATUS (NTAPI *PFN_NT_QUERY_SYSTEM_INFORMATION)
  8. (
  9. IN SYSTEM_INFORMATION_CLASS SystemInformationClass,
  10. OUT PVOID SystemInformation,
  11. IN ULONG SystemInformationLength,
  12. OUT PULONG ReturnLength OPTIONAL
  13. );
  14. class CProcess;
  15. class CSession;
  16. class CUser;
  17. class CUserComp;
  18. class CUserSessionCollection;
  19. class CUser
  20. {
  21. public:
  22. CUser() : m_sidUser(NULL) {}
  23. CUser(
  24. PSID psidUser);
  25. CUser(
  26. const CUser& user);
  27. virtual ~CUser();
  28. bool IsValid();
  29. PSID GetPSID() const
  30. {
  31. return m_sidUser;
  32. }
  33. void GetSidString(
  34. CHString& str) const;
  35. private:
  36. void Copy(
  37. CUser& out) const;
  38. PSID m_sidUser;
  39. bool m_fValid;
  40. };
  41. // Comparison class required for multimap
  42. // costructor involving non-standard key
  43. // type (i.e., a CUser) in the map.
  44. class CUserComp
  45. {
  46. public:
  47. CUserComp() {}
  48. virtual ~CUserComp() {}
  49. bool operator()(
  50. const CUser& userFirst,
  51. const CUser& userSecond) const
  52. {
  53. bool fRet;
  54. CHString chstr1, chstr2;
  55. userFirst.GetSidString(chstr1);
  56. userSecond.GetSidString(chstr2);
  57. long lcmp = chstr1.CompareNoCase(chstr2);
  58. (lcmp < 0) ? fRet = true : fRet = false;
  59. return fRet;
  60. }
  61. };
  62. class CProcess
  63. {
  64. public:
  65. // Constructors and destructors
  66. CProcess();
  67. CProcess(
  68. DWORD dwPID,
  69. LPCWSTR wstrImageName);
  70. CProcess(
  71. const CProcess& process);
  72. virtual ~CProcess();
  73. // Accessor functions
  74. DWORD GetPID() const;
  75. CHString GetImageName() const;
  76. private:
  77. DWORD m_dwPID;
  78. CHString m_chstrImageName;
  79. void Copy(
  80. CProcess& process) const;
  81. };
  82. // vector and iterator for getting a session's processes...
  83. typedef std::vector<CProcess> PROCESS_VECTOR;
  84. typedef PROCESS_VECTOR::iterator PROCESS_ITERATOR;
  85. class CSession
  86. {
  87. public:
  88. // Constructors and destructors
  89. CSession() {}
  90. CSession(
  91. const LUID& luidSessionID);
  92. CSession(
  93. const CSession& ses);
  94. virtual ~CSession() {}
  95. // Accessor functions
  96. LUID GetLUID() const;
  97. __int64 GetLUIDint64() const;
  98. CHString GetAuthenticationPkg() const;
  99. ULONG GetLogonType() const;
  100. __int64 GetLogonTime() const;
  101. // Enumerate list of processes
  102. CProcess* GetFirstProcess(
  103. PROCESS_ITERATOR& pos);
  104. CProcess* GetNextProcess(
  105. PROCESS_ITERATOR& pos);
  106. // Allow easy impersonation of
  107. // the session's first process
  108. HANDLE Impersonate();
  109. DWORD GetImpProcPID();
  110. friend CUserSessionCollection;
  111. // Checks a string representation
  112. // of a session id for validity
  113. bool IsSessionIDValid(
  114. LPCWSTR wstrSessionID);
  115. private:
  116. void Copy(
  117. CSession& sesCopy) const;
  118. CHString m_chstrAuthPkg;
  119. ULONG m_ulLogonType;
  120. __int64 i64LogonTime;
  121. LUID m_luid;
  122. PROCESS_VECTOR m_vecProcesses;
  123. };
  124. // map and iterator for relating users and sessions...
  125. typedef std::multimap<CUser, CSession, CUserComp> USER_SESSION_MAP;
  126. typedef USER_SESSION_MAP::iterator USER_SESSION_ITERATOR;
  127. // Custom iterator used in enumerating processes from
  128. // CUserSessionCollection.
  129. struct USER_SESSION_PROCESS_ITERATOR
  130. {
  131. friend CUserSessionCollection;
  132. private:
  133. USER_SESSION_ITERATOR usIter;
  134. PROCESS_ITERATOR procIter;
  135. };
  136. class CUserSessionCollection
  137. {
  138. public:
  139. // Constructors and destructors
  140. CUserSessionCollection();
  141. CUserSessionCollection(
  142. const CUserSessionCollection& sescol);
  143. virtual ~CUserSessionCollection() {}
  144. // Method to refresh map
  145. DWORD Refresh();
  146. // Methods to check whether a particular
  147. // session is in the map
  148. bool IsSessionMapped(
  149. LUID& luidSes);
  150. bool CUserSessionCollection::IsSessionMapped(
  151. __int64 i64luidSes);
  152. // Support enumeration of users
  153. CUser* GetFirstUser(
  154. USER_SESSION_ITERATOR& pos);
  155. CUser* GetNextUser(
  156. USER_SESSION_ITERATOR& pos);
  157. // Support enumeration of sessions
  158. // belonging to a particular user.
  159. CSession* GetFirstSessionOfUser(
  160. CUser& usr,
  161. USER_SESSION_ITERATOR& pos);
  162. CSession* GetNextSessionOfUser(
  163. USER_SESSION_ITERATOR& pos);
  164. // Support enumeration of all sessions
  165. CSession* GetFirstSession(
  166. USER_SESSION_ITERATOR& pos);
  167. CSession* GetNextSession(
  168. USER_SESSION_ITERATOR& pos);
  169. // Support finding a particular session
  170. CSession* FindSession(
  171. LUID& luidSes);
  172. CSession* FindSession(
  173. __int64 i64luidSes);
  174. // Support enumeration of processes
  175. // belonging to a particular user
  176. CProcess* GetFirstProcessOfUser(
  177. CUser& usr,
  178. USER_SESSION_PROCESS_ITERATOR& pos);
  179. CProcess* GetNextProcessOfUser(
  180. USER_SESSION_PROCESS_ITERATOR& pos);
  181. // Support enumeration of all processes
  182. CProcess* GetFirstProcess(
  183. USER_SESSION_PROCESS_ITERATOR& pos);
  184. CProcess* GetNextProcess(
  185. USER_SESSION_PROCESS_ITERATOR& pos);
  186. private:
  187. DWORD CollectSessions();
  188. DWORD CollectNoProcessesSessions();
  189. void Copy(
  190. CUserSessionCollection& out) const;
  191. DWORD GetProcessList(
  192. std::vector<CProcess>& vecProcesses) const;
  193. DWORD EnablePrivilegeOnCurrentThread(
  194. LPCTSTR szPriv) const;
  195. bool FindSessionInternal(
  196. LUID& luidSes,
  197. USER_SESSION_ITERATOR& usiOut);
  198. USER_SESSION_MAP m_usr2ses;
  199. };
  200. // This version is a smart handle
  201. // for use with thread tokens we
  202. // are impersonating. On destruction,
  203. // it reverts to the handle it
  204. // encapsulates.
  205. class SmartRevertTokenHANDLE
  206. {
  207. private:
  208. HANDLE m_h;
  209. public:
  210. SmartRevertTokenHANDLE()
  211. : m_h(INVALID_HANDLE_VALUE) {}
  212. SmartRevertTokenHANDLE(
  213. HANDLE h)
  214. : m_h(h) {}
  215. ~SmartRevertTokenHANDLE()
  216. {
  217. if ( FALSE == Revert () )
  218. {
  219. throw CFramework_Exception(L"SetThreadToken failed", GetLastError());
  220. }
  221. }
  222. HANDLE operator =(HANDLE h)
  223. {
  224. if ( FALSE == Revert () )
  225. {
  226. throw CFramework_Exception(L"SetThreadToken failed", GetLastError());
  227. }
  228. m_h = h;
  229. return h;
  230. }
  231. operator HANDLE() const
  232. {
  233. return m_h;
  234. }
  235. HANDLE* operator &()
  236. {
  237. if ( FALSE == Revert () )
  238. {
  239. throw CFramework_Exception(L"SetThreadToken failed", GetLastError());
  240. }
  241. m_h = INVALID_HANDLE_VALUE;
  242. return &m_h;
  243. }
  244. private :
  245. BOOL Revert ()
  246. {
  247. BOOL bRet = FALSE ;
  248. if ( m_h && INVALID_HANDLE_VALUE != m_h )
  249. {
  250. CThreadToken cpt ( m_h );
  251. if ( cpt.IsValidToken () )
  252. {
  253. HANDLE hCurThread = ::GetCurrentThread();
  254. TOKEN_TYPE type;
  255. if ( cpt.GetTokenType ( type ) )
  256. {
  257. if ( TokenPrimary == type )
  258. {
  259. CToken ct;
  260. if ( ct.Duplicate ( cpt, FALSE ) )
  261. {
  262. bRet = ::SetThreadToken ( &hCurThread, ct.GetTokenHandle () );
  263. }
  264. }
  265. else
  266. {
  267. bRet = ::SetThreadToken ( &hCurThread, cpt.GetTokenHandle () ) ;
  268. }
  269. if (!bRet)
  270. {
  271. LogMessage2( L"Failed to SetThreadToken in SmartRevertTokenHANDLE with error %d", ::GetLastError() );
  272. }
  273. }
  274. }
  275. CloseHandle(m_h);
  276. }
  277. else
  278. {
  279. //
  280. // smart revert was created from invalid handle
  281. // there is nothing we should do here !
  282. //
  283. bRet = TRUE ;
  284. }
  285. return bRet ;
  286. }
  287. };
  288. // Helper for automatic cleanup of
  289. // pointers returned from the various
  290. // enumeration functions.
  291. template<class T>
  292. class SmartDelete
  293. {
  294. private:
  295. T* m_ptr;
  296. public:
  297. SmartDelete()
  298. : m_ptr(NULL) {}
  299. SmartDelete(
  300. T* ptr)
  301. : m_ptr(hptr) {}
  302. virtual ~SmartDelete()
  303. {
  304. if(m_ptr != NULL)
  305. {
  306. delete m_ptr;
  307. m_ptr = NULL;
  308. }
  309. }
  310. T* operator =(T* ptrRight)
  311. {
  312. if(m_ptr != NULL)
  313. {
  314. delete m_ptr;
  315. m_ptr = NULL;
  316. }
  317. m_ptr = ptrRight;
  318. return ptrRight;
  319. }
  320. operator T*() const
  321. {
  322. return m_ptr;
  323. }
  324. T* operator &()
  325. {
  326. if(m_ptr != NULL)
  327. {
  328. delete m_ptr;
  329. m_ptr = NULL;
  330. }
  331. m_ptr = NULL;
  332. return m_ptr;
  333. }
  334. T* operator->() const
  335. {
  336. return m_ptr;
  337. }
  338. };