Source code of Windows XP (NT5)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2359 lines
52 KiB

  1. // This is a part of the Active Template Library.
  2. // Copyright (C) 1996-2001 Microsoft Corporation
  3. // All rights reserved.
  4. //
  5. // This source code is only intended as a supplement to the
  6. // Active Template Library Reference and related
  7. // electronic documentation provided with the library.
  8. // See these sources for detailed information regarding the
  9. // Active Template Library product.
  10. #ifndef __ATLSESSION_H__
  11. #define __ATLSESSION_H__
  12. #pragma once
  13. #pragma warning(push)
  14. #pragma warning(disable: 4702) // unreachable code
  15. #include <atldbcli.h>
  16. #include <atlcom.h>
  17. #include <atlstr.h>
  18. #include <stdio.h>
  19. #include <atlcoll.h>
  20. #include <atltime.h>
  21. #include <atlcrypt.h>
  22. #include <atlenc.h>
  23. #include <atlutil.h>
  24. #include <atlcache.h>
  25. #include <atlspriv.h>
  26. #include <atlsiface.h>
  27. #ifndef SESSION_KEY_LENGTH
  28. #define SESSION_KEY_LENGTH 37
  29. #endif
  30. #ifndef MAX_SESSION_KEY_LEN
  31. #define MAX_SESSION_KEY_LEN 128
  32. #endif
  33. #ifndef MAX_VARIABLE_NAME_LENGTH
  34. #define MAX_VARIABLE_NAME_LENGTH 50
  35. #endif
  36. #ifndef MAX_VARIABLE_VALUE_LENGTH
  37. #define MAX_VARIABLE_VALUE_LENGTH 128
  38. #endif
  39. #ifndef DEFAULT_SQL_LEN
  40. #define DEFAULT_SQL_LEN 1024
  41. #endif
  42. #ifndef MAX_CONNECTION_STRING_LEN
  43. #define MAX_CONNECTION_STRING_LEN 2048
  44. #endif
  45. #ifndef SESSION_COOKIE_NAME
  46. #define SESSION_COOKIE_NAME "SESSIONID"
  47. #endif
  48. #ifndef ATL_SESSION_TIMEOUT
  49. #define ATL_SESSION_TIMEOUT 600000 //10 min
  50. #endif
  51. #ifndef ATL_SESSION_SWEEPER_TIMEOUT
  52. #define ATL_SESSION_SWEEPER_TIMEOUT 1000 // 1sec
  53. #endif
  54. #define INVALID_DB_SESSION_POS 0x0
  55. #define ATL_DBSESSION_ID _T("__ATL_SESSION_DB_CONNECTION")
  56. namespace ATL {
  57. // CSessionNameGenerator
  58. // This is a helper class that generates random data for session key
  59. // names. This class tries to use the CryptoApi to generate random
  60. // bytes for the session key name. If the CryptoApi isn't available
  61. // then the CRT rand() is used to generate the random bytes. This
  62. // class's GetNewSessionName member function is used to actually
  63. // generate the session name.
  64. class CSessionNameGenerator :
  65. public CCryptProv
  66. {
  67. public:
  68. bool m_bCryptNotAvailable;
  69. enum {MIN_SESSION_KEY_LEN=5};
  70. CSessionNameGenerator() throw() :
  71. m_bCryptNotAvailable(false)
  72. {
  73. // Note that the crypto api is being
  74. // initialized with no private key
  75. // information
  76. HRESULT hr = InitVerifyContext();
  77. m_bCryptNotAvailable = FAILED(hr) ? true : false;
  78. }
  79. // This function creates a new session name and base64 encodes it.
  80. // The base64 encoding algorithm used needs at least MIN_SESSION_KEY_LEN
  81. // bytes to work correctly. Since we stack allocate the temporary
  82. // buffer that holds the key name, the buffer must be less than or equal to
  83. // the MAX_SESSION_KEY_LEN in size.
  84. HRESULT GetNewSessionName(LPSTR szNewID, DWORD *pdwSize) throw()
  85. {
  86. HRESULT hr = E_FAIL;
  87. if (!pdwSize)
  88. return E_POINTER;
  89. if (*pdwSize < MIN_SESSION_KEY_LEN ||
  90. *pdwSize > MAX_SESSION_KEY_LEN)
  91. return E_INVALIDARG;
  92. if (!szNewID)
  93. return E_POINTER;
  94. BYTE key[MAX_SESSION_KEY_LEN] = {0x0};
  95. // calculate the number of bytes that will fit in the
  96. // buffer we've been passed
  97. DWORD dwDataSize = CalcMaxInputSize(*pdwSize);
  98. if (dwDataSize && *pdwSize >= (DWORD)(Base64EncodeGetRequiredLength(dwDataSize,
  99. ATL_BASE64_FLAG_NOCRLF)))
  100. {
  101. int dwKeySize = *pdwSize;
  102. hr = GenerateRandomName(key, dwDataSize);
  103. if (SUCCEEDED(hr))
  104. {
  105. if( Base64Encode(key,
  106. dwDataSize,
  107. szNewID,
  108. &dwKeySize,
  109. ATL_BASE64_FLAG_NOCRLF) )
  110. {
  111. //null terminate
  112. szNewID[dwKeySize]=0;
  113. *pdwSize = dwKeySize+1;
  114. }
  115. else
  116. hr = E_FAIL;
  117. }
  118. else
  119. {
  120. *pdwSize = (DWORD)(Base64EncodeGetRequiredLength(dwDataSize,
  121. ATL_BASE64_FLAG_NOCRLF));
  122. return E_OUTOFMEMORY;
  123. }
  124. }
  125. return hr;
  126. }
  127. DWORD CalcMaxInputSize(DWORD nOutputSize) throw()
  128. {
  129. if (nOutputSize < (DWORD)MIN_SESSION_KEY_LEN)
  130. return 0;
  131. // subtract one from the output size to make room
  132. // for the NULL terminator in the output then
  133. // calculate the biggest number of input bytes that
  134. // when base64 encoded will fit in a buffer of size
  135. // nOutputSize (including base64 padding)
  136. int nInputSize = ((nOutputSize-1)*3)/4;
  137. int factor = ((nInputSize*4)/3)%4;
  138. if (factor)
  139. nInputSize -= factor;
  140. return nInputSize;
  141. }
  142. HRESULT GenerateRandomName(BYTE *pBuff, DWORD dwBuffSize) throw()
  143. {
  144. if (!pBuff)
  145. return E_POINTER;
  146. if (!dwBuffSize)
  147. return E_UNEXPECTED;
  148. if (!m_bCryptNotAvailable && GetHandle())
  149. {
  150. // Use the crypto api to generate random data.
  151. return GenRandom(dwBuffSize, pBuff);
  152. }
  153. // CryptoApi isn't available so we generate
  154. // random data using rand. We seed the random
  155. // number generator with a seed that is a combination
  156. // of bytes from an arbitrary number and the system
  157. // time which changes every millisecond so it will
  158. // be different for every call to this function.
  159. FILETIME ft;
  160. GetSystemTimeAsFileTime(&ft);
  161. static DWORD dwVal = 0x21;
  162. DWORD dwSeed = (dwVal++ << 0x18) | (ft.dwLowDateTime & 0x00ffff00) | dwVal++ & 0x000000ff;
  163. srand(dwSeed);
  164. BYTE *pCurr = pBuff;
  165. // fill buffer with random bytes
  166. for (int i=0; i < (int)dwBuffSize; i++)
  167. {
  168. *pCurr = (BYTE) (rand() & 0x000000ff);
  169. pCurr++;
  170. }
  171. return S_OK;
  172. }
  173. };
  174. //
  175. // CDefaultQueryClass
  176. // returns Query strings for use in SQL queries used
  177. // by the database persisted session service.
  178. class CDefaultQueryClass
  179. {
  180. public:
  181. LPCTSTR GetSessionRefDelete() throw()
  182. {
  183. return _T("DELETE FROM SessionReferences ")
  184. _T("WHERE SessionID=? AND RefCount <= 0 ")
  185. _T("AND DATEDIFF(millisecond, LastAccess, getdate()) > TimeoutMs");
  186. }
  187. LPCTSTR GetSessionRefIsExpired() throw()
  188. {
  189. return _T("SELECT SessionID FROM SessionReferences ")
  190. _T("WHERE (SessionID=?) AND (DATEDIFF(millisecond, LastAccess, getdate()) > TimeoutMs)");
  191. }
  192. LPCTSTR GetSessionRefDeleteFinal() throw()
  193. {
  194. return _T("DELETE FROM SessionReferences ")
  195. _T("WHERE SessionID=?");
  196. }
  197. LPCTSTR GetSessionRefCreate() throw()
  198. {
  199. return _T("INSERT INTO SessionReferences ")
  200. _T("(SessionID, LastAccess, RefCount, TimeoutMs) ")
  201. _T("VALUES (?, getdate(), 1, ?)");
  202. }
  203. LPCTSTR GetSessionRefUpdateTimeout() throw()
  204. {
  205. return _T("UPDATE SessionReferences ")
  206. _T("SET TimeoutMs=? WHERE SessionID=?");
  207. }
  208. LPCTSTR GetSessionRefAddRef() throw()
  209. {
  210. return _T("UPDATE SessionReferences ")
  211. _T("SET RefCount=RefCount+1, ")
  212. _T("LastAccess=getdate() ")
  213. _T("WHERE SessionID=?");
  214. }
  215. LPCTSTR GetSessionRefRemoveRef() throw()
  216. {
  217. return _T("UPDATE SessionReferences ")
  218. _T("SET RefCount=RefCount-1, ")
  219. _T("LastAccess=getdate() ")
  220. _T("WHERE SessionID=?");
  221. }
  222. LPCTSTR GetSessionRefAccess() throw()
  223. {
  224. return _T("UPDATE SessionReferences ")
  225. _T("SET LastAccess=getdate() ")
  226. _T("WHERE SessionID=?");
  227. }
  228. LPCTSTR GetSessionRefSelect() throw()
  229. {
  230. return _T("SELECT * FROM SessionReferences ")
  231. _T("WHERE SessionID=?");
  232. }
  233. LPCTSTR GetSessionRefGetCount() throw()
  234. {
  235. return _T("SELECT COUNT(*) FROM SessionReferences");
  236. }
  237. LPCTSTR GetSessionVarCount() throw()
  238. {
  239. return _T("SELECT COUNT(*) FROM SessionVariables WHERE SessionID=?");
  240. }
  241. LPCTSTR GetSessionVarInsert() throw()
  242. {
  243. return _T("INSERT INTO SessionVariables ")
  244. _T("(VariableValue, SessionID, VariableName) ")
  245. _T("VALUES (?, ?, ?)");
  246. }
  247. LPCTSTR GetSessionVarUpdate() throw()
  248. {
  249. return _T("UPDATE SessionVariables ")
  250. _T("SET VariableValue=? ")
  251. _T("WHERE SessionID=? AND VariableName=?");
  252. }
  253. LPCTSTR GetSessionVarDeleteVar() throw()
  254. {
  255. return _T("DELETE FROM SessionVariables ")
  256. _T("WHERE SessionID=? AND VariableName=?");
  257. }
  258. LPCTSTR GetSessionVarDeleteAllVars() throw()
  259. {
  260. return _T("DELETE FROM SessionVariables WHERE (SessionID=?)");
  261. }
  262. LPCTSTR GetSessionVarSelectVar()throw()
  263. {
  264. return _T("SELECT SessionID, VariableName, VariableValue ")
  265. _T("FROM SessionVariables ")
  266. _T("WHERE SessionID=? AND VariableName=?");
  267. }
  268. LPCTSTR GetSessionVarSelectAllVars() throw()
  269. {
  270. return _T("SELECT SessionID, VariableName, VariableValue ")
  271. _T("FROM SessionVariables ")
  272. _T("WHERE SessionID=?");
  273. }
  274. LPCTSTR GetSessionReferencesSet() throw()
  275. {
  276. return _T("UPDATE SessionReferences SET TimeoutMs=?");
  277. }
  278. };
  279. // Contains the data for the session variable accessors
  280. class CSessionDataBase
  281. {
  282. public:
  283. TCHAR m_szSessionID[MAX_SESSION_KEY_LEN];
  284. TCHAR m_VariableName[MAX_VARIABLE_NAME_LENGTH];
  285. BYTE m_VariableValue[MAX_VARIABLE_VALUE_LENGTH];
  286. DWORD m_VariableLen;
  287. CSessionDataBase() throw()
  288. {
  289. m_szSessionID[0] = '\0';
  290. m_VariableName[0] = '\0';
  291. m_VariableValue[0] = '\0';
  292. m_VariableLen = 0;
  293. }
  294. HRESULT Assign(LPCTSTR szSessionID, LPCTSTR szVarName, VARIANT *pVal) throw()
  295. {
  296. HRESULT hr = S_OK;
  297. CVariantStream stream;
  298. if ( szSessionID )
  299. {
  300. if ( _tcslen(szSessionID)< MAX_SESSION_KEY_LEN)
  301. _tcscpy(m_szSessionID, szSessionID);
  302. else
  303. hr = E_OUTOFMEMORY;
  304. }
  305. else
  306. return E_INVALIDARG;
  307. if (szVarName)
  308. if ( _tcslen(szVarName) < MAX_VARIABLE_NAME_LENGTH)
  309. _tcscpy(m_VariableName, szVarName);
  310. else
  311. hr = E_OUTOFMEMORY;
  312. if (pVal)
  313. {
  314. hr = stream.InsertVariant(pVal);
  315. if (hr == S_OK)
  316. {
  317. BYTE *pBytes = stream.m_stream;
  318. size_t size = stream.GetVariantSize();
  319. if (pBytes && size && size < MAX_VARIABLE_VALUE_LENGTH)
  320. {
  321. memcpy(m_VariableValue, pBytes, stream.GetVariantSize());
  322. m_VariableLen = (DWORD)size;
  323. }
  324. else
  325. hr = E_UNEXPECTED;
  326. }
  327. }
  328. return hr;
  329. }
  330. };
  331. // Use to select a session variable given the name
  332. // of a session and the name of a variable.
  333. class CSessionDataSelector : public CSessionDataBase
  334. {
  335. public:
  336. BEGIN_COLUMN_MAP(CSessionDataSelector)
  337. COLUMN_ENTRY(1, m_szSessionID)
  338. COLUMN_ENTRY(2, m_VariableName)
  339. COLUMN_ENTRY_LENGTH(3, m_VariableValue, m_VariableLen)
  340. END_COLUMN_MAP()
  341. BEGIN_PARAM_MAP(CSessionDataSelector)
  342. SET_PARAM_TYPE(DBPARAMIO_INPUT)
  343. COLUMN_ENTRY(1, m_szSessionID)
  344. COLUMN_ENTRY(2, m_VariableName)
  345. END_PARAM_MAP()
  346. };
  347. // Use to select all session variables given the name of
  348. // of a session.
  349. class CAllSessionDataSelector : public CSessionDataBase
  350. {
  351. public:
  352. BEGIN_COLUMN_MAP(CAllSessionDataSelector)
  353. COLUMN_ENTRY(1, m_szSessionID)
  354. COLUMN_ENTRY(2, m_VariableName)
  355. COLUMN_ENTRY_LENGTH(3, m_VariableValue, m_VariableLen)
  356. END_COLUMN_MAP()
  357. BEGIN_PARAM_MAP(CAllSessionDataSelector)
  358. SET_PARAM_TYPE(DBPARAMIO_INPUT)
  359. COLUMN_ENTRY(1, m_szSessionID)
  360. END_PARAM_MAP()
  361. };
  362. // Use to update the value of a session variable
  363. class CSessionDataUpdator : public CSessionDataBase
  364. {
  365. public:
  366. BEGIN_PARAM_MAP(CSessionDataUpdator)
  367. SET_PARAM_TYPE(DBPARAMIO_INPUT)
  368. COLUMN_ENTRY_LENGTH(1, m_VariableValue, m_VariableLen)
  369. COLUMN_ENTRY(2, m_szSessionID)
  370. COLUMN_ENTRY(3, m_VariableName)
  371. END_PARAM_MAP()
  372. };
  373. // Use to delete a session variable given the
  374. // session name and the name of the variable
  375. class CSessionDataDeletor
  376. {
  377. public:
  378. CSessionDataDeletor()
  379. {
  380. m_szSessionID[0] = '\0';
  381. m_VariableName[0] = '\0';
  382. }
  383. TCHAR m_szSessionID[MAX_SESSION_KEY_LEN];
  384. TCHAR m_VariableName[MAX_VARIABLE_NAME_LENGTH];
  385. HRESULT Assign(LPCTSTR szSessionID, LPCTSTR szVarName) throw()
  386. {
  387. if (szSessionID)
  388. {
  389. if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
  390. _tcscpy(m_szSessionID, szSessionID);
  391. else
  392. return E_OUTOFMEMORY;
  393. }
  394. if (szVarName)
  395. {
  396. if(_tcslen(szVarName) < MAX_VARIABLE_NAME_LENGTH)
  397. _tcscpy(m_VariableName, szVarName);
  398. else
  399. return E_OUTOFMEMORY;
  400. }
  401. return S_OK;
  402. }
  403. BEGIN_PARAM_MAP(CSessionDataDeletor)
  404. SET_PARAM_TYPE(DBPARAMIO_INPUT)
  405. COLUMN_ENTRY(1, m_szSessionID)
  406. COLUMN_ENTRY(2, m_VariableName)
  407. END_PARAM_MAP()
  408. };
  409. class CSessionDataDeleteAll
  410. {
  411. public:
  412. TCHAR m_szSessionID[MAX_SESSION_KEY_LEN];
  413. HRESULT Assign(LPCTSTR szSessionID) throw()
  414. {
  415. if (!szSessionID)
  416. return E_INVALIDARG;
  417. if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
  418. _tcscpy(m_szSessionID, szSessionID);
  419. else
  420. return E_OUTOFMEMORY;
  421. return S_OK;
  422. }
  423. BEGIN_PARAM_MAP(CSessionDataDeleteAll)
  424. SET_PARAM_TYPE(DBPARAMIO_INPUT)
  425. COLUMN_ENTRY(1, m_szSessionID)
  426. END_PARAM_MAP()
  427. };
  428. // Used for retrieving the count of session variables for
  429. // a given session ID.
  430. class CCountAccessor
  431. {
  432. public:
  433. LONG m_nCount;
  434. TCHAR m_szSessionID[MAX_SESSION_KEY_LEN];
  435. CCountAccessor() throw()
  436. {
  437. m_szSessionID[0] = '\0';
  438. m_nCount = 0;
  439. }
  440. HRESULT Assign(LPCTSTR szSessionID) throw()
  441. {
  442. if (!szSessionID)
  443. return E_INVALIDARG;
  444. if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
  445. _tcscpy(m_szSessionID, szSessionID);
  446. else
  447. return E_OUTOFMEMORY;
  448. return S_OK;
  449. }
  450. BEGIN_COLUMN_MAP(CCountAccessor)
  451. COLUMN_ENTRY(1, m_nCount)
  452. END_COLUMN_MAP()
  453. BEGIN_PARAM_MAP(CCountAccessor)
  454. SET_PARAM_TYPE(DBPARAMIO_INPUT)
  455. COLUMN_ENTRY(1, m_szSessionID)
  456. END_PARAM_MAP()
  457. };
  458. // Used for updating entries in the session
  459. // references table, given a session ID
  460. class CSessionRefUpdator
  461. {
  462. public:
  463. TCHAR m_SessionID[MAX_SESSION_KEY_LEN];
  464. HRESULT Assign(LPCTSTR szSessionID) throw()
  465. {
  466. if (!szSessionID)
  467. return E_INVALIDARG;
  468. if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
  469. _tcscpy(m_SessionID, szSessionID);
  470. else
  471. return E_OUTOFMEMORY;
  472. return S_OK;
  473. }
  474. BEGIN_PARAM_MAP(CSessionRefUpdator)
  475. SET_PARAM_TYPE(DBPARAMIO_INPUT)
  476. COLUMN_ENTRY(1, m_SessionID)
  477. END_PARAM_MAP()
  478. };
  479. class CSessionRefIsExpired
  480. {
  481. public:
  482. TCHAR m_SessionID[MAX_SESSION_KEY_LEN];
  483. TCHAR m_SessionIDOut[MAX_SESSION_KEY_LEN];
  484. HRESULT Assign(LPCTSTR szSessionID) throw()
  485. {
  486. m_SessionIDOut[0]=0;
  487. if (!szSessionID)
  488. return E_INVALIDARG;
  489. if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
  490. _tcscpy(m_SessionID, szSessionID);
  491. else
  492. return E_OUTOFMEMORY;
  493. return S_OK;
  494. }
  495. BEGIN_COLUMN_MAP(CSessionRefIsExpired)
  496. COLUMN_ENTRY(1, m_SessionIDOut)
  497. END_COLUMN_MAP()
  498. BEGIN_PARAM_MAP(CSessionRefIsExpired)
  499. SET_PARAM_TYPE(DBPARAMIO_INPUT)
  500. COLUMN_ENTRY(1, m_SessionID)
  501. END_PARAM_MAP()
  502. };
  503. class CSetAllTimeouts
  504. {
  505. public:
  506. unsigned __int64 m_dwNewTimeout;
  507. HRESULT Assign(unsigned __int64 dwNewValue)
  508. {
  509. m_dwNewTimeout = dwNewValue;
  510. return S_OK;
  511. }
  512. BEGIN_PARAM_MAP(CSetAllTimeouts)
  513. SET_PARAM_TYPE(DBPARAMIO_INPUT)
  514. COLUMN_ENTRY(1, m_dwNewTimeout)
  515. END_PARAM_MAP()
  516. };
  517. class CSessionRefUpdateTimeout
  518. {
  519. public:
  520. TCHAR m_SessionID[MAX_SESSION_KEY_LEN];
  521. unsigned __int64 m_nNewTimeout;
  522. HRESULT Assign(LPCTSTR szSessionID, unsigned __int64 nNewTimeout) throw()
  523. {
  524. if (!szSessionID)
  525. return E_INVALIDARG;
  526. if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
  527. _tcscpy(m_SessionID, szSessionID);
  528. else
  529. return E_OUTOFMEMORY;
  530. m_nNewTimeout = nNewTimeout;
  531. return S_OK;
  532. }
  533. BEGIN_PARAM_MAP(CSessionRefUpdateTimeout)
  534. SET_PARAM_TYPE(DBPARAMIO_INPUT)
  535. COLUMN_ENTRY(1, m_nNewTimeout)
  536. COLUMN_ENTRY(2, m_SessionID)
  537. END_PARAM_MAP()
  538. };
  539. class CSessionRefSelector
  540. {
  541. public:
  542. TCHAR m_SessionID[MAX_SESSION_KEY_LEN];
  543. int m_RefCount;
  544. HRESULT Assign(LPCTSTR szSessionID) throw()
  545. {
  546. if (!szSessionID)
  547. return E_INVALIDARG;
  548. if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
  549. _tcscpy(m_SessionID, szSessionID);
  550. else
  551. return E_OUTOFMEMORY;
  552. return S_OK;
  553. }
  554. BEGIN_COLUMN_MAP(CSessionRefSelector)
  555. COLUMN_ENTRY(1, m_SessionID)
  556. COLUMN_ENTRY(3, m_RefCount)
  557. END_COLUMN_MAP()
  558. BEGIN_PARAM_MAP(CSessionRefSelector)
  559. SET_PARAM_TYPE(DBPARAMIO_INPUT)
  560. COLUMN_ENTRY(1, m_SessionID)
  561. END_PARAM_MAP()
  562. };
  563. class CSessionRefCount
  564. {
  565. public:
  566. LONG m_nCount;
  567. BEGIN_COLUMN_MAP(CSessionRefCount)
  568. COLUMN_ENTRY(1, m_nCount)
  569. END_COLUMN_MAP()
  570. };
  571. // Used for creating new entries in the session
  572. // references table.
  573. class CSessionRefCreator
  574. {
  575. public:
  576. TCHAR m_SessionID[MAX_SESSION_KEY_LEN];
  577. unsigned __int64 m_TimeoutMs;
  578. HRESULT Assign(LPCTSTR szSessionID, unsigned __int64 timeout) throw()
  579. {
  580. if (!szSessionID)
  581. return E_INVALIDARG;
  582. if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
  583. {
  584. _tcscpy(m_SessionID, szSessionID);
  585. m_TimeoutMs = timeout;
  586. }
  587. else
  588. return E_OUTOFMEMORY;
  589. return S_OK;
  590. }
  591. BEGIN_PARAM_MAP(CSessionRefCreator)
  592. SET_PARAM_TYPE(DBPARAMIO_INPUT)
  593. COLUMN_ENTRY(1, m_SessionID)
  594. COLUMN_ENTRY(2, m_TimeoutMs)
  595. END_PARAM_MAP()
  596. };
  597. // CDBSession
  598. // This session persistance class persists session variables to
  599. // an OLEDB datasource. The following table gives a general description
  600. // of the table schema for the tables this class uses.
  601. //
  602. // TableName: SessionVariables
  603. // Column Name Type Description
  604. // 1 SessionID char[MAX_SESSION_KEY_LEN] Session Key name
  605. // 2 VariableName char[MAX_VARIABLE_NAME_LENGTH] Variable Name
  606. // 3 VariableValue varbinary[MAX_VARIABLE_VALUE_LENGTH] Variable Value
  607. //
  608. // TableName: SessionReferences
  609. // Column Name Type Description
  610. // 1 SessionID char[MAX_SESSION_KEY_LEN] Session Key Name.
  611. // 2 LastAccess datetime Date and time of last access to this session.
  612. // 3 RefCount int Current references on this session.
  613. // 4 TimeoutMS int Timeout value for the session in milli seconds
  614. typedef bool (*PFN_GETPROVIDERINFO)(DWORD_PTR, wchar_t **);
  615. template <class QueryClass=CDefaultQueryClass>
  616. class CDBSession:
  617. public ISession,
  618. public CComObjectRootEx<CComGlobalsThreadModel>
  619. {
  620. typedef CCommand<CAccessor<CAllSessionDataSelector> > iterator_accessor;
  621. public:
  622. typedef QueryClass DBQUERYCLASS_TYPE;
  623. BEGIN_COM_MAP(CDBSession)
  624. COM_INTERFACE_ENTRY(ISession)
  625. END_COM_MAP()
  626. CDBSession() throw():
  627. m_dwTimeout(ATL_SESSION_TIMEOUT)
  628. {
  629. m_szSessionName[0] = '\0';
  630. }
  631. ~CDBSession() throw()
  632. {
  633. }
  634. void FinalRelease()throw()
  635. {
  636. SessionUnlock();
  637. }
  638. STDMETHOD(SetVariable)(LPCSTR szName, VARIANT Val) throw()
  639. {
  640. HRESULT hr = E_FAIL;
  641. if (!szName)
  642. return E_INVALIDARG;
  643. // Get the data connection for this thread.
  644. CDataConnection dataconn;
  645. hr = GetSessionConnection(&dataconn, m_spServiceProvider);
  646. if (hr != S_OK)
  647. return hr;
  648. // Update the last access time for this session
  649. hr = Access();
  650. if (hr != S_OK)
  651. return hr;
  652. // Allocate an updator command and fill out it's input parameters.
  653. CCommand<CAccessor<CSessionDataUpdator> > command;
  654. _ATLTRY
  655. {
  656. CA2CT name(szName);
  657. hr = command.Assign(m_szSessionName, name, &Val);
  658. }
  659. _ATLCATCHALL()
  660. {
  661. hr = E_OUTOFMEMORY;
  662. }
  663. if (hr != S_OK)
  664. return hr;
  665. // Try an update. Update will fail if the variable is not already there.
  666. LONG nRows = 0;
  667. hr = command.Open(dataconn,
  668. m_QueryObj.GetSessionVarUpdate(),
  669. NULL, &nRows, DBGUID_DEFAULT, false);
  670. if (hr == S_OK && nRows <= 0)
  671. hr = E_UNEXPECTED;
  672. if (hr != S_OK)
  673. {
  674. // Try an insert
  675. hr = command.Open(dataconn, m_QueryObj.GetSessionVarInsert(), NULL, &nRows, DBGUID_DEFAULT, false);
  676. if (hr == S_OK && nRows <=0)
  677. hr = E_UNEXPECTED;
  678. }
  679. return hr;
  680. }
  681. // Warning: For string data types, depending on the configuration of
  682. // your database, strings might be returned with trailing white space.
  683. STDMETHOD(GetVariable)(LPCSTR szName, VARIANT *pVal) throw()
  684. {
  685. HRESULT hr = E_FAIL;
  686. if (!szName)
  687. return E_INVALIDARG;
  688. if (pVal)
  689. VariantClear(pVal);
  690. else
  691. return E_POINTER;
  692. // Get the data connection for this thread
  693. CDataConnection dataconn;
  694. hr = GetSessionConnection(&dataconn, m_spServiceProvider);
  695. if (hr != S_OK)
  696. return hr;
  697. // Update the last access time for this session
  698. hr = Access();
  699. if (hr != S_OK)
  700. return hr;
  701. // Allocate a command a fill out it's input parameters.
  702. CCommand<CAccessor<CSessionDataSelector> > command;
  703. _ATLTRY
  704. {
  705. CA2CT name(szName);
  706. hr = command.Assign(m_szSessionName, name, NULL);
  707. }
  708. _ATLCATCHALL()
  709. {
  710. hr = E_OUTOFMEMORY;
  711. }
  712. if (hr == S_OK)
  713. {
  714. hr = command.Open(dataconn, m_QueryObj.GetSessionVarSelectVar());
  715. if (SUCCEEDED(hr))
  716. {
  717. if ( S_OK == (hr = command.MoveFirst()))
  718. {
  719. CStreamOnByteArray stream(command.m_VariableValue);
  720. CComVariant vOut;
  721. hr = vOut.ReadFromStream(static_cast<IStream*>(&stream));
  722. if (hr == S_OK)
  723. hr = vOut.Detach(pVal);
  724. }
  725. }
  726. }
  727. return hr;
  728. }
  729. STDMETHOD(RemoveVariable)(LPCSTR szName) throw()
  730. {
  731. HRESULT hr = E_FAIL;
  732. if (!szName)
  733. return E_INVALIDARG;
  734. // Get the data connection for this thread.
  735. CDataConnection dataconn;
  736. hr = GetSessionConnection(&dataconn, m_spServiceProvider);
  737. if (hr != S_OK)
  738. return hr;
  739. // update the last access time for this session
  740. hr = Access();
  741. if (hr != S_OK)
  742. return hr;
  743. // allocate a command and set it's input parameters
  744. CCommand<CAccessor<CSessionDataDeletor> > command;
  745. _ATLTRY
  746. {
  747. CA2CT name(szName);
  748. hr = command.Assign(m_szSessionName, name);
  749. }
  750. _ATLCATCHALL()
  751. {
  752. return E_OUTOFMEMORY;
  753. }
  754. // execute the command
  755. long nRows = 0;
  756. if (hr == S_OK)
  757. hr = command.Open(dataconn, m_QueryObj.GetSessionVarDeleteVar(),
  758. NULL, &nRows, DBGUID_DEFAULT, false);
  759. if (hr == S_OK && nRows <= 0)
  760. hr = E_UNEXPECTED;
  761. return hr;
  762. }
  763. // Gives the count of rows in the table for this session ID.
  764. STDMETHOD(GetCount)(long *pnCount) throw()
  765. {
  766. HRESULT hr = S_OK;
  767. if (pnCount)
  768. *pnCount = 0;
  769. else
  770. return E_POINTER;
  771. // Get the database connection for this thread.
  772. CDataConnection dataconn;
  773. hr = GetSessionConnection(&dataconn, m_spServiceProvider);
  774. if (hr != S_OK)
  775. return hr;
  776. hr = Access();
  777. if (hr != S_OK)
  778. return hr;
  779. CCommand<CAccessor<CCountAccessor> > command;
  780. hr = command.Assign(m_szSessionName);
  781. if (hr == S_OK)
  782. {
  783. hr = command.Open(dataconn, m_QueryObj.GetSessionVarCount());
  784. if (hr == S_OK)
  785. {
  786. if (S_OK == (hr = command.MoveFirst()))
  787. {
  788. *pnCount = command.m_nCount;
  789. hr = S_OK;
  790. }
  791. }
  792. }
  793. return hr;
  794. }
  795. STDMETHOD(RemoveAllVariables)() throw()
  796. {
  797. HRESULT hr = E_UNEXPECTED;
  798. // Get the data connection for this thread.
  799. CDataConnection dataconn;
  800. hr = GetSessionConnection(&dataconn, m_spServiceProvider);
  801. if (hr != S_OK)
  802. return hr;
  803. CCommand<CAccessor<CSessionDataDeleteAll> > command;
  804. hr = command.Assign(m_szSessionName);
  805. if (hr != S_OK)
  806. return hr;
  807. // delete all session variables
  808. hr = command.Open(dataconn, m_QueryObj.GetSessionVarDeleteAllVars(), NULL, NULL, DBGUID_DEFAULT, false);
  809. return hr;
  810. }
  811. // Iteration of variables works by taking a snapshot
  812. // of the sessions at the point in time BeginVariableEnum
  813. // is called, and then keeping an index variable that you use to
  814. // move through the snapshot rowset. It is important to know
  815. // that the handle returned in phEnum is not thread safe. It
  816. // should only be used by the calling thread.
  817. STDMETHOD(BeginVariableEnum)(HSESSIONENUM *phEnum, POSITION *pPOS) throw()
  818. {
  819. HRESULT hr = E_FAIL;
  820. if (!pPOS)
  821. return E_POINTER;
  822. if (phEnum)
  823. *phEnum = NULL;
  824. else
  825. return E_POINTER;
  826. // Get the data connection for this thread.
  827. CDataConnection dataconn;
  828. hr = GetSessionConnection(&dataconn, m_spServiceProvider);
  829. if (hr != S_OK)
  830. return hr;
  831. // Update the last access time for this session.
  832. hr = Access();
  833. if (hr != S_OK)
  834. return hr;
  835. // Allocate a new iterator accessor and initialize it's input parameters.
  836. iterator_accessor *pIteratorAccessor = NULL;
  837. ATLTRYALLOC(pIteratorAccessor = new iterator_accessor);
  838. if (!pIteratorAccessor)
  839. return E_OUTOFMEMORY;
  840. hr = pIteratorAccessor->Assign(m_szSessionName, NULL, NULL);
  841. if (hr == S_OK)
  842. {
  843. // execute the command and move to the first row of the recordset.
  844. hr = pIteratorAccessor->Open(dataconn,
  845. m_QueryObj.GetSessionVarSelectAllVars());
  846. if (hr == S_OK)
  847. {
  848. hr = pIteratorAccessor->MoveFirst();
  849. if (hr == S_OK)
  850. {
  851. *pPOS = (POSITION) INVALID_DB_SESSION_POS + 1;
  852. *phEnum = reinterpret_cast<HSESSIONENUM>(pIteratorAccessor);
  853. }
  854. }
  855. if (hr != S_OK)
  856. {
  857. *pPOS = INVALID_DB_SESSION_POS;
  858. *phEnum = NULL;
  859. delete pIteratorAccessor;
  860. }
  861. }
  862. return hr;
  863. }
  864. // The values for hEnum and pPos must have been initialized in a previous
  865. // call to BeginVariableEnum. On success, the out variant will hold the next
  866. // variable
  867. STDMETHOD(GetNextVariable)(HSESSIONENUM hEnum, POSITION *pPOS, LPSTR szName, DWORD dwLen, VARIANT *pVal) throw()
  868. {
  869. if (!pPOS)
  870. return E_INVALIDARG;
  871. if (pVal)
  872. VariantInit(pVal);
  873. else
  874. return E_POINTER;
  875. if (!hEnum)
  876. return E_UNEXPECTED;
  877. if (*pPOS <= INVALID_DB_SESSION_POS)
  878. return E_UNEXPECTED;
  879. iterator_accessor *pIteratorAccessor = reinterpret_cast<iterator_accessor*>(hEnum);
  880. // update the last access time.
  881. HRESULT hr = Access();
  882. POSITION posCurrent = *pPOS;
  883. if (szName)
  884. {
  885. // caller wants entry name
  886. size_t nNameLenChars = _tcslen(pIteratorAccessor->m_VariableName);
  887. if (dwLen > nNameLenChars)
  888. {
  889. _ATLTRY
  890. {
  891. CT2CA szVarName(pIteratorAccessor->m_VariableName);
  892. strcpy(szName, szVarName);
  893. }
  894. _ATLCATCHALL()
  895. {
  896. hr = E_OUTOFMEMORY;
  897. }
  898. }
  899. else
  900. hr = E_OUTOFMEMORY; // buffer not big enough
  901. }
  902. if (hr == S_OK)
  903. {
  904. CStreamOnByteArray stream(pIteratorAccessor->m_VariableValue);
  905. CComVariant vOut;
  906. hr = vOut.ReadFromStream(static_cast<IStream*>(&stream));
  907. if (hr == S_OK)
  908. vOut.Detach(pVal);
  909. else
  910. return hr;
  911. }
  912. else
  913. return hr;
  914. hr = pIteratorAccessor->MoveNext();
  915. *pPOS = ++posCurrent;
  916. if (hr == DB_S_ENDOFROWSET)
  917. {
  918. // We're done iterating, reset everything
  919. *pPOS = INVALID_DB_SESSION_POS;
  920. hr = S_OK;
  921. }
  922. if (hr != S_OK)
  923. {
  924. VariantClear(pVal);
  925. }
  926. return hr;
  927. }
  928. // CloseEnum frees up any resources allocated by the iterator
  929. STDMETHOD(CloseEnum)(HSESSIONENUM hEnum) throw()
  930. {
  931. iterator_accessor *pIteratorAccessor = reinterpret_cast<iterator_accessor*>(hEnum);
  932. if (!pIteratorAccessor)
  933. return E_INVALIDARG;
  934. pIteratorAccessor->Close();
  935. delete pIteratorAccessor;
  936. return S_OK;
  937. }
  938. //
  939. // Returns S_FALSE if it's not expired
  940. // S_OK if it is expired and an error HRESULT
  941. // if an error occurred.
  942. STDMETHOD(IsExpired)() throw()
  943. {
  944. HRESULT hrRet = S_FALSE;
  945. HRESULT hr = E_UNEXPECTED;
  946. // Get the data connection for this thread.
  947. CDataConnection dataconn;
  948. hr = GetSessionConnection(&dataconn, m_spServiceProvider);
  949. if (hr != S_OK)
  950. return hr;
  951. CCommand<CAccessor<CSessionRefIsExpired> > command;
  952. hr = command.Assign(m_szSessionName);
  953. if (hr != S_OK)
  954. return hr;
  955. hr = command.Open(dataconn, m_QueryObj.GetSessionRefIsExpired(),
  956. NULL, NULL, DBGUID_DEFAULT, true);
  957. if (hr == S_OK)
  958. {
  959. if (S_OK == command.MoveFirst())
  960. {
  961. if (!_tcscmp(command.m_SessionIDOut, m_szSessionName))
  962. hrRet = S_OK;
  963. }
  964. }
  965. if (hr == S_OK)
  966. return hrRet;
  967. return hr;
  968. }
  969. STDMETHOD(SetTimeout)(unsigned __int64 dwNewTimeout) throw()
  970. {
  971. HRESULT hr = E_UNEXPECTED;
  972. // Get the data connection for this thread.
  973. CDataConnection dataconn;
  974. hr = GetSessionConnection(&dataconn, m_spServiceProvider);
  975. if (hr != S_OK)
  976. return hr;
  977. // allocate a command and set it's input parameters
  978. CCommand<CAccessor<CSessionRefUpdateTimeout> > command;
  979. hr = command.Assign(m_szSessionName, dwNewTimeout);
  980. if (hr != S_OK)
  981. return hr;
  982. hr = command.Open(dataconn, m_QueryObj.GetSessionRefUpdateTimeout(),
  983. NULL, NULL, DBGUID_DEFAULT, false);
  984. return hr;
  985. }
  986. // SessionLock increments the session reference count for this session.
  987. // If there is not a session by this name in the session references table,
  988. // a new session entry is created in the the table.
  989. HRESULT SessionLock() throw()
  990. {
  991. HRESULT hr = E_UNEXPECTED;
  992. if (!m_szSessionName || m_szSessionName[0]==0)
  993. return hr; // no session to lock.
  994. // retrieve the data connection for this thread
  995. CDataConnection dataconn;
  996. hr = GetSessionConnection(&dataconn, m_spServiceProvider);
  997. if (hr != S_OK)
  998. return hr;
  999. // first try to update a session with this name
  1000. LONG nRows = 0;
  1001. CCommand<CAccessor<CSessionRefUpdator> > updator;
  1002. if (S_OK == updator.Assign(m_szSessionName))
  1003. {
  1004. if (S_OK != (hr = updator.Open(dataconn, m_QueryObj.GetSessionRefAddRef(),
  1005. NULL, &nRows, DBGUID_DEFAULT, false)) ||
  1006. nRows == 0)
  1007. {
  1008. // No session to update. Use the creator accessor
  1009. // to create a new session reference.
  1010. CCommand<CAccessor<CSessionRefCreator> > creator;
  1011. hr = creator.Assign(m_szSessionName, m_dwTimeout);
  1012. if (hr == S_OK)
  1013. hr = creator.Open(dataconn, m_QueryObj.GetSessionRefCreate(),
  1014. NULL, &nRows, DBGUID_DEFAULT, false);
  1015. }
  1016. }
  1017. // We should have been able to create or update a session.
  1018. ATLASSERT(nRows > 0);
  1019. if (hr == S_OK && nRows <= 0)
  1020. hr = E_UNEXPECTED;
  1021. return hr;
  1022. }
  1023. // SessionUnlock decrements the session RefCount for this session.
  1024. // Sessions cannot be removed from the database unless the session
  1025. // refcount is 0
  1026. HRESULT SessionUnlock() throw()
  1027. {
  1028. HRESULT hr = E_UNEXPECTED;
  1029. if (!m_szSessionName ||
  1030. m_szSessionName[0]==0)
  1031. return hr;
  1032. // get the data connection for this thread
  1033. CDataConnection dataconn;
  1034. hr = GetSessionConnection(&dataconn, m_spServiceProvider);
  1035. if (hr != S_OK)
  1036. return hr;
  1037. // The session must exist at this point in order to unlock it
  1038. // so we can just use the session updator here.
  1039. LONG nRows = 0;
  1040. CCommand<CAccessor<CSessionRefUpdator> > updator;
  1041. hr = updator.Assign(m_szSessionName);
  1042. if (hr == S_OK)
  1043. {
  1044. hr = updator.Open( dataconn,
  1045. m_QueryObj.GetSessionRefRemoveRef(),
  1046. NULL,
  1047. &nRows,
  1048. DBGUID_DEFAULT,
  1049. false);
  1050. }
  1051. if (hr != S_OK)
  1052. return hr;
  1053. // delete the session from the database if
  1054. // nobody else is using it and it's expired.
  1055. hr = FreeSession();
  1056. return hr;
  1057. }
  1058. // Access updates the last access time for the session. The access
  1059. // time for sessions is updated using the SQL GETDATE function on the
  1060. // database server so that all clients will be using the same clock
  1061. // to compare access times against.
  1062. HRESULT Access() throw()
  1063. {
  1064. HRESULT hr = E_UNEXPECTED;
  1065. if (!m_szSessionName ||
  1066. m_szSessionName[0]==0)
  1067. return hr; // no session to access
  1068. // get the data connection for this thread
  1069. CDataConnection dataconn;
  1070. hr = GetSessionConnection(&dataconn, m_spServiceProvider);
  1071. if (hr != S_OK)
  1072. return hr;
  1073. // The session reference entry in the references table must
  1074. // be created prior to calling this function so we can just
  1075. // use an updator to update the current entry.
  1076. CCommand<CAccessor<CSessionRefUpdator> > updator;
  1077. LONG nRows = 0;
  1078. hr = updator.Assign(m_szSessionName);
  1079. if (hr == S_OK)
  1080. {
  1081. hr = updator.Open( dataconn,
  1082. m_QueryObj.GetSessionRefAccess(),
  1083. NULL,
  1084. &nRows,
  1085. DBGUID_DEFAULT,
  1086. false);
  1087. }
  1088. ATLASSERT(nRows > 0);
  1089. if (hr == S_OK && nRows <= 0)
  1090. hr = E_UNEXPECTED;
  1091. return hr;
  1092. }
  1093. // If the session is expired and it's reference is 0,
  1094. // it can be deleted. SessionUnlock calls this function to
  1095. // unlock the session and delete it after we release a session
  1096. // lock. Note that our SQL command will only delete the session
  1097. // if it is expired and it's refcount is <= 0
  1098. HRESULT FreeSession() throw()
  1099. {
  1100. HRESULT hr = E_UNEXPECTED;
  1101. if (!m_szSessionName ||
  1102. m_szSessionName[0]==0)
  1103. return hr;
  1104. // Get the data connection for this thread.
  1105. CDataConnection dataconn;
  1106. hr = GetSessionConnection(&dataconn, m_spServiceProvider);
  1107. if (hr != S_OK)
  1108. return hr;
  1109. CCommand<CAccessor<CSessionRefUpdator> > updator;
  1110. // The SQL for this command only deletes the
  1111. // session reference from the references table if it's access
  1112. // count is 0 and it has expired.
  1113. return updator.Open(dataconn,
  1114. m_QueryObj.GetSessionRefDelete(),
  1115. NULL,
  1116. NULL,
  1117. DBGUID_DEFAULT,
  1118. false);
  1119. }
  1120. // Initialize is called each time a new session is created.
  1121. HRESULT Initialize( LPCSTR szSessionName,
  1122. IServiceProvider *pServiceProvider,
  1123. DWORD_PTR dwCookie,
  1124. PFN_GETPROVIDERINFO pfnInfo) throw()
  1125. {
  1126. if (!szSessionName)
  1127. return E_INVALIDARG;
  1128. if (!pServiceProvider)
  1129. return E_INVALIDARG;
  1130. if (!pfnInfo)
  1131. return E_INVALIDARG;
  1132. m_pfnInfo = pfnInfo;
  1133. m_dwProvCookie = dwCookie;
  1134. m_spServiceProvider = pServiceProvider;
  1135. _ATLTRY
  1136. {
  1137. CA2CT tcsSessionName(szSessionName);
  1138. if (_tcslen(tcsSessionName) < MAX_SESSION_KEY_LEN)
  1139. _tcscpy(m_szSessionName, tcsSessionName);
  1140. else
  1141. return E_OUTOFMEMORY;
  1142. }
  1143. _ATLCATCHALL()
  1144. {
  1145. return E_OUTOFMEMORY;
  1146. }
  1147. return SessionLock();
  1148. }
  1149. HRESULT GetSessionConnection(CDataConnection *pConn,
  1150. IServiceProvider *pProv) throw()
  1151. {
  1152. if (!pProv)
  1153. return E_INVALIDARG;
  1154. if (!m_pfnInfo ||
  1155. !m_dwProvCookie)
  1156. return E_UNEXPECTED;
  1157. wchar_t *wszProv = NULL;
  1158. if (m_pfnInfo(m_dwProvCookie, &wszProv) && wszProv!=NULL)
  1159. {
  1160. return GetDataSource(pProv,
  1161. ATL_DBSESSION_ID,
  1162. wszProv,
  1163. pConn);
  1164. }
  1165. return E_FAIL;
  1166. }
  1167. protected:
  1168. TCHAR m_szSessionName[MAX_SESSION_KEY_LEN];
  1169. unsigned __int64 m_dwTimeout;
  1170. CComPtr<IServiceProvider> m_spServiceProvider;
  1171. DWORD_PTR m_dwProvCookie;
  1172. PFN_GETPROVIDERINFO m_pfnInfo;
  1173. DBQUERYCLASS_TYPE m_QueryObj;
  1174. }; // CDBSession
  1175. template <class TDBSession=CDBSession<> >
  1176. class CDBSessionServiceImplT
  1177. {
  1178. wchar_t m_szConnectionString[MAX_CONNECTION_STRING_LEN];
  1179. CComPtr<IServiceProvider> m_spServiceProvider;
  1180. TDBSession::DBQUERYCLASS_TYPE m_QueryObj;
  1181. public:
  1182. typedef const wchar_t* SERVICEIMPL_INITPARAM_TYPE;
  1183. CDBSessionServiceImplT() throw()
  1184. {
  1185. m_dwTimeout = ATL_SESSION_TIMEOUT;
  1186. m_szConnectionString[0] = '\0';
  1187. }
  1188. static bool GetProviderInfo(DWORD_PTR dwProvCookie, wchar_t **ppszProvInfo) throw()
  1189. {
  1190. if (dwProvCookie &&
  1191. ppszProvInfo)
  1192. {
  1193. CDBSessionServiceImplT<TDBSession> *pSvc =
  1194. reinterpret_cast<CDBSessionServiceImplT<TDBSession>*>(dwProvCookie);
  1195. *ppszProvInfo = pSvc->m_szConnectionString;
  1196. return true;
  1197. }
  1198. return false;
  1199. }
  1200. HRESULT GetSessionConnection(CDataConnection *pConn,
  1201. IServiceProvider *pProv) throw()
  1202. {
  1203. if (!pProv)
  1204. return E_INVALIDARG;
  1205. if(!m_szConnectionString[0])
  1206. return E_UNEXPECTED;
  1207. return GetDataSource(pProv,
  1208. ATL_DBSESSION_ID,
  1209. m_szConnectionString,
  1210. pConn);
  1211. }
  1212. HRESULT Initialize(SERVICEIMPL_INITPARAM_TYPE pData,
  1213. IServiceProvider *pProvider,
  1214. unsigned __int64 dwInitialTimeout) throw()
  1215. {
  1216. if (!pData || !pProvider)
  1217. return E_INVALIDARG;
  1218. if (wcslen(pData) < MAX_CONNECTION_STRING_LEN)
  1219. {
  1220. wcscpy(m_szConnectionString, pData);
  1221. }
  1222. else
  1223. return E_OUTOFMEMORY;
  1224. m_dwTimeout = dwInitialTimeout;
  1225. m_spServiceProvider = pProvider;
  1226. return S_OK;
  1227. }
  1228. HRESULT CreateNewSession(LPSTR szNewID, DWORD *pdwSize, ISession** ppSession) throw()
  1229. {
  1230. HRESULT hr = E_FAIL;
  1231. CComObject<TDBSession> *pNewSession = NULL;
  1232. if (!pdwSize)
  1233. return E_INVALIDARG;
  1234. if (ppSession)
  1235. *ppSession = NULL;
  1236. else
  1237. return E_POINTER;
  1238. if (szNewID)
  1239. *szNewID = NULL;
  1240. else
  1241. return E_INVALIDARG;
  1242. // Create new session
  1243. CComObject<TDBSession>::CreateInstance(&pNewSession);
  1244. if (pNewSession == NULL)
  1245. return E_OUTOFMEMORY;
  1246. // Create a session name and initialize the object
  1247. hr = m_SessionNameGenerator.GetNewSessionName(szNewID, pdwSize);
  1248. if (hr == S_OK)
  1249. {
  1250. hr = pNewSession->Initialize(szNewID,
  1251. m_spServiceProvider,
  1252. reinterpret_cast<DWORD_PTR>(this),
  1253. GetProviderInfo);
  1254. if (hr == S_OK)
  1255. {
  1256. // we don't hold a reference to the object
  1257. hr = pNewSession->QueryInterface(ppSession);
  1258. }
  1259. }
  1260. if (hr != S_OK)
  1261. delete pNewSession;
  1262. return hr;
  1263. }
  1264. HRESULT GetSession(LPCSTR szID, ISession **ppSession) throw()
  1265. {
  1266. HRESULT hr = E_FAIL;
  1267. if (!szID)
  1268. return E_INVALIDARG;
  1269. if (ppSession)
  1270. *ppSession = NULL;
  1271. else
  1272. return E_POINTER;
  1273. CComObject<TDBSession> *pNewSession = NULL;
  1274. // Check the DB to see if the session ID is a valid session
  1275. _ATLTRY
  1276. {
  1277. CA2CT session(szID);
  1278. hr = IsValidSession(session);
  1279. }
  1280. _ATLCATCHALL()
  1281. {
  1282. hr = E_OUTOFMEMORY;
  1283. }
  1284. if (hr == S_OK)
  1285. {
  1286. // Create new session object to represent this session
  1287. CComObject<TDBSession>::CreateInstance(&pNewSession);
  1288. if (pNewSession == NULL)
  1289. return E_OUTOFMEMORY;
  1290. hr = pNewSession->Initialize(szID,
  1291. m_spServiceProvider,
  1292. reinterpret_cast<DWORD_PTR>(this),
  1293. GetProviderInfo);
  1294. if (hr == S_OK)
  1295. {
  1296. // we don't hold a reference to the object
  1297. hr = pNewSession->QueryInterface(ppSession);
  1298. }
  1299. }
  1300. if (hr != S_OK && pNewSession)
  1301. delete pNewSession;
  1302. return hr;
  1303. }
  1304. HRESULT CloseSession(LPCSTR szID) throw()
  1305. {
  1306. if (!szID)
  1307. return E_INVALIDARG;
  1308. CDataConnection conn;
  1309. HRESULT hr = GetSessionConnection(&conn,
  1310. m_spServiceProvider);
  1311. if (hr != S_OK)
  1312. return hr;
  1313. // set up accessors
  1314. CCommand<CAccessor<CSessionRefUpdator> > updator;
  1315. CCommand<CAccessor<CSessionDataDeleteAll> > command;
  1316. _ATLTRY
  1317. {
  1318. CA2CT session(szID);
  1319. hr = updator.Assign(session);
  1320. if (hr == S_OK)
  1321. hr = command.Assign(session);
  1322. }
  1323. _ATLCATCHALL()
  1324. {
  1325. hr = E_OUTOFMEMORY;
  1326. }
  1327. if (hr == S_OK)
  1328. {
  1329. // delete all session variables
  1330. hr = command.Open(conn,
  1331. m_QueryObj.GetSessionVarDeleteAllVars(),
  1332. NULL,
  1333. NULL,
  1334. DBGUID_DEFAULT,
  1335. false);
  1336. if (hr == S_OK)
  1337. {
  1338. // delete references in the session references table
  1339. hr = updator.Open(conn,
  1340. m_QueryObj.GetSessionRefDeleteFinal(),
  1341. NULL,
  1342. NULL,
  1343. DBGUID_DEFAULT,
  1344. false);
  1345. }
  1346. }
  1347. return hr;
  1348. }
  1349. HRESULT SetSessionTimeout(unsigned __int64 nTimeout) throw()
  1350. {
  1351. // Get the data connection for this thread
  1352. CDataConnection conn;
  1353. HRESULT hr = GetSessionConnection(&conn, m_spServiceProvider);
  1354. if (hr != S_OK)
  1355. return hr;
  1356. // all sessions get the same timeout
  1357. CCommand<CAccessor<CSetAllTimeouts> > command;
  1358. hr = command.Assign(nTimeout);
  1359. if (hr == S_OK)
  1360. {
  1361. hr = command.Open(conn, m_QueryObj.GetSessionReferencesSet(),
  1362. NULL,
  1363. NULL,
  1364. DBGUID_DEFAULT,
  1365. false);
  1366. if (hr == S_OK)
  1367. {
  1368. m_dwTimeout = nTimeout;
  1369. }
  1370. }
  1371. return hr;
  1372. }
  1373. HRESULT GetSessionTimeout(unsigned __int64* pnTimeout) throw()
  1374. {
  1375. if (pnTimeout)
  1376. *pnTimeout = m_dwTimeout;
  1377. else
  1378. return E_INVALIDARG;
  1379. return S_OK;
  1380. }
  1381. HRESULT GetSessionCount(DWORD *pnCount) throw()
  1382. {
  1383. if (pnCount)
  1384. *pnCount = 0;
  1385. else
  1386. return E_INVALIDARG;
  1387. CCommand<CAccessor<CSessionRefCount> > command;
  1388. CDataConnection conn;
  1389. HRESULT hr = GetSessionConnection(&conn,
  1390. m_spServiceProvider);
  1391. if (hr != S_OK)
  1392. return hr;
  1393. hr = command.Open(conn,
  1394. m_QueryObj.GetSessionRefGetCount());
  1395. if (hr == S_OK)
  1396. {
  1397. hr = command.MoveFirst();
  1398. if (hr == S_OK)
  1399. {
  1400. *pnCount = (DWORD)command.m_nCount;
  1401. }
  1402. }
  1403. return hr;
  1404. }
  1405. void ReleaseAllSessions() throw()
  1406. {
  1407. // nothing to do
  1408. }
  1409. void SweepSessions() throw()
  1410. {
  1411. // nothing to do
  1412. }
  1413. // Helpers
  1414. HRESULT IsValidSession(LPCTSTR szID) throw()
  1415. {
  1416. if (!szID)
  1417. return E_INVALIDARG;
  1418. // Look in the sessionreferences table to see if there is an entry
  1419. // for this session.
  1420. if (m_szConnectionString[0] == 0)
  1421. return E_UNEXPECTED;
  1422. CDataConnection conn;
  1423. HRESULT hr = GetSessionConnection(&conn,
  1424. m_spServiceProvider);
  1425. if (hr != S_OK)
  1426. return hr;
  1427. // Check the session references table to see if
  1428. // this is a valid session
  1429. CCommand<CAccessor<CSessionRefSelector> > selector;
  1430. hr = selector.Assign(szID);
  1431. if (hr != S_OK)
  1432. return hr;
  1433. // The SQL for this command only deletes the
  1434. // session reference from the references table if it's access
  1435. // count is 0 and it has expired.
  1436. hr = selector.Open(conn,
  1437. m_QueryObj.GetSessionRefSelect(),
  1438. NULL,
  1439. NULL,
  1440. DBGUID_DEFAULT,
  1441. true);
  1442. if (hr == S_OK)
  1443. return selector.MoveFirst();
  1444. return hr;
  1445. }
  1446. CSessionNameGenerator m_SessionNameGenerator; // Object for generating session names
  1447. unsigned __int64 m_dwTimeout;
  1448. }; // CDBSessionServiceImplT
  1449. typedef CDBSessionServiceImplT<> CDBSessionServiceImpl;
  1450. //////////////////////////////////////////////////////////////////
  1451. //
  1452. // In-memory persisted session
  1453. //
  1454. //////////////////////////////////////////////////////////////////
  1455. // In-memory persisted session service keeps a pointer
  1456. // to the session obejct around in memory. The pointer is
  1457. // contained in a CComPtr, which is stored in a CAtlMap, so
  1458. // we have to have a CElementTraits class for that.
  1459. typedef CComPtr<ISession> SESSIONPTRTYPE;
  1460. template<>
  1461. class CElementTraits<SESSIONPTRTYPE> :
  1462. public CElementTraitsBase<SESSIONPTRTYPE>
  1463. {
  1464. public:
  1465. static ULONG Hash( INARGTYPE obj ) throw()
  1466. {
  1467. return( (ULONG)(ULONG_PTR)obj.p);
  1468. }
  1469. static BOOL CompareElements( OUTARGTYPE element1, OUTARGTYPE element2 ) throw()
  1470. {
  1471. return element1.IsEqualObject(element2.p) ? TRUE : FALSE;
  1472. }
  1473. static int CompareElementsOrdered( INARGTYPE , INARGTYPE ) throw()
  1474. {
  1475. ATLASSERT(0); // NOT IMPLEMENTED
  1476. return 0;
  1477. }
  1478. };
  1479. // CMemSession
  1480. // This session persistance class persists session variables in memory.
  1481. // Note that this type of persistance should only be used on single server
  1482. // web sites.
  1483. class CMemSession :
  1484. public ISession,
  1485. public CComObjectRootEx<CComGlobalsThreadModel>
  1486. {
  1487. public:
  1488. BEGIN_COM_MAP(CMemSession)
  1489. COM_INTERFACE_ENTRY(ISession)
  1490. END_COM_MAP()
  1491. CMemSession() throw(...)
  1492. {
  1493. }
  1494. STDMETHOD(GetVariable)(LPCSTR szName, VARIANT *pVal) throw()
  1495. {
  1496. if (!szName)
  1497. return E_INVALIDARG;
  1498. if (pVal)
  1499. VariantInit(pVal);
  1500. else
  1501. return E_POINTER;
  1502. HRESULT hr = Access();
  1503. if (hr == S_OK)
  1504. {
  1505. CSLockType lock(m_cs, false);
  1506. hr = lock.Lock();
  1507. if (FAILED(hr))
  1508. return hr;
  1509. _ATLTRY
  1510. {
  1511. CComVariant val;
  1512. if (m_Variables.Lookup(szName, val))
  1513. {
  1514. hr = VariantCopy(pVal, &val);
  1515. }
  1516. }
  1517. _ATLCATCHALL()
  1518. {
  1519. hr = E_UNEXPECTED;
  1520. }
  1521. }
  1522. return hr;
  1523. }
  1524. STDMETHOD(SetVariable)(LPCSTR szName, VARIANT vNewVal) throw()
  1525. {
  1526. if (!szName)
  1527. return E_INVALIDARG;
  1528. HRESULT hr = Access();
  1529. if (hr == S_OK)
  1530. {
  1531. CSLockType lock(m_cs, false);
  1532. hr = lock.Lock();
  1533. if (FAILED(hr))
  1534. return hr;
  1535. _ATLTRY
  1536. {
  1537. hr = m_Variables.SetAt(szName, vNewVal) ? S_OK : E_FAIL;
  1538. }
  1539. _ATLCATCHALL()
  1540. {
  1541. hr = E_UNEXPECTED;
  1542. }
  1543. }
  1544. return hr;
  1545. }
  1546. STDMETHOD(RemoveVariable)(LPCSTR szName) throw()
  1547. {
  1548. if (!szName)
  1549. return E_INVALIDARG;
  1550. HRESULT hr = Access();
  1551. if (hr == S_OK)
  1552. {
  1553. CSLockType lock(m_cs, false);
  1554. hr = lock.Lock();
  1555. if (FAILED(hr))
  1556. return hr;
  1557. _ATLTRY
  1558. {
  1559. hr = m_Variables.RemoveKey(szName) ? S_OK : E_FAIL;
  1560. }
  1561. _ATLCATCHALL()
  1562. {
  1563. hr = E_UNEXPECTED;
  1564. }
  1565. }
  1566. return hr;
  1567. }
  1568. STDMETHOD(GetCount)(long *pnCount) throw()
  1569. {
  1570. if (pnCount)
  1571. return *pnCount = 0;
  1572. else
  1573. return E_POINTER;
  1574. HRESULT hr = Access();
  1575. if (hr == S_OK)
  1576. {
  1577. CSLockType lock(m_cs, false);
  1578. hr = lock.Lock();
  1579. if (FAILED(hr))
  1580. return hr;
  1581. *pnCount = (long) m_Variables.GetCount();
  1582. }
  1583. return hr;
  1584. }
  1585. STDMETHOD(RemoveAllVariables)() throw()
  1586. {
  1587. HRESULT hr = Access();
  1588. if (hr == S_OK)
  1589. {
  1590. CSLockType lock(m_cs, false);
  1591. hr = lock.Lock();
  1592. if (FAILED(hr))
  1593. return hr;
  1594. m_Variables.RemoveAll();
  1595. }
  1596. return hr;
  1597. }
  1598. STDMETHOD(BeginVariableEnum)(HSESSIONENUM *phEnumHandle, POSITION *pPOS) throw()
  1599. {
  1600. if (phEnumHandle)
  1601. *phEnumHandle = NULL;
  1602. else
  1603. return E_POINTER;
  1604. if (pPOS)
  1605. *pPOS = NULL;
  1606. else
  1607. return E_POINTER;
  1608. HRESULT hr = Access();
  1609. if (hr == S_OK)
  1610. {
  1611. CSLockType lock(m_cs, false);
  1612. hr = lock.Lock();
  1613. if (FAILED(hr))
  1614. return hr;
  1615. *pPOS = m_Variables.GetStartPosition();
  1616. }
  1617. return hr;
  1618. }
  1619. STDMETHOD(GetNextVariable)(HSESSIONENUM /*hEnum*/,
  1620. POSITION *pPOS, LPSTR szName,
  1621. DWORD dwLen, VARIANT *pVal) throw()
  1622. {
  1623. if (!szName)
  1624. return E_INVALIDARG;
  1625. if (pVal)
  1626. VariantInit(pVal);
  1627. else
  1628. return E_POINTER;
  1629. if (!pPOS)
  1630. return E_POINTER;
  1631. CComVariant val;
  1632. POSITION pos = *pPOS;
  1633. HRESULT hr = Access();
  1634. if (hr == S_OK)
  1635. {
  1636. CSLockType lock(m_cs, false);
  1637. hr = lock.Lock();
  1638. if (FAILED(hr))
  1639. return hr;
  1640. _ATLTRY
  1641. {
  1642. CStringA strName = m_Variables.GetKeyAt(pos);
  1643. if (strName.GetLength())
  1644. {
  1645. if (dwLen > (DWORD)strName.GetLength())
  1646. strcpy(szName, strName);
  1647. else
  1648. hr = E_OUTOFMEMORY;
  1649. }
  1650. if (hr == S_OK)
  1651. {
  1652. val = m_Variables.GetNextValue(pos);
  1653. hr = VariantCopy(pVal, &val);
  1654. if (hr == S_OK)
  1655. *pPOS = pos;
  1656. }
  1657. }
  1658. _ATLCATCHALL()
  1659. {
  1660. hr = E_UNEXPECTED;
  1661. }
  1662. }
  1663. return hr;
  1664. }
  1665. STDMETHOD(CloseEnum)(HSESSIONENUM /*hEnumHandle*/) throw()
  1666. {
  1667. return S_OK;
  1668. }
  1669. STDMETHOD(IsExpired)() throw()
  1670. {
  1671. CTime tmNow = CTime::GetCurrentTime();
  1672. CTimeSpan span = tmNow-m_tLastAccess;
  1673. if ((unsigned __int64)((span.GetTotalSeconds()*1000)) > m_dwTimeout)
  1674. return S_OK;
  1675. return S_FALSE;
  1676. }
  1677. HRESULT Access() throw()
  1678. {
  1679. // We lock here to protect against multiple threads
  1680. // updating the same member concurrently.
  1681. CSLockType lock(m_cs, false);
  1682. HRESULT hr = lock.Lock();
  1683. if (FAILED(hr))
  1684. return hr;
  1685. m_tLastAccess = CTime::GetCurrentTime();
  1686. return S_OK;
  1687. }
  1688. STDMETHOD(SetTimeout)(unsigned __int64 dwNewTimeout) throw()
  1689. {
  1690. // We lock here to protect against multiple threads
  1691. // updating the same member concurrently
  1692. CSLockType lock(m_cs, false);
  1693. HRESULT hr = lock.Lock();
  1694. if (FAILED(hr))
  1695. return hr;
  1696. m_dwTimeout = dwNewTimeout;
  1697. return S_OK;
  1698. }
  1699. HRESULT SessionLock() throw()
  1700. {
  1701. Access();
  1702. return S_OK;
  1703. }
  1704. HRESULT SessionUnlock() throw()
  1705. {
  1706. return S_OK;
  1707. }
  1708. protected:
  1709. typedef CAtlMap<CStringA,
  1710. CComVariant,
  1711. CStringElementTraits<CStringA> > VarMapType;
  1712. unsigned __int64 m_dwTimeout;
  1713. CTime m_tLastAccess;
  1714. VarMapType m_Variables;
  1715. CComAutoCriticalSection m_cs;
  1716. typedef CComCritSecLock<CComAutoCriticalSection> CSLockType;
  1717. }; // CMemSession
  1718. //
  1719. // CMemSessionServiceImpl
  1720. // Implements the service part of in-memory persisted session services.
  1721. //
  1722. class CMemSessionServiceImpl
  1723. {
  1724. public:
  1725. typedef void* SERVICEIMPL_INITPARAM_TYPE;
  1726. CMemSessionServiceImpl() throw()
  1727. {
  1728. m_dwTimeout = ATL_SESSION_TIMEOUT;
  1729. }
  1730. HRESULT CreateNewSession(LPSTR szNewID, DWORD *pdwSize, ISession** ppSession) throw()
  1731. {
  1732. HRESULT hr = E_FAIL;
  1733. CComObject<CMemSession> *pNewSession = NULL;
  1734. if (!szNewID)
  1735. return E_INVALIDARG;
  1736. if (!pdwSize)
  1737. return E_POINTER;
  1738. if (ppSession)
  1739. *ppSession = NULL;
  1740. else
  1741. return E_POINTER;
  1742. _ATLTRY
  1743. {
  1744. // Create new session
  1745. CComObject<CMemSession>::CreateInstance(&pNewSession);
  1746. if (pNewSession == NULL)
  1747. return E_OUTOFMEMORY;
  1748. // Initialize and add to list of CSessionData
  1749. hr = m_SessionNameGenerator.GetNewSessionName(szNewID, pdwSize);
  1750. if (SUCCEEDED(hr))
  1751. {
  1752. CComPtr<ISession> spSession;
  1753. hr = pNewSession->QueryInterface(&spSession);
  1754. if (SUCCEEDED(hr))
  1755. {
  1756. pNewSession->SetTimeout(m_dwTimeout);
  1757. pNewSession->Access();
  1758. CSLockType lock(m_CritSec, false);
  1759. hr = lock.Lock();
  1760. if (FAILED(hr))
  1761. return hr;
  1762. m_Sessions.SetAt(szNewID, spSession);
  1763. *ppSession = spSession.Detach();
  1764. }
  1765. }
  1766. }
  1767. _ATLCATCHALL()
  1768. {
  1769. hr = E_UNEXPECTED;
  1770. }
  1771. return hr;
  1772. }
  1773. HRESULT GetSession(LPCSTR szID, ISession **ppSession) throw()
  1774. {
  1775. HRESULT hr = E_FAIL;
  1776. SessMapType::CPair *pPair = NULL;
  1777. if (ppSession)
  1778. *ppSession = NULL;
  1779. else
  1780. return E_POINTER;
  1781. if (!szID)
  1782. return E_INVALIDARG;
  1783. CSLockType lock(m_CritSec, false);
  1784. hr = lock.Lock();
  1785. if (FAILED(hr))
  1786. return hr;
  1787. _ATLTRY
  1788. {
  1789. pPair = m_Sessions.Lookup(szID);
  1790. if (pPair) // the session exists and is in our local map of sessions
  1791. {
  1792. hr = pPair->m_value.QueryInterface(ppSession);
  1793. }
  1794. }
  1795. _ATLCATCHALL()
  1796. {
  1797. return E_UNEXPECTED;
  1798. }
  1799. return hr;
  1800. }
  1801. HRESULT CloseSession(LPCSTR szID) throw()
  1802. {
  1803. if (!szID)
  1804. return E_INVALIDARG;
  1805. HRESULT hr = E_FAIL;
  1806. CSLockType lock(m_CritSec, false);
  1807. hr = lock.Lock();
  1808. if (FAILED(hr))
  1809. return hr;
  1810. _ATLTRY
  1811. {
  1812. hr = m_Sessions.RemoveKey(szID) ? S_OK : E_FAIL;
  1813. }
  1814. _ATLCATCHALL()
  1815. {
  1816. hr = E_UNEXPECTED;
  1817. }
  1818. return hr;
  1819. }
  1820. void SweepSessions() throw()
  1821. {
  1822. POSITION posRemove = NULL;
  1823. const SessMapType::CPair *pPair = NULL;
  1824. POSITION pos = NULL;
  1825. CSLockType lock(m_CritSec, false);
  1826. if (FAILED(lock.Lock()))
  1827. return;
  1828. pos = m_Sessions.GetStartPosition();
  1829. while (pos)
  1830. {
  1831. posRemove = pos;
  1832. pPair = m_Sessions.GetNext(pos);
  1833. if (pPair)
  1834. {
  1835. if (pPair->m_value.p &&
  1836. S_OK == pPair->m_value->IsExpired())
  1837. {
  1838. // remove our reference on the session
  1839. m_Sessions.RemoveAtPos(posRemove);
  1840. }
  1841. }
  1842. }
  1843. }
  1844. HRESULT SetSessionTimeout(unsigned __int64 nTimeout) throw()
  1845. {
  1846. HRESULT hr = S_OK;
  1847. CComPtr<ISession> spSession;
  1848. m_dwTimeout = nTimeout;
  1849. POSITION pos = m_Sessions.GetStartPosition();
  1850. CSLockType lock(m_CritSec, false);
  1851. hr = lock.Lock();
  1852. if (FAILED(hr))
  1853. return hr;
  1854. while (pos)
  1855. {
  1856. SessMapType::CPair *pPair = const_cast<SessMapType::CPair*>(m_Sessions.GetNext(pos));
  1857. if (pPair)
  1858. {
  1859. spSession = pPair->m_value;
  1860. if (spSession)
  1861. {
  1862. // if we fail on any of the sets we will return the
  1863. // error code immediately
  1864. hr = spSession->SetTimeout(nTimeout);
  1865. spSession.Release();
  1866. if (hr != S_OK)
  1867. break;
  1868. }
  1869. }
  1870. }
  1871. return hr;
  1872. }
  1873. HRESULT GetSessionTimeout(unsigned __int64* pnTimeout) throw()
  1874. {
  1875. if (pnTimeout)
  1876. *pnTimeout = m_dwTimeout;
  1877. else
  1878. return E_POINTER;
  1879. return S_OK;
  1880. }
  1881. HRESULT GetSessionCount(DWORD *pnCount) throw()
  1882. {
  1883. if (pnCount)
  1884. *pnCount = 0;
  1885. else
  1886. return E_POINTER;
  1887. CSLockType lock(m_CritSec, false);
  1888. HRESULT hr = lock.Lock();
  1889. if (FAILED(hr))
  1890. return hr;
  1891. *pnCount = (DWORD)m_Sessions.GetCount();
  1892. return S_OK;
  1893. }
  1894. void ReleaseAllSessions() throw()
  1895. {
  1896. CSLockType lock(m_CritSec, false);
  1897. if (FAILED(lock.Lock()))
  1898. return;
  1899. m_Sessions.RemoveAll();
  1900. }
  1901. HRESULT Initialize(SERVICEIMPL_INITPARAM_TYPE,
  1902. IServiceProvider*,
  1903. unsigned __int64 dwNewTimeout) throw()
  1904. {
  1905. m_dwTimeout = dwNewTimeout;
  1906. return m_CritSec.Init();
  1907. }
  1908. typedef CAtlMap<CStringA,
  1909. SESSIONPTRTYPE,
  1910. CStringElementTraits<CStringA>,
  1911. CElementTraitsBase<SESSIONPTRTYPE> > SessMapType;
  1912. SessMapType m_Sessions; // map for holding sessions in memory
  1913. CComCriticalSection m_CritSec; // for synchronizing access to map
  1914. typedef CComCritSecLock<CComCriticalSection> CSLockType;
  1915. CSessionNameGenerator m_SessionNameGenerator; // Object for generating session names
  1916. unsigned __int64 m_dwTimeout;
  1917. }; // CMemSessionServiceImpl
  1918. //
  1919. // CSessionStateService
  1920. // This class implements the session state service which can be
  1921. // exposed to request handlers.
  1922. //
  1923. // Template Parameters:
  1924. // CMonitorClass: Provides periodic sweeping services for the session service class.
  1925. // TServiceImplClass: The class that actually implements the methods of the
  1926. // ISessionStateService and ISessionStateControl interfaces.
  1927. template <class CMonitorClass, class TServiceImplClass >
  1928. class CSessionStateService :
  1929. public ISessionStateService,
  1930. public ISessionStateControl,
  1931. public IWorkerThreadClient,
  1932. public CComObjectRootEx<CComGlobalsThreadModel>
  1933. {
  1934. protected:
  1935. CMonitorClass m_Monitor;
  1936. HANDLE m_hTimer;
  1937. CComPtr<IServiceProvider> m_spServiceProvider;
  1938. TServiceImplClass m_SessionServiceImpl;
  1939. public:
  1940. // Construction/Initialization
  1941. CSessionStateService() throw() :
  1942. m_hTimer(NULL)
  1943. {
  1944. }
  1945. ~CSessionStateService() throw()
  1946. {
  1947. ATLASSERT(m_hTimer == NULL);
  1948. }
  1949. BEGIN_COM_MAP(CSessionStateService)
  1950. COM_INTERFACE_ENTRY(ISessionStateService)
  1951. COM_INTERFACE_ENTRY(ISessionStateControl)
  1952. END_COM_MAP()
  1953. // ISessionStateServie methods
  1954. STDMETHOD(CreateNewSession)(LPSTR szNewID, DWORD *pdwSize, ISession** ppSession) throw()
  1955. {
  1956. return m_SessionServiceImpl.CreateNewSession(szNewID, pdwSize, ppSession);
  1957. }
  1958. STDMETHOD(GetSession)(LPCSTR szID, ISession **ppSession) throw()
  1959. {
  1960. return m_SessionServiceImpl.GetSession(szID, ppSession);
  1961. }
  1962. STDMETHOD(CloseSession)(LPCSTR szSessionID) throw()
  1963. {
  1964. return m_SessionServiceImpl.CloseSession(szSessionID);
  1965. }
  1966. STDMETHOD(SetSessionTimeout)(unsigned __int64 nTimeout) throw()
  1967. {
  1968. return m_SessionServiceImpl.SetSessionTimeout(nTimeout);
  1969. }
  1970. STDMETHOD(GetSessionTimeout)(unsigned __int64 *pnTimeout) throw()
  1971. {
  1972. return m_SessionServiceImpl.GetSessionTimeout(pnTimeout);
  1973. }
  1974. STDMETHOD(GetSessionCount)(DWORD *pnSessionCount) throw()
  1975. {
  1976. return m_SessionServiceImpl.GetSessionCount(pnSessionCount);
  1977. }
  1978. void SweepSessions() throw()
  1979. {
  1980. m_SessionServiceImpl.SweepSessions();
  1981. }
  1982. void ReleaseAllSessions() throw()
  1983. {
  1984. m_SessionServiceImpl.ReleaseAllSessions();
  1985. }
  1986. HRESULT Initialize(
  1987. IServiceProvider *pServiceProvider = NULL,
  1988. unsigned __int64 dwTimeout = ATL_SESSION_TIMEOUT,
  1989. TServiceImplClass::SERVICEIMPL_INITPARAM_TYPE pInitData = NULL) throw()
  1990. {
  1991. HRESULT hr = S_OK;
  1992. if (pServiceProvider)
  1993. m_spServiceProvider = pServiceProvider;
  1994. hr = m_SessionServiceImpl.Initialize(pInitData, pServiceProvider, dwTimeout);
  1995. return hr;
  1996. }
  1997. template <class ThreadTraits>
  1998. HRESULT Initialize(
  1999. CWorkerThread<ThreadTraits> *pWorker,
  2000. IServiceProvider *pServiceProvider = NULL,
  2001. unsigned __int64 dwTimeout = ATL_SESSION_TIMEOUT,
  2002. TServiceImplClass::SERVICEIMPL_INITPARAM_TYPE pInitData = NULL) throw()
  2003. {
  2004. if (!pWorker)
  2005. return E_INVALIDARG;
  2006. HRESULT hr = Initialize(pServiceProvider, dwTimeout, pInitData);
  2007. if (hr == S_OK)
  2008. {
  2009. hr = m_Monitor.Initialize(pWorker);
  2010. if (hr == S_OK)
  2011. {
  2012. //sweep every 500ms
  2013. hr = m_Monitor.AddTimer(ATL_SESSION_SWEEPER_TIMEOUT, this, 0, &m_hTimer);
  2014. }
  2015. }
  2016. return hr;
  2017. }
  2018. HRESULT Execute(DWORD_PTR /*dwParam*/, HANDLE /*hObject*/) throw()
  2019. {
  2020. SweepSessions();
  2021. return S_OK;
  2022. }
  2023. HRESULT CloseHandle(HANDLE hHandle) throw()
  2024. {
  2025. ::CloseHandle(hHandle);
  2026. m_hTimer = NULL;
  2027. return S_OK;
  2028. }
  2029. void Shutdown() throw()
  2030. {
  2031. if (m_hTimer)
  2032. {
  2033. m_Monitor.RemoveHandle(m_hTimer);
  2034. m_hTimer = NULL;
  2035. }
  2036. ReleaseAllSessions();
  2037. }
  2038. }; // CSessionStateService
  2039. } // namespace ATL
  2040. #pragma warning(pop)
  2041. #endif // __ATLSESSION_H__