// This is a part of the Active Template Library. // Copyright (C) 1996-2001 Microsoft Corporation // All rights reserved. // // This source code is only intended as a supplement to the // Active Template Library Reference and related // electronic documentation provided with the library. // See these sources for detailed information regarding the // Active Template Library product. #ifndef __ATLSESSION_H__ #define __ATLSESSION_H__ #pragma once #pragma warning(push) #pragma warning(disable: 4702) // unreachable code #include #include #include #include #include #include #include #include #include #include #include #include #ifndef SESSION_KEY_LENGTH #define SESSION_KEY_LENGTH 37 #endif #ifndef MAX_SESSION_KEY_LEN #define MAX_SESSION_KEY_LEN 128 #endif #ifndef MAX_VARIABLE_NAME_LENGTH #define MAX_VARIABLE_NAME_LENGTH 50 #endif #ifndef MAX_VARIABLE_VALUE_LENGTH #define MAX_VARIABLE_VALUE_LENGTH 128 #endif #ifndef DEFAULT_SQL_LEN #define DEFAULT_SQL_LEN 1024 #endif #ifndef MAX_CONNECTION_STRING_LEN #define MAX_CONNECTION_STRING_LEN 2048 #endif #ifndef SESSION_COOKIE_NAME #define SESSION_COOKIE_NAME "SESSIONID" #endif #ifndef ATL_SESSION_TIMEOUT #define ATL_SESSION_TIMEOUT 600000 //10 min #endif #ifndef ATL_SESSION_SWEEPER_TIMEOUT #define ATL_SESSION_SWEEPER_TIMEOUT 1000 // 1sec #endif #define INVALID_DB_SESSION_POS 0x0 #define ATL_DBSESSION_ID _T("__ATL_SESSION_DB_CONNECTION") namespace ATL { // CSessionNameGenerator // This is a helper class that generates random data for session key // names. This class tries to use the CryptoApi to generate random // bytes for the session key name. If the CryptoApi isn't available // then the CRT rand() is used to generate the random bytes. This // class's GetNewSessionName member function is used to actually // generate the session name. class CSessionNameGenerator : public CCryptProv { public: bool m_bCryptNotAvailable; enum {MIN_SESSION_KEY_LEN=5}; CSessionNameGenerator() throw() : m_bCryptNotAvailable(false) { // Note that the crypto api is being // initialized with no private key // information HRESULT hr = InitVerifyContext(); m_bCryptNotAvailable = FAILED(hr) ? true : false; } // This function creates a new session name and base64 encodes it. // The base64 encoding algorithm used needs at least MIN_SESSION_KEY_LEN // bytes to work correctly. Since we stack allocate the temporary // buffer that holds the key name, the buffer must be less than or equal to // the MAX_SESSION_KEY_LEN in size. HRESULT GetNewSessionName(LPSTR szNewID, DWORD *pdwSize) throw() { HRESULT hr = E_FAIL; if (!pdwSize) return E_POINTER; if (*pdwSize < MIN_SESSION_KEY_LEN || *pdwSize > MAX_SESSION_KEY_LEN) return E_INVALIDARG; if (!szNewID) return E_POINTER; BYTE key[MAX_SESSION_KEY_LEN] = {0x0}; // calculate the number of bytes that will fit in the // buffer we've been passed DWORD dwDataSize = CalcMaxInputSize(*pdwSize); if (dwDataSize && *pdwSize >= (DWORD)(Base64EncodeGetRequiredLength(dwDataSize, ATL_BASE64_FLAG_NOCRLF))) { int dwKeySize = *pdwSize; hr = GenerateRandomName(key, dwDataSize); if (SUCCEEDED(hr)) { if( Base64Encode(key, dwDataSize, szNewID, &dwKeySize, ATL_BASE64_FLAG_NOCRLF) ) { //null terminate szNewID[dwKeySize]=0; *pdwSize = dwKeySize+1; } else hr = E_FAIL; } else { *pdwSize = (DWORD)(Base64EncodeGetRequiredLength(dwDataSize, ATL_BASE64_FLAG_NOCRLF)); return E_OUTOFMEMORY; } } return hr; } DWORD CalcMaxInputSize(DWORD nOutputSize) throw() { if (nOutputSize < (DWORD)MIN_SESSION_KEY_LEN) return 0; // subtract one from the output size to make room // for the NULL terminator in the output then // calculate the biggest number of input bytes that // when base64 encoded will fit in a buffer of size // nOutputSize (including base64 padding) int nInputSize = ((nOutputSize-1)*3)/4; int factor = ((nInputSize*4)/3)%4; if (factor) nInputSize -= factor; return nInputSize; } HRESULT GenerateRandomName(BYTE *pBuff, DWORD dwBuffSize) throw() { if (!pBuff) return E_POINTER; if (!dwBuffSize) return E_UNEXPECTED; if (!m_bCryptNotAvailable && GetHandle()) { // Use the crypto api to generate random data. return GenRandom(dwBuffSize, pBuff); } // CryptoApi isn't available so we generate // random data using rand. We seed the random // number generator with a seed that is a combination // of bytes from an arbitrary number and the system // time which changes every millisecond so it will // be different for every call to this function. FILETIME ft; GetSystemTimeAsFileTime(&ft); static DWORD dwVal = 0x21; DWORD dwSeed = (dwVal++ << 0x18) | (ft.dwLowDateTime & 0x00ffff00) | dwVal++ & 0x000000ff; srand(dwSeed); BYTE *pCurr = pBuff; // fill buffer with random bytes for (int i=0; i < (int)dwBuffSize; i++) { *pCurr = (BYTE) (rand() & 0x000000ff); pCurr++; } return S_OK; } }; // // CDefaultQueryClass // returns Query strings for use in SQL queries used // by the database persisted session service. class CDefaultQueryClass { public: LPCTSTR GetSessionRefDelete() throw() { return _T("DELETE FROM SessionReferences ") _T("WHERE SessionID=? AND RefCount <= 0 ") _T("AND DATEDIFF(millisecond, LastAccess, getdate()) > TimeoutMs"); } LPCTSTR GetSessionRefIsExpired() throw() { return _T("SELECT SessionID FROM SessionReferences ") _T("WHERE (SessionID=?) AND (DATEDIFF(millisecond, LastAccess, getdate()) > TimeoutMs)"); } LPCTSTR GetSessionRefDeleteFinal() throw() { return _T("DELETE FROM SessionReferences ") _T("WHERE SessionID=?"); } LPCTSTR GetSessionRefCreate() throw() { return _T("INSERT INTO SessionReferences ") _T("(SessionID, LastAccess, RefCount, TimeoutMs) ") _T("VALUES (?, getdate(), 1, ?)"); } LPCTSTR GetSessionRefUpdateTimeout() throw() { return _T("UPDATE SessionReferences ") _T("SET TimeoutMs=? WHERE SessionID=?"); } LPCTSTR GetSessionRefAddRef() throw() { return _T("UPDATE SessionReferences ") _T("SET RefCount=RefCount+1, ") _T("LastAccess=getdate() ") _T("WHERE SessionID=?"); } LPCTSTR GetSessionRefRemoveRef() throw() { return _T("UPDATE SessionReferences ") _T("SET RefCount=RefCount-1, ") _T("LastAccess=getdate() ") _T("WHERE SessionID=?"); } LPCTSTR GetSessionRefAccess() throw() { return _T("UPDATE SessionReferences ") _T("SET LastAccess=getdate() ") _T("WHERE SessionID=?"); } LPCTSTR GetSessionRefSelect() throw() { return _T("SELECT * FROM SessionReferences ") _T("WHERE SessionID=?"); } LPCTSTR GetSessionRefGetCount() throw() { return _T("SELECT COUNT(*) FROM SessionReferences"); } LPCTSTR GetSessionVarCount() throw() { return _T("SELECT COUNT(*) FROM SessionVariables WHERE SessionID=?"); } LPCTSTR GetSessionVarInsert() throw() { return _T("INSERT INTO SessionVariables ") _T("(VariableValue, SessionID, VariableName) ") _T("VALUES (?, ?, ?)"); } LPCTSTR GetSessionVarUpdate() throw() { return _T("UPDATE SessionVariables ") _T("SET VariableValue=? ") _T("WHERE SessionID=? AND VariableName=?"); } LPCTSTR GetSessionVarDeleteVar() throw() { return _T("DELETE FROM SessionVariables ") _T("WHERE SessionID=? AND VariableName=?"); } LPCTSTR GetSessionVarDeleteAllVars() throw() { return _T("DELETE FROM SessionVariables WHERE (SessionID=?)"); } LPCTSTR GetSessionVarSelectVar()throw() { return _T("SELECT SessionID, VariableName, VariableValue ") _T("FROM SessionVariables ") _T("WHERE SessionID=? AND VariableName=?"); } LPCTSTR GetSessionVarSelectAllVars() throw() { return _T("SELECT SessionID, VariableName, VariableValue ") _T("FROM SessionVariables ") _T("WHERE SessionID=?"); } LPCTSTR GetSessionReferencesSet() throw() { return _T("UPDATE SessionReferences SET TimeoutMs=?"); } }; // Contains the data for the session variable accessors class CSessionDataBase { public: TCHAR m_szSessionID[MAX_SESSION_KEY_LEN]; TCHAR m_VariableName[MAX_VARIABLE_NAME_LENGTH]; BYTE m_VariableValue[MAX_VARIABLE_VALUE_LENGTH]; DWORD m_VariableLen; CSessionDataBase() throw() { m_szSessionID[0] = '\0'; m_VariableName[0] = '\0'; m_VariableValue[0] = '\0'; m_VariableLen = 0; } HRESULT Assign(LPCTSTR szSessionID, LPCTSTR szVarName, VARIANT *pVal) throw() { HRESULT hr = S_OK; CVariantStream stream; if ( szSessionID ) { if ( _tcslen(szSessionID)< MAX_SESSION_KEY_LEN) _tcscpy(m_szSessionID, szSessionID); else hr = E_OUTOFMEMORY; } else return E_INVALIDARG; if (szVarName) if ( _tcslen(szVarName) < MAX_VARIABLE_NAME_LENGTH) _tcscpy(m_VariableName, szVarName); else hr = E_OUTOFMEMORY; if (pVal) { hr = stream.InsertVariant(pVal); if (hr == S_OK) { BYTE *pBytes = stream.m_stream; size_t size = stream.GetVariantSize(); if (pBytes && size && size < MAX_VARIABLE_VALUE_LENGTH) { memcpy(m_VariableValue, pBytes, stream.GetVariantSize()); m_VariableLen = (DWORD)size; } else hr = E_UNEXPECTED; } } return hr; } }; // Use to select a session variable given the name // of a session and the name of a variable. class CSessionDataSelector : public CSessionDataBase { public: BEGIN_COLUMN_MAP(CSessionDataSelector) COLUMN_ENTRY(1, m_szSessionID) COLUMN_ENTRY(2, m_VariableName) COLUMN_ENTRY_LENGTH(3, m_VariableValue, m_VariableLen) END_COLUMN_MAP() BEGIN_PARAM_MAP(CSessionDataSelector) SET_PARAM_TYPE(DBPARAMIO_INPUT) COLUMN_ENTRY(1, m_szSessionID) COLUMN_ENTRY(2, m_VariableName) END_PARAM_MAP() }; // Use to select all session variables given the name of // of a session. class CAllSessionDataSelector : public CSessionDataBase { public: BEGIN_COLUMN_MAP(CAllSessionDataSelector) COLUMN_ENTRY(1, m_szSessionID) COLUMN_ENTRY(2, m_VariableName) COLUMN_ENTRY_LENGTH(3, m_VariableValue, m_VariableLen) END_COLUMN_MAP() BEGIN_PARAM_MAP(CAllSessionDataSelector) SET_PARAM_TYPE(DBPARAMIO_INPUT) COLUMN_ENTRY(1, m_szSessionID) END_PARAM_MAP() }; // Use to update the value of a session variable class CSessionDataUpdator : public CSessionDataBase { public: BEGIN_PARAM_MAP(CSessionDataUpdator) SET_PARAM_TYPE(DBPARAMIO_INPUT) COLUMN_ENTRY_LENGTH(1, m_VariableValue, m_VariableLen) COLUMN_ENTRY(2, m_szSessionID) COLUMN_ENTRY(3, m_VariableName) END_PARAM_MAP() }; // Use to delete a session variable given the // session name and the name of the variable class CSessionDataDeletor { public: CSessionDataDeletor() { m_szSessionID[0] = '\0'; m_VariableName[0] = '\0'; } TCHAR m_szSessionID[MAX_SESSION_KEY_LEN]; TCHAR m_VariableName[MAX_VARIABLE_NAME_LENGTH]; HRESULT Assign(LPCTSTR szSessionID, LPCTSTR szVarName) throw() { if (szSessionID) { if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN) _tcscpy(m_szSessionID, szSessionID); else return E_OUTOFMEMORY; } if (szVarName) { if(_tcslen(szVarName) < MAX_VARIABLE_NAME_LENGTH) _tcscpy(m_VariableName, szVarName); else return E_OUTOFMEMORY; } return S_OK; } BEGIN_PARAM_MAP(CSessionDataDeletor) SET_PARAM_TYPE(DBPARAMIO_INPUT) COLUMN_ENTRY(1, m_szSessionID) COLUMN_ENTRY(2, m_VariableName) END_PARAM_MAP() }; class CSessionDataDeleteAll { public: TCHAR m_szSessionID[MAX_SESSION_KEY_LEN]; HRESULT Assign(LPCTSTR szSessionID) throw() { if (!szSessionID) return E_INVALIDARG; if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN) _tcscpy(m_szSessionID, szSessionID); else return E_OUTOFMEMORY; return S_OK; } BEGIN_PARAM_MAP(CSessionDataDeleteAll) SET_PARAM_TYPE(DBPARAMIO_INPUT) COLUMN_ENTRY(1, m_szSessionID) END_PARAM_MAP() }; // Used for retrieving the count of session variables for // a given session ID. class CCountAccessor { public: LONG m_nCount; TCHAR m_szSessionID[MAX_SESSION_KEY_LEN]; CCountAccessor() throw() { m_szSessionID[0] = '\0'; m_nCount = 0; } HRESULT Assign(LPCTSTR szSessionID) throw() { if (!szSessionID) return E_INVALIDARG; if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN) _tcscpy(m_szSessionID, szSessionID); else return E_OUTOFMEMORY; return S_OK; } BEGIN_COLUMN_MAP(CCountAccessor) COLUMN_ENTRY(1, m_nCount) END_COLUMN_MAP() BEGIN_PARAM_MAP(CCountAccessor) SET_PARAM_TYPE(DBPARAMIO_INPUT) COLUMN_ENTRY(1, m_szSessionID) END_PARAM_MAP() }; // Used for updating entries in the session // references table, given a session ID class CSessionRefUpdator { public: TCHAR m_SessionID[MAX_SESSION_KEY_LEN]; HRESULT Assign(LPCTSTR szSessionID) throw() { if (!szSessionID) return E_INVALIDARG; if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN) _tcscpy(m_SessionID, szSessionID); else return E_OUTOFMEMORY; return S_OK; } BEGIN_PARAM_MAP(CSessionRefUpdator) SET_PARAM_TYPE(DBPARAMIO_INPUT) COLUMN_ENTRY(1, m_SessionID) END_PARAM_MAP() }; class CSessionRefIsExpired { public: TCHAR m_SessionID[MAX_SESSION_KEY_LEN]; TCHAR m_SessionIDOut[MAX_SESSION_KEY_LEN]; HRESULT Assign(LPCTSTR szSessionID) throw() { m_SessionIDOut[0]=0; if (!szSessionID) return E_INVALIDARG; if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN) _tcscpy(m_SessionID, szSessionID); else return E_OUTOFMEMORY; return S_OK; } BEGIN_COLUMN_MAP(CSessionRefIsExpired) COLUMN_ENTRY(1, m_SessionIDOut) END_COLUMN_MAP() BEGIN_PARAM_MAP(CSessionRefIsExpired) SET_PARAM_TYPE(DBPARAMIO_INPUT) COLUMN_ENTRY(1, m_SessionID) END_PARAM_MAP() }; class CSetAllTimeouts { public: unsigned __int64 m_dwNewTimeout; HRESULT Assign(unsigned __int64 dwNewValue) { m_dwNewTimeout = dwNewValue; return S_OK; } BEGIN_PARAM_MAP(CSetAllTimeouts) SET_PARAM_TYPE(DBPARAMIO_INPUT) COLUMN_ENTRY(1, m_dwNewTimeout) END_PARAM_MAP() }; class CSessionRefUpdateTimeout { public: TCHAR m_SessionID[MAX_SESSION_KEY_LEN]; unsigned __int64 m_nNewTimeout; HRESULT Assign(LPCTSTR szSessionID, unsigned __int64 nNewTimeout) throw() { if (!szSessionID) return E_INVALIDARG; if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN) _tcscpy(m_SessionID, szSessionID); else return E_OUTOFMEMORY; m_nNewTimeout = nNewTimeout; return S_OK; } BEGIN_PARAM_MAP(CSessionRefUpdateTimeout) SET_PARAM_TYPE(DBPARAMIO_INPUT) COLUMN_ENTRY(1, m_nNewTimeout) COLUMN_ENTRY(2, m_SessionID) END_PARAM_MAP() }; class CSessionRefSelector { public: TCHAR m_SessionID[MAX_SESSION_KEY_LEN]; int m_RefCount; HRESULT Assign(LPCTSTR szSessionID) throw() { if (!szSessionID) return E_INVALIDARG; if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN) _tcscpy(m_SessionID, szSessionID); else return E_OUTOFMEMORY; return S_OK; } BEGIN_COLUMN_MAP(CSessionRefSelector) COLUMN_ENTRY(1, m_SessionID) COLUMN_ENTRY(3, m_RefCount) END_COLUMN_MAP() BEGIN_PARAM_MAP(CSessionRefSelector) SET_PARAM_TYPE(DBPARAMIO_INPUT) COLUMN_ENTRY(1, m_SessionID) END_PARAM_MAP() }; class CSessionRefCount { public: LONG m_nCount; BEGIN_COLUMN_MAP(CSessionRefCount) COLUMN_ENTRY(1, m_nCount) END_COLUMN_MAP() }; // Used for creating new entries in the session // references table. class CSessionRefCreator { public: TCHAR m_SessionID[MAX_SESSION_KEY_LEN]; unsigned __int64 m_TimeoutMs; HRESULT Assign(LPCTSTR szSessionID, unsigned __int64 timeout) throw() { if (!szSessionID) return E_INVALIDARG; if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN) { _tcscpy(m_SessionID, szSessionID); m_TimeoutMs = timeout; } else return E_OUTOFMEMORY; return S_OK; } BEGIN_PARAM_MAP(CSessionRefCreator) SET_PARAM_TYPE(DBPARAMIO_INPUT) COLUMN_ENTRY(1, m_SessionID) COLUMN_ENTRY(2, m_TimeoutMs) END_PARAM_MAP() }; // CDBSession // This session persistance class persists session variables to // an OLEDB datasource. The following table gives a general description // of the table schema for the tables this class uses. // // TableName: SessionVariables // Column Name Type Description // 1 SessionID char[MAX_SESSION_KEY_LEN] Session Key name // 2 VariableName char[MAX_VARIABLE_NAME_LENGTH] Variable Name // 3 VariableValue varbinary[MAX_VARIABLE_VALUE_LENGTH] Variable Value // // TableName: SessionReferences // Column Name Type Description // 1 SessionID char[MAX_SESSION_KEY_LEN] Session Key Name. // 2 LastAccess datetime Date and time of last access to this session. // 3 RefCount int Current references on this session. // 4 TimeoutMS int Timeout value for the session in milli seconds typedef bool (*PFN_GETPROVIDERINFO)(DWORD_PTR, wchar_t **); template class CDBSession: public ISession, public CComObjectRootEx { typedef CCommand > iterator_accessor; public: typedef QueryClass DBQUERYCLASS_TYPE; BEGIN_COM_MAP(CDBSession) COM_INTERFACE_ENTRY(ISession) END_COM_MAP() CDBSession() throw(): m_dwTimeout(ATL_SESSION_TIMEOUT) { m_szSessionName[0] = '\0'; } ~CDBSession() throw() { } void FinalRelease()throw() { SessionUnlock(); } STDMETHOD(SetVariable)(LPCSTR szName, VARIANT Val) throw() { HRESULT hr = E_FAIL; if (!szName) return E_INVALIDARG; // Get the data connection for this thread. CDataConnection dataconn; hr = GetSessionConnection(&dataconn, m_spServiceProvider); if (hr != S_OK) return hr; // Update the last access time for this session hr = Access(); if (hr != S_OK) return hr; // Allocate an updator command and fill out it's input parameters. CCommand > command; _ATLTRY { CA2CT name(szName); hr = command.Assign(m_szSessionName, name, &Val); } _ATLCATCHALL() { hr = E_OUTOFMEMORY; } if (hr != S_OK) return hr; // Try an update. Update will fail if the variable is not already there. LONG nRows = 0; hr = command.Open(dataconn, m_QueryObj.GetSessionVarUpdate(), NULL, &nRows, DBGUID_DEFAULT, false); if (hr == S_OK && nRows <= 0) hr = E_UNEXPECTED; if (hr != S_OK) { // Try an insert hr = command.Open(dataconn, m_QueryObj.GetSessionVarInsert(), NULL, &nRows, DBGUID_DEFAULT, false); if (hr == S_OK && nRows <=0) hr = E_UNEXPECTED; } return hr; } // Warning: For string data types, depending on the configuration of // your database, strings might be returned with trailing white space. STDMETHOD(GetVariable)(LPCSTR szName, VARIANT *pVal) throw() { HRESULT hr = E_FAIL; if (!szName) return E_INVALIDARG; if (pVal) VariantClear(pVal); else return E_POINTER; // Get the data connection for this thread CDataConnection dataconn; hr = GetSessionConnection(&dataconn, m_spServiceProvider); if (hr != S_OK) return hr; // Update the last access time for this session hr = Access(); if (hr != S_OK) return hr; // Allocate a command a fill out it's input parameters. CCommand > command; _ATLTRY { CA2CT name(szName); hr = command.Assign(m_szSessionName, name, NULL); } _ATLCATCHALL() { hr = E_OUTOFMEMORY; } if (hr == S_OK) { hr = command.Open(dataconn, m_QueryObj.GetSessionVarSelectVar()); if (SUCCEEDED(hr)) { if ( S_OK == (hr = command.MoveFirst())) { CStreamOnByteArray stream(command.m_VariableValue); CComVariant vOut; hr = vOut.ReadFromStream(static_cast(&stream)); if (hr == S_OK) hr = vOut.Detach(pVal); } } } return hr; } STDMETHOD(RemoveVariable)(LPCSTR szName) throw() { HRESULT hr = E_FAIL; if (!szName) return E_INVALIDARG; // Get the data connection for this thread. CDataConnection dataconn; hr = GetSessionConnection(&dataconn, m_spServiceProvider); if (hr != S_OK) return hr; // update the last access time for this session hr = Access(); if (hr != S_OK) return hr; // allocate a command and set it's input parameters CCommand > command; _ATLTRY { CA2CT name(szName); hr = command.Assign(m_szSessionName, name); } _ATLCATCHALL() { return E_OUTOFMEMORY; } // execute the command long nRows = 0; if (hr == S_OK) hr = command.Open(dataconn, m_QueryObj.GetSessionVarDeleteVar(), NULL, &nRows, DBGUID_DEFAULT, false); if (hr == S_OK && nRows <= 0) hr = E_UNEXPECTED; return hr; } // Gives the count of rows in the table for this session ID. STDMETHOD(GetCount)(long *pnCount) throw() { HRESULT hr = S_OK; if (pnCount) *pnCount = 0; else return E_POINTER; // Get the database connection for this thread. CDataConnection dataconn; hr = GetSessionConnection(&dataconn, m_spServiceProvider); if (hr != S_OK) return hr; hr = Access(); if (hr != S_OK) return hr; CCommand > command; hr = command.Assign(m_szSessionName); if (hr == S_OK) { hr = command.Open(dataconn, m_QueryObj.GetSessionVarCount()); if (hr == S_OK) { if (S_OK == (hr = command.MoveFirst())) { *pnCount = command.m_nCount; hr = S_OK; } } } return hr; } STDMETHOD(RemoveAllVariables)() throw() { HRESULT hr = E_UNEXPECTED; // Get the data connection for this thread. CDataConnection dataconn; hr = GetSessionConnection(&dataconn, m_spServiceProvider); if (hr != S_OK) return hr; CCommand > command; hr = command.Assign(m_szSessionName); if (hr != S_OK) return hr; // delete all session variables hr = command.Open(dataconn, m_QueryObj.GetSessionVarDeleteAllVars(), NULL, NULL, DBGUID_DEFAULT, false); return hr; } // Iteration of variables works by taking a snapshot // of the sessions at the point in time BeginVariableEnum // is called, and then keeping an index variable that you use to // move through the snapshot rowset. It is important to know // that the handle returned in phEnum is not thread safe. It // should only be used by the calling thread. STDMETHOD(BeginVariableEnum)(HSESSIONENUM *phEnum, POSITION *pPOS) throw() { HRESULT hr = E_FAIL; if (!pPOS) return E_POINTER; if (phEnum) *phEnum = NULL; else return E_POINTER; // Get the data connection for this thread. CDataConnection dataconn; hr = GetSessionConnection(&dataconn, m_spServiceProvider); if (hr != S_OK) return hr; // Update the last access time for this session. hr = Access(); if (hr != S_OK) return hr; // Allocate a new iterator accessor and initialize it's input parameters. iterator_accessor *pIteratorAccessor = NULL; ATLTRYALLOC(pIteratorAccessor = new iterator_accessor); if (!pIteratorAccessor) return E_OUTOFMEMORY; hr = pIteratorAccessor->Assign(m_szSessionName, NULL, NULL); if (hr == S_OK) { // execute the command and move to the first row of the recordset. hr = pIteratorAccessor->Open(dataconn, m_QueryObj.GetSessionVarSelectAllVars()); if (hr == S_OK) { hr = pIteratorAccessor->MoveFirst(); if (hr == S_OK) { *pPOS = (POSITION) INVALID_DB_SESSION_POS + 1; *phEnum = reinterpret_cast(pIteratorAccessor); } } if (hr != S_OK) { *pPOS = INVALID_DB_SESSION_POS; *phEnum = NULL; delete pIteratorAccessor; } } return hr; } // The values for hEnum and pPos must have been initialized in a previous // call to BeginVariableEnum. On success, the out variant will hold the next // variable STDMETHOD(GetNextVariable)(HSESSIONENUM hEnum, POSITION *pPOS, LPSTR szName, DWORD dwLen, VARIANT *pVal) throw() { if (!pPOS) return E_INVALIDARG; if (pVal) VariantInit(pVal); else return E_POINTER; if (!hEnum) return E_UNEXPECTED; if (*pPOS <= INVALID_DB_SESSION_POS) return E_UNEXPECTED; iterator_accessor *pIteratorAccessor = reinterpret_cast(hEnum); // update the last access time. HRESULT hr = Access(); POSITION posCurrent = *pPOS; if (szName) { // caller wants entry name size_t nNameLenChars = _tcslen(pIteratorAccessor->m_VariableName); if (dwLen > nNameLenChars) { _ATLTRY { CT2CA szVarName(pIteratorAccessor->m_VariableName); strcpy(szName, szVarName); } _ATLCATCHALL() { hr = E_OUTOFMEMORY; } } else hr = E_OUTOFMEMORY; // buffer not big enough } if (hr == S_OK) { CStreamOnByteArray stream(pIteratorAccessor->m_VariableValue); CComVariant vOut; hr = vOut.ReadFromStream(static_cast(&stream)); if (hr == S_OK) vOut.Detach(pVal); else return hr; } else return hr; hr = pIteratorAccessor->MoveNext(); *pPOS = ++posCurrent; if (hr == DB_S_ENDOFROWSET) { // We're done iterating, reset everything *pPOS = INVALID_DB_SESSION_POS; hr = S_OK; } if (hr != S_OK) { VariantClear(pVal); } return hr; } // CloseEnum frees up any resources allocated by the iterator STDMETHOD(CloseEnum)(HSESSIONENUM hEnum) throw() { iterator_accessor *pIteratorAccessor = reinterpret_cast(hEnum); if (!pIteratorAccessor) return E_INVALIDARG; pIteratorAccessor->Close(); delete pIteratorAccessor; return S_OK; } // // Returns S_FALSE if it's not expired // S_OK if it is expired and an error HRESULT // if an error occurred. STDMETHOD(IsExpired)() throw() { HRESULT hrRet = S_FALSE; HRESULT hr = E_UNEXPECTED; // Get the data connection for this thread. CDataConnection dataconn; hr = GetSessionConnection(&dataconn, m_spServiceProvider); if (hr != S_OK) return hr; CCommand > command; hr = command.Assign(m_szSessionName); if (hr != S_OK) return hr; hr = command.Open(dataconn, m_QueryObj.GetSessionRefIsExpired(), NULL, NULL, DBGUID_DEFAULT, true); if (hr == S_OK) { if (S_OK == command.MoveFirst()) { if (!_tcscmp(command.m_SessionIDOut, m_szSessionName)) hrRet = S_OK; } } if (hr == S_OK) return hrRet; return hr; } STDMETHOD(SetTimeout)(unsigned __int64 dwNewTimeout) throw() { HRESULT hr = E_UNEXPECTED; // Get the data connection for this thread. CDataConnection dataconn; hr = GetSessionConnection(&dataconn, m_spServiceProvider); if (hr != S_OK) return hr; // allocate a command and set it's input parameters CCommand > command; hr = command.Assign(m_szSessionName, dwNewTimeout); if (hr != S_OK) return hr; hr = command.Open(dataconn, m_QueryObj.GetSessionRefUpdateTimeout(), NULL, NULL, DBGUID_DEFAULT, false); return hr; } // SessionLock increments the session reference count for this session. // If there is not a session by this name in the session references table, // a new session entry is created in the the table. HRESULT SessionLock() throw() { HRESULT hr = E_UNEXPECTED; if (!m_szSessionName || m_szSessionName[0]==0) return hr; // no session to lock. // retrieve the data connection for this thread CDataConnection dataconn; hr = GetSessionConnection(&dataconn, m_spServiceProvider); if (hr != S_OK) return hr; // first try to update a session with this name LONG nRows = 0; CCommand > updator; if (S_OK == updator.Assign(m_szSessionName)) { if (S_OK != (hr = updator.Open(dataconn, m_QueryObj.GetSessionRefAddRef(), NULL, &nRows, DBGUID_DEFAULT, false)) || nRows == 0) { // No session to update. Use the creator accessor // to create a new session reference. CCommand > creator; hr = creator.Assign(m_szSessionName, m_dwTimeout); if (hr == S_OK) hr = creator.Open(dataconn, m_QueryObj.GetSessionRefCreate(), NULL, &nRows, DBGUID_DEFAULT, false); } } // We should have been able to create or update a session. ATLASSERT(nRows > 0); if (hr == S_OK && nRows <= 0) hr = E_UNEXPECTED; return hr; } // SessionUnlock decrements the session RefCount for this session. // Sessions cannot be removed from the database unless the session // refcount is 0 HRESULT SessionUnlock() throw() { HRESULT hr = E_UNEXPECTED; if (!m_szSessionName || m_szSessionName[0]==0) return hr; // get the data connection for this thread CDataConnection dataconn; hr = GetSessionConnection(&dataconn, m_spServiceProvider); if (hr != S_OK) return hr; // The session must exist at this point in order to unlock it // so we can just use the session updator here. LONG nRows = 0; CCommand > updator; hr = updator.Assign(m_szSessionName); if (hr == S_OK) { hr = updator.Open( dataconn, m_QueryObj.GetSessionRefRemoveRef(), NULL, &nRows, DBGUID_DEFAULT, false); } if (hr != S_OK) return hr; // delete the session from the database if // nobody else is using it and it's expired. hr = FreeSession(); return hr; } // Access updates the last access time for the session. The access // time for sessions is updated using the SQL GETDATE function on the // database server so that all clients will be using the same clock // to compare access times against. HRESULT Access() throw() { HRESULT hr = E_UNEXPECTED; if (!m_szSessionName || m_szSessionName[0]==0) return hr; // no session to access // get the data connection for this thread CDataConnection dataconn; hr = GetSessionConnection(&dataconn, m_spServiceProvider); if (hr != S_OK) return hr; // The session reference entry in the references table must // be created prior to calling this function so we can just // use an updator to update the current entry. CCommand > updator; LONG nRows = 0; hr = updator.Assign(m_szSessionName); if (hr == S_OK) { hr = updator.Open( dataconn, m_QueryObj.GetSessionRefAccess(), NULL, &nRows, DBGUID_DEFAULT, false); } ATLASSERT(nRows > 0); if (hr == S_OK && nRows <= 0) hr = E_UNEXPECTED; return hr; } // If the session is expired and it's reference is 0, // it can be deleted. SessionUnlock calls this function to // unlock the session and delete it after we release a session // lock. Note that our SQL command will only delete the session // if it is expired and it's refcount is <= 0 HRESULT FreeSession() throw() { HRESULT hr = E_UNEXPECTED; if (!m_szSessionName || m_szSessionName[0]==0) return hr; // Get the data connection for this thread. CDataConnection dataconn; hr = GetSessionConnection(&dataconn, m_spServiceProvider); if (hr != S_OK) return hr; CCommand > updator; // The SQL for this command only deletes the // session reference from the references table if it's access // count is 0 and it has expired. return updator.Open(dataconn, m_QueryObj.GetSessionRefDelete(), NULL, NULL, DBGUID_DEFAULT, false); } // Initialize is called each time a new session is created. HRESULT Initialize( LPCSTR szSessionName, IServiceProvider *pServiceProvider, DWORD_PTR dwCookie, PFN_GETPROVIDERINFO pfnInfo) throw() { if (!szSessionName) return E_INVALIDARG; if (!pServiceProvider) return E_INVALIDARG; if (!pfnInfo) return E_INVALIDARG; m_pfnInfo = pfnInfo; m_dwProvCookie = dwCookie; m_spServiceProvider = pServiceProvider; _ATLTRY { CA2CT tcsSessionName(szSessionName); if (_tcslen(tcsSessionName) < MAX_SESSION_KEY_LEN) _tcscpy(m_szSessionName, tcsSessionName); else return E_OUTOFMEMORY; } _ATLCATCHALL() { return E_OUTOFMEMORY; } return SessionLock(); } HRESULT GetSessionConnection(CDataConnection *pConn, IServiceProvider *pProv) throw() { if (!pProv) return E_INVALIDARG; if (!m_pfnInfo || !m_dwProvCookie) return E_UNEXPECTED; wchar_t *wszProv = NULL; if (m_pfnInfo(m_dwProvCookie, &wszProv) && wszProv!=NULL) { return GetDataSource(pProv, ATL_DBSESSION_ID, wszProv, pConn); } return E_FAIL; } protected: TCHAR m_szSessionName[MAX_SESSION_KEY_LEN]; unsigned __int64 m_dwTimeout; CComPtr m_spServiceProvider; DWORD_PTR m_dwProvCookie; PFN_GETPROVIDERINFO m_pfnInfo; DBQUERYCLASS_TYPE m_QueryObj; }; // CDBSession template > class CDBSessionServiceImplT { wchar_t m_szConnectionString[MAX_CONNECTION_STRING_LEN]; CComPtr m_spServiceProvider; TDBSession::DBQUERYCLASS_TYPE m_QueryObj; public: typedef const wchar_t* SERVICEIMPL_INITPARAM_TYPE; CDBSessionServiceImplT() throw() { m_dwTimeout = ATL_SESSION_TIMEOUT; m_szConnectionString[0] = '\0'; } static bool GetProviderInfo(DWORD_PTR dwProvCookie, wchar_t **ppszProvInfo) throw() { if (dwProvCookie && ppszProvInfo) { CDBSessionServiceImplT *pSvc = reinterpret_cast*>(dwProvCookie); *ppszProvInfo = pSvc->m_szConnectionString; return true; } return false; } HRESULT GetSessionConnection(CDataConnection *pConn, IServiceProvider *pProv) throw() { if (!pProv) return E_INVALIDARG; if(!m_szConnectionString[0]) return E_UNEXPECTED; return GetDataSource(pProv, ATL_DBSESSION_ID, m_szConnectionString, pConn); } HRESULT Initialize(SERVICEIMPL_INITPARAM_TYPE pData, IServiceProvider *pProvider, unsigned __int64 dwInitialTimeout) throw() { if (!pData || !pProvider) return E_INVALIDARG; if (wcslen(pData) < MAX_CONNECTION_STRING_LEN) { wcscpy(m_szConnectionString, pData); } else return E_OUTOFMEMORY; m_dwTimeout = dwInitialTimeout; m_spServiceProvider = pProvider; return S_OK; } HRESULT CreateNewSession(LPSTR szNewID, DWORD *pdwSize, ISession** ppSession) throw() { HRESULT hr = E_FAIL; CComObject *pNewSession = NULL; if (!pdwSize) return E_INVALIDARG; if (ppSession) *ppSession = NULL; else return E_POINTER; if (szNewID) *szNewID = NULL; else return E_INVALIDARG; // Create new session CComObject::CreateInstance(&pNewSession); if (pNewSession == NULL) return E_OUTOFMEMORY; // Create a session name and initialize the object hr = m_SessionNameGenerator.GetNewSessionName(szNewID, pdwSize); if (hr == S_OK) { hr = pNewSession->Initialize(szNewID, m_spServiceProvider, reinterpret_cast(this), GetProviderInfo); if (hr == S_OK) { // we don't hold a reference to the object hr = pNewSession->QueryInterface(ppSession); } } if (hr != S_OK) delete pNewSession; return hr; } HRESULT GetSession(LPCSTR szID, ISession **ppSession) throw() { HRESULT hr = E_FAIL; if (!szID) return E_INVALIDARG; if (ppSession) *ppSession = NULL; else return E_POINTER; CComObject *pNewSession = NULL; // Check the DB to see if the session ID is a valid session _ATLTRY { CA2CT session(szID); hr = IsValidSession(session); } _ATLCATCHALL() { hr = E_OUTOFMEMORY; } if (hr == S_OK) { // Create new session object to represent this session CComObject::CreateInstance(&pNewSession); if (pNewSession == NULL) return E_OUTOFMEMORY; hr = pNewSession->Initialize(szID, m_spServiceProvider, reinterpret_cast(this), GetProviderInfo); if (hr == S_OK) { // we don't hold a reference to the object hr = pNewSession->QueryInterface(ppSession); } } if (hr != S_OK && pNewSession) delete pNewSession; return hr; } HRESULT CloseSession(LPCSTR szID) throw() { if (!szID) return E_INVALIDARG; CDataConnection conn; HRESULT hr = GetSessionConnection(&conn, m_spServiceProvider); if (hr != S_OK) return hr; // set up accessors CCommand > updator; CCommand > command; _ATLTRY { CA2CT session(szID); hr = updator.Assign(session); if (hr == S_OK) hr = command.Assign(session); } _ATLCATCHALL() { hr = E_OUTOFMEMORY; } if (hr == S_OK) { // delete all session variables hr = command.Open(conn, m_QueryObj.GetSessionVarDeleteAllVars(), NULL, NULL, DBGUID_DEFAULT, false); if (hr == S_OK) { // delete references in the session references table hr = updator.Open(conn, m_QueryObj.GetSessionRefDeleteFinal(), NULL, NULL, DBGUID_DEFAULT, false); } } return hr; } HRESULT SetSessionTimeout(unsigned __int64 nTimeout) throw() { // Get the data connection for this thread CDataConnection conn; HRESULT hr = GetSessionConnection(&conn, m_spServiceProvider); if (hr != S_OK) return hr; // all sessions get the same timeout CCommand > command; hr = command.Assign(nTimeout); if (hr == S_OK) { hr = command.Open(conn, m_QueryObj.GetSessionReferencesSet(), NULL, NULL, DBGUID_DEFAULT, false); if (hr == S_OK) { m_dwTimeout = nTimeout; } } return hr; } HRESULT GetSessionTimeout(unsigned __int64* pnTimeout) throw() { if (pnTimeout) *pnTimeout = m_dwTimeout; else return E_INVALIDARG; return S_OK; } HRESULT GetSessionCount(DWORD *pnCount) throw() { if (pnCount) *pnCount = 0; else return E_INVALIDARG; CCommand > command; CDataConnection conn; HRESULT hr = GetSessionConnection(&conn, m_spServiceProvider); if (hr != S_OK) return hr; hr = command.Open(conn, m_QueryObj.GetSessionRefGetCount()); if (hr == S_OK) { hr = command.MoveFirst(); if (hr == S_OK) { *pnCount = (DWORD)command.m_nCount; } } return hr; } void ReleaseAllSessions() throw() { // nothing to do } void SweepSessions() throw() { // nothing to do } // Helpers HRESULT IsValidSession(LPCTSTR szID) throw() { if (!szID) return E_INVALIDARG; // Look in the sessionreferences table to see if there is an entry // for this session. if (m_szConnectionString[0] == 0) return E_UNEXPECTED; CDataConnection conn; HRESULT hr = GetSessionConnection(&conn, m_spServiceProvider); if (hr != S_OK) return hr; // Check the session references table to see if // this is a valid session CCommand > selector; hr = selector.Assign(szID); if (hr != S_OK) return hr; // The SQL for this command only deletes the // session reference from the references table if it's access // count is 0 and it has expired. hr = selector.Open(conn, m_QueryObj.GetSessionRefSelect(), NULL, NULL, DBGUID_DEFAULT, true); if (hr == S_OK) return selector.MoveFirst(); return hr; } CSessionNameGenerator m_SessionNameGenerator; // Object for generating session names unsigned __int64 m_dwTimeout; }; // CDBSessionServiceImplT typedef CDBSessionServiceImplT<> CDBSessionServiceImpl; ////////////////////////////////////////////////////////////////// // // In-memory persisted session // ////////////////////////////////////////////////////////////////// // In-memory persisted session service keeps a pointer // to the session obejct around in memory. The pointer is // contained in a CComPtr, which is stored in a CAtlMap, so // we have to have a CElementTraits class for that. typedef CComPtr SESSIONPTRTYPE; template<> class CElementTraits : public CElementTraitsBase { public: static ULONG Hash( INARGTYPE obj ) throw() { return( (ULONG)(ULONG_PTR)obj.p); } static BOOL CompareElements( OUTARGTYPE element1, OUTARGTYPE element2 ) throw() { return element1.IsEqualObject(element2.p) ? TRUE : FALSE; } static int CompareElementsOrdered( INARGTYPE , INARGTYPE ) throw() { ATLASSERT(0); // NOT IMPLEMENTED return 0; } }; // CMemSession // This session persistance class persists session variables in memory. // Note that this type of persistance should only be used on single server // web sites. class CMemSession : public ISession, public CComObjectRootEx { public: BEGIN_COM_MAP(CMemSession) COM_INTERFACE_ENTRY(ISession) END_COM_MAP() CMemSession() throw(...) { } STDMETHOD(GetVariable)(LPCSTR szName, VARIANT *pVal) throw() { if (!szName) return E_INVALIDARG; if (pVal) VariantInit(pVal); else return E_POINTER; HRESULT hr = Access(); if (hr == S_OK) { CSLockType lock(m_cs, false); hr = lock.Lock(); if (FAILED(hr)) return hr; _ATLTRY { CComVariant val; if (m_Variables.Lookup(szName, val)) { hr = VariantCopy(pVal, &val); } } _ATLCATCHALL() { hr = E_UNEXPECTED; } } return hr; } STDMETHOD(SetVariable)(LPCSTR szName, VARIANT vNewVal) throw() { if (!szName) return E_INVALIDARG; HRESULT hr = Access(); if (hr == S_OK) { CSLockType lock(m_cs, false); hr = lock.Lock(); if (FAILED(hr)) return hr; _ATLTRY { hr = m_Variables.SetAt(szName, vNewVal) ? S_OK : E_FAIL; } _ATLCATCHALL() { hr = E_UNEXPECTED; } } return hr; } STDMETHOD(RemoveVariable)(LPCSTR szName) throw() { if (!szName) return E_INVALIDARG; HRESULT hr = Access(); if (hr == S_OK) { CSLockType lock(m_cs, false); hr = lock.Lock(); if (FAILED(hr)) return hr; _ATLTRY { hr = m_Variables.RemoveKey(szName) ? S_OK : E_FAIL; } _ATLCATCHALL() { hr = E_UNEXPECTED; } } return hr; } STDMETHOD(GetCount)(long *pnCount) throw() { if (pnCount) return *pnCount = 0; else return E_POINTER; HRESULT hr = Access(); if (hr == S_OK) { CSLockType lock(m_cs, false); hr = lock.Lock(); if (FAILED(hr)) return hr; *pnCount = (long) m_Variables.GetCount(); } return hr; } STDMETHOD(RemoveAllVariables)() throw() { HRESULT hr = Access(); if (hr == S_OK) { CSLockType lock(m_cs, false); hr = lock.Lock(); if (FAILED(hr)) return hr; m_Variables.RemoveAll(); } return hr; } STDMETHOD(BeginVariableEnum)(HSESSIONENUM *phEnumHandle, POSITION *pPOS) throw() { if (phEnumHandle) *phEnumHandle = NULL; else return E_POINTER; if (pPOS) *pPOS = NULL; else return E_POINTER; HRESULT hr = Access(); if (hr == S_OK) { CSLockType lock(m_cs, false); hr = lock.Lock(); if (FAILED(hr)) return hr; *pPOS = m_Variables.GetStartPosition(); } return hr; } STDMETHOD(GetNextVariable)(HSESSIONENUM /*hEnum*/, POSITION *pPOS, LPSTR szName, DWORD dwLen, VARIANT *pVal) throw() { if (!szName) return E_INVALIDARG; if (pVal) VariantInit(pVal); else return E_POINTER; if (!pPOS) return E_POINTER; CComVariant val; POSITION pos = *pPOS; HRESULT hr = Access(); if (hr == S_OK) { CSLockType lock(m_cs, false); hr = lock.Lock(); if (FAILED(hr)) return hr; _ATLTRY { CStringA strName = m_Variables.GetKeyAt(pos); if (strName.GetLength()) { if (dwLen > (DWORD)strName.GetLength()) strcpy(szName, strName); else hr = E_OUTOFMEMORY; } if (hr == S_OK) { val = m_Variables.GetNextValue(pos); hr = VariantCopy(pVal, &val); if (hr == S_OK) *pPOS = pos; } } _ATLCATCHALL() { hr = E_UNEXPECTED; } } return hr; } STDMETHOD(CloseEnum)(HSESSIONENUM /*hEnumHandle*/) throw() { return S_OK; } STDMETHOD(IsExpired)() throw() { CTime tmNow = CTime::GetCurrentTime(); CTimeSpan span = tmNow-m_tLastAccess; if ((unsigned __int64)((span.GetTotalSeconds()*1000)) > m_dwTimeout) return S_OK; return S_FALSE; } HRESULT Access() throw() { // We lock here to protect against multiple threads // updating the same member concurrently. CSLockType lock(m_cs, false); HRESULT hr = lock.Lock(); if (FAILED(hr)) return hr; m_tLastAccess = CTime::GetCurrentTime(); return S_OK; } STDMETHOD(SetTimeout)(unsigned __int64 dwNewTimeout) throw() { // We lock here to protect against multiple threads // updating the same member concurrently CSLockType lock(m_cs, false); HRESULT hr = lock.Lock(); if (FAILED(hr)) return hr; m_dwTimeout = dwNewTimeout; return S_OK; } HRESULT SessionLock() throw() { Access(); return S_OK; } HRESULT SessionUnlock() throw() { return S_OK; } protected: typedef CAtlMap > VarMapType; unsigned __int64 m_dwTimeout; CTime m_tLastAccess; VarMapType m_Variables; CComAutoCriticalSection m_cs; typedef CComCritSecLock CSLockType; }; // CMemSession // // CMemSessionServiceImpl // Implements the service part of in-memory persisted session services. // class CMemSessionServiceImpl { public: typedef void* SERVICEIMPL_INITPARAM_TYPE; CMemSessionServiceImpl() throw() { m_dwTimeout = ATL_SESSION_TIMEOUT; } HRESULT CreateNewSession(LPSTR szNewID, DWORD *pdwSize, ISession** ppSession) throw() { HRESULT hr = E_FAIL; CComObject *pNewSession = NULL; if (!szNewID) return E_INVALIDARG; if (!pdwSize) return E_POINTER; if (ppSession) *ppSession = NULL; else return E_POINTER; _ATLTRY { // Create new session CComObject::CreateInstance(&pNewSession); if (pNewSession == NULL) return E_OUTOFMEMORY; // Initialize and add to list of CSessionData hr = m_SessionNameGenerator.GetNewSessionName(szNewID, pdwSize); if (SUCCEEDED(hr)) { CComPtr spSession; hr = pNewSession->QueryInterface(&spSession); if (SUCCEEDED(hr)) { pNewSession->SetTimeout(m_dwTimeout); pNewSession->Access(); CSLockType lock(m_CritSec, false); hr = lock.Lock(); if (FAILED(hr)) return hr; m_Sessions.SetAt(szNewID, spSession); *ppSession = spSession.Detach(); } } } _ATLCATCHALL() { hr = E_UNEXPECTED; } return hr; } HRESULT GetSession(LPCSTR szID, ISession **ppSession) throw() { HRESULT hr = E_FAIL; SessMapType::CPair *pPair = NULL; if (ppSession) *ppSession = NULL; else return E_POINTER; if (!szID) return E_INVALIDARG; CSLockType lock(m_CritSec, false); hr = lock.Lock(); if (FAILED(hr)) return hr; _ATLTRY { pPair = m_Sessions.Lookup(szID); if (pPair) // the session exists and is in our local map of sessions { hr = pPair->m_value.QueryInterface(ppSession); } } _ATLCATCHALL() { return E_UNEXPECTED; } return hr; } HRESULT CloseSession(LPCSTR szID) throw() { if (!szID) return E_INVALIDARG; HRESULT hr = E_FAIL; CSLockType lock(m_CritSec, false); hr = lock.Lock(); if (FAILED(hr)) return hr; _ATLTRY { hr = m_Sessions.RemoveKey(szID) ? S_OK : E_FAIL; } _ATLCATCHALL() { hr = E_UNEXPECTED; } return hr; } void SweepSessions() throw() { POSITION posRemove = NULL; const SessMapType::CPair *pPair = NULL; POSITION pos = NULL; CSLockType lock(m_CritSec, false); if (FAILED(lock.Lock())) return; pos = m_Sessions.GetStartPosition(); while (pos) { posRemove = pos; pPair = m_Sessions.GetNext(pos); if (pPair) { if (pPair->m_value.p && S_OK == pPair->m_value->IsExpired()) { // remove our reference on the session m_Sessions.RemoveAtPos(posRemove); } } } } HRESULT SetSessionTimeout(unsigned __int64 nTimeout) throw() { HRESULT hr = S_OK; CComPtr spSession; m_dwTimeout = nTimeout; POSITION pos = m_Sessions.GetStartPosition(); CSLockType lock(m_CritSec, false); hr = lock.Lock(); if (FAILED(hr)) return hr; while (pos) { SessMapType::CPair *pPair = const_cast(m_Sessions.GetNext(pos)); if (pPair) { spSession = pPair->m_value; if (spSession) { // if we fail on any of the sets we will return the // error code immediately hr = spSession->SetTimeout(nTimeout); spSession.Release(); if (hr != S_OK) break; } } } return hr; } HRESULT GetSessionTimeout(unsigned __int64* pnTimeout) throw() { if (pnTimeout) *pnTimeout = m_dwTimeout; else return E_POINTER; return S_OK; } HRESULT GetSessionCount(DWORD *pnCount) throw() { if (pnCount) *pnCount = 0; else return E_POINTER; CSLockType lock(m_CritSec, false); HRESULT hr = lock.Lock(); if (FAILED(hr)) return hr; *pnCount = (DWORD)m_Sessions.GetCount(); return S_OK; } void ReleaseAllSessions() throw() { CSLockType lock(m_CritSec, false); if (FAILED(lock.Lock())) return; m_Sessions.RemoveAll(); } HRESULT Initialize(SERVICEIMPL_INITPARAM_TYPE, IServiceProvider*, unsigned __int64 dwNewTimeout) throw() { m_dwTimeout = dwNewTimeout; return m_CritSec.Init(); } typedef CAtlMap, CElementTraitsBase > SessMapType; SessMapType m_Sessions; // map for holding sessions in memory CComCriticalSection m_CritSec; // for synchronizing access to map typedef CComCritSecLock CSLockType; CSessionNameGenerator m_SessionNameGenerator; // Object for generating session names unsigned __int64 m_dwTimeout; }; // CMemSessionServiceImpl // // CSessionStateService // This class implements the session state service which can be // exposed to request handlers. // // Template Parameters: // CMonitorClass: Provides periodic sweeping services for the session service class. // TServiceImplClass: The class that actually implements the methods of the // ISessionStateService and ISessionStateControl interfaces. template class CSessionStateService : public ISessionStateService, public ISessionStateControl, public IWorkerThreadClient, public CComObjectRootEx { protected: CMonitorClass m_Monitor; HANDLE m_hTimer; CComPtr m_spServiceProvider; TServiceImplClass m_SessionServiceImpl; public: // Construction/Initialization CSessionStateService() throw() : m_hTimer(NULL) { } ~CSessionStateService() throw() { ATLASSERT(m_hTimer == NULL); } BEGIN_COM_MAP(CSessionStateService) COM_INTERFACE_ENTRY(ISessionStateService) COM_INTERFACE_ENTRY(ISessionStateControl) END_COM_MAP() // ISessionStateServie methods STDMETHOD(CreateNewSession)(LPSTR szNewID, DWORD *pdwSize, ISession** ppSession) throw() { return m_SessionServiceImpl.CreateNewSession(szNewID, pdwSize, ppSession); } STDMETHOD(GetSession)(LPCSTR szID, ISession **ppSession) throw() { return m_SessionServiceImpl.GetSession(szID, ppSession); } STDMETHOD(CloseSession)(LPCSTR szSessionID) throw() { return m_SessionServiceImpl.CloseSession(szSessionID); } STDMETHOD(SetSessionTimeout)(unsigned __int64 nTimeout) throw() { return m_SessionServiceImpl.SetSessionTimeout(nTimeout); } STDMETHOD(GetSessionTimeout)(unsigned __int64 *pnTimeout) throw() { return m_SessionServiceImpl.GetSessionTimeout(pnTimeout); } STDMETHOD(GetSessionCount)(DWORD *pnSessionCount) throw() { return m_SessionServiceImpl.GetSessionCount(pnSessionCount); } void SweepSessions() throw() { m_SessionServiceImpl.SweepSessions(); } void ReleaseAllSessions() throw() { m_SessionServiceImpl.ReleaseAllSessions(); } HRESULT Initialize( IServiceProvider *pServiceProvider = NULL, unsigned __int64 dwTimeout = ATL_SESSION_TIMEOUT, TServiceImplClass::SERVICEIMPL_INITPARAM_TYPE pInitData = NULL) throw() { HRESULT hr = S_OK; if (pServiceProvider) m_spServiceProvider = pServiceProvider; hr = m_SessionServiceImpl.Initialize(pInitData, pServiceProvider, dwTimeout); return hr; } template HRESULT Initialize( CWorkerThread *pWorker, IServiceProvider *pServiceProvider = NULL, unsigned __int64 dwTimeout = ATL_SESSION_TIMEOUT, TServiceImplClass::SERVICEIMPL_INITPARAM_TYPE pInitData = NULL) throw() { if (!pWorker) return E_INVALIDARG; HRESULT hr = Initialize(pServiceProvider, dwTimeout, pInitData); if (hr == S_OK) { hr = m_Monitor.Initialize(pWorker); if (hr == S_OK) { //sweep every 500ms hr = m_Monitor.AddTimer(ATL_SESSION_SWEEPER_TIMEOUT, this, 0, &m_hTimer); } } return hr; } HRESULT Execute(DWORD_PTR /*dwParam*/, HANDLE /*hObject*/) throw() { SweepSessions(); return S_OK; } HRESULT CloseHandle(HANDLE hHandle) throw() { ::CloseHandle(hHandle); m_hTimer = NULL; return S_OK; } void Shutdown() throw() { if (m_hTimer) { m_Monitor.RemoveHandle(m_hTimer); m_hTimer = NULL; } ReleaseAllSessions(); } }; // CSessionStateService } // namespace ATL #pragma warning(pop) #endif // __ATLSESSION_H__