mirror of https://github.com/tongzx/nt5src
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
2359 lines
52 KiB
2359 lines
52 KiB
// This is a part of the Active Template Library.
|
|
// Copyright (C) 1996-2001 Microsoft Corporation
|
|
// All rights reserved.
|
|
//
|
|
// This source code is only intended as a supplement to the
|
|
// Active Template Library Reference and related
|
|
// electronic documentation provided with the library.
|
|
// See these sources for detailed information regarding the
|
|
// Active Template Library product.
|
|
|
|
#ifndef __ATLSESSION_H__
|
|
#define __ATLSESSION_H__
|
|
|
|
#pragma once
|
|
#pragma warning(push)
|
|
#pragma warning(disable: 4702) // unreachable code
|
|
|
|
#include <atldbcli.h>
|
|
#include <atlcom.h>
|
|
#include <atlstr.h>
|
|
#include <stdio.h>
|
|
#include <atlcoll.h>
|
|
#include <atltime.h>
|
|
#include <atlcrypt.h>
|
|
#include <atlenc.h>
|
|
#include <atlutil.h>
|
|
#include <atlcache.h>
|
|
#include <atlspriv.h>
|
|
#include <atlsiface.h>
|
|
|
|
#ifndef SESSION_KEY_LENGTH
|
|
#define SESSION_KEY_LENGTH 37
|
|
#endif
|
|
|
|
#ifndef MAX_SESSION_KEY_LEN
|
|
#define MAX_SESSION_KEY_LEN 128
|
|
#endif
|
|
|
|
#ifndef MAX_VARIABLE_NAME_LENGTH
|
|
#define MAX_VARIABLE_NAME_LENGTH 50
|
|
#endif
|
|
|
|
#ifndef MAX_VARIABLE_VALUE_LENGTH
|
|
#define MAX_VARIABLE_VALUE_LENGTH 128
|
|
#endif
|
|
|
|
#ifndef DEFAULT_SQL_LEN
|
|
#define DEFAULT_SQL_LEN 1024
|
|
#endif
|
|
|
|
#ifndef MAX_CONNECTION_STRING_LEN
|
|
#define MAX_CONNECTION_STRING_LEN 2048
|
|
#endif
|
|
|
|
#ifndef SESSION_COOKIE_NAME
|
|
#define SESSION_COOKIE_NAME "SESSIONID"
|
|
#endif
|
|
|
|
#ifndef ATL_SESSION_TIMEOUT
|
|
#define ATL_SESSION_TIMEOUT 600000 //10 min
|
|
#endif
|
|
|
|
#ifndef ATL_SESSION_SWEEPER_TIMEOUT
|
|
#define ATL_SESSION_SWEEPER_TIMEOUT 1000 // 1sec
|
|
#endif
|
|
|
|
#define INVALID_DB_SESSION_POS 0x0
|
|
#define ATL_DBSESSION_ID _T("__ATL_SESSION_DB_CONNECTION")
|
|
|
|
|
|
namespace ATL {
|
|
|
|
// CSessionNameGenerator
|
|
// This is a helper class that generates random data for session key
|
|
// names. This class tries to use the CryptoApi to generate random
|
|
// bytes for the session key name. If the CryptoApi isn't available
|
|
// then the CRT rand() is used to generate the random bytes. This
|
|
// class's GetNewSessionName member function is used to actually
|
|
// generate the session name.
|
|
class CSessionNameGenerator :
|
|
public CCryptProv
|
|
{
|
|
public:
|
|
bool m_bCryptNotAvailable;
|
|
enum {MIN_SESSION_KEY_LEN=5};
|
|
|
|
CSessionNameGenerator() throw() :
|
|
m_bCryptNotAvailable(false)
|
|
{
|
|
// Note that the crypto api is being
|
|
// initialized with no private key
|
|
// information
|
|
HRESULT hr = InitVerifyContext();
|
|
m_bCryptNotAvailable = FAILED(hr) ? true : false;
|
|
}
|
|
|
|
// This function creates a new session name and base64 encodes it.
|
|
// The base64 encoding algorithm used needs at least MIN_SESSION_KEY_LEN
|
|
// bytes to work correctly. Since we stack allocate the temporary
|
|
// buffer that holds the key name, the buffer must be less than or equal to
|
|
// the MAX_SESSION_KEY_LEN in size.
|
|
HRESULT GetNewSessionName(LPSTR szNewID, DWORD *pdwSize) throw()
|
|
{
|
|
HRESULT hr = E_FAIL;
|
|
|
|
if (!pdwSize)
|
|
return E_POINTER;
|
|
|
|
if (*pdwSize < MIN_SESSION_KEY_LEN ||
|
|
*pdwSize > MAX_SESSION_KEY_LEN)
|
|
return E_INVALIDARG;
|
|
|
|
if (!szNewID)
|
|
return E_POINTER;
|
|
|
|
BYTE key[MAX_SESSION_KEY_LEN] = {0x0};
|
|
|
|
|
|
// calculate the number of bytes that will fit in the
|
|
// buffer we've been passed
|
|
DWORD dwDataSize = CalcMaxInputSize(*pdwSize);
|
|
|
|
if (dwDataSize && *pdwSize >= (DWORD)(Base64EncodeGetRequiredLength(dwDataSize,
|
|
ATL_BASE64_FLAG_NOCRLF)))
|
|
{
|
|
int dwKeySize = *pdwSize;
|
|
hr = GenerateRandomName(key, dwDataSize);
|
|
if (SUCCEEDED(hr))
|
|
{
|
|
if( Base64Encode(key,
|
|
dwDataSize,
|
|
szNewID,
|
|
&dwKeySize,
|
|
ATL_BASE64_FLAG_NOCRLF) )
|
|
{
|
|
//null terminate
|
|
szNewID[dwKeySize]=0;
|
|
*pdwSize = dwKeySize+1;
|
|
}
|
|
else
|
|
hr = E_FAIL;
|
|
}
|
|
else
|
|
{
|
|
*pdwSize = (DWORD)(Base64EncodeGetRequiredLength(dwDataSize,
|
|
ATL_BASE64_FLAG_NOCRLF));
|
|
return E_OUTOFMEMORY;
|
|
}
|
|
}
|
|
return hr;
|
|
}
|
|
|
|
DWORD CalcMaxInputSize(DWORD nOutputSize) throw()
|
|
{
|
|
if (nOutputSize < (DWORD)MIN_SESSION_KEY_LEN)
|
|
return 0;
|
|
// subtract one from the output size to make room
|
|
// for the NULL terminator in the output then
|
|
// calculate the biggest number of input bytes that
|
|
// when base64 encoded will fit in a buffer of size
|
|
// nOutputSize (including base64 padding)
|
|
int nInputSize = ((nOutputSize-1)*3)/4;
|
|
int factor = ((nInputSize*4)/3)%4;
|
|
if (factor)
|
|
nInputSize -= factor;
|
|
return nInputSize;
|
|
}
|
|
|
|
|
|
HRESULT GenerateRandomName(BYTE *pBuff, DWORD dwBuffSize) throw()
|
|
{
|
|
if (!pBuff)
|
|
return E_POINTER;
|
|
|
|
if (!dwBuffSize)
|
|
return E_UNEXPECTED;
|
|
|
|
if (!m_bCryptNotAvailable && GetHandle())
|
|
{
|
|
// Use the crypto api to generate random data.
|
|
return GenRandom(dwBuffSize, pBuff);
|
|
}
|
|
|
|
// CryptoApi isn't available so we generate
|
|
// random data using rand. We seed the random
|
|
// number generator with a seed that is a combination
|
|
// of bytes from an arbitrary number and the system
|
|
// time which changes every millisecond so it will
|
|
// be different for every call to this function.
|
|
FILETIME ft;
|
|
GetSystemTimeAsFileTime(&ft);
|
|
static DWORD dwVal = 0x21;
|
|
DWORD dwSeed = (dwVal++ << 0x18) | (ft.dwLowDateTime & 0x00ffff00) | dwVal++ & 0x000000ff;
|
|
srand(dwSeed);
|
|
BYTE *pCurr = pBuff;
|
|
// fill buffer with random bytes
|
|
for (int i=0; i < (int)dwBuffSize; i++)
|
|
{
|
|
*pCurr = (BYTE) (rand() & 0x000000ff);
|
|
pCurr++;
|
|
}
|
|
return S_OK;
|
|
}
|
|
};
|
|
|
|
|
|
//
|
|
// CDefaultQueryClass
|
|
// returns Query strings for use in SQL queries used
|
|
// by the database persisted session service.
|
|
class CDefaultQueryClass
|
|
{
|
|
public:
|
|
LPCTSTR GetSessionRefDelete() throw()
|
|
{
|
|
return _T("DELETE FROM SessionReferences ")
|
|
_T("WHERE SessionID=? AND RefCount <= 0 ")
|
|
_T("AND DATEDIFF(millisecond, LastAccess, getdate()) > TimeoutMs");
|
|
}
|
|
|
|
LPCTSTR GetSessionRefIsExpired() throw()
|
|
{
|
|
return _T("SELECT SessionID FROM SessionReferences ")
|
|
_T("WHERE (SessionID=?) AND (DATEDIFF(millisecond, LastAccess, getdate()) > TimeoutMs)");
|
|
}
|
|
|
|
LPCTSTR GetSessionRefDeleteFinal() throw()
|
|
{
|
|
return _T("DELETE FROM SessionReferences ")
|
|
_T("WHERE SessionID=?");
|
|
}
|
|
|
|
LPCTSTR GetSessionRefCreate() throw()
|
|
{
|
|
return _T("INSERT INTO SessionReferences ")
|
|
_T("(SessionID, LastAccess, RefCount, TimeoutMs) ")
|
|
_T("VALUES (?, getdate(), 1, ?)");
|
|
}
|
|
|
|
LPCTSTR GetSessionRefUpdateTimeout() throw()
|
|
{
|
|
return _T("UPDATE SessionReferences ")
|
|
_T("SET TimeoutMs=? WHERE SessionID=?");
|
|
}
|
|
|
|
LPCTSTR GetSessionRefAddRef() throw()
|
|
{
|
|
return _T("UPDATE SessionReferences ")
|
|
_T("SET RefCount=RefCount+1, ")
|
|
_T("LastAccess=getdate() ")
|
|
_T("WHERE SessionID=?");
|
|
}
|
|
|
|
LPCTSTR GetSessionRefRemoveRef() throw()
|
|
{
|
|
return _T("UPDATE SessionReferences ")
|
|
_T("SET RefCount=RefCount-1, ")
|
|
_T("LastAccess=getdate() ")
|
|
_T("WHERE SessionID=?");
|
|
}
|
|
|
|
LPCTSTR GetSessionRefAccess() throw()
|
|
{
|
|
return _T("UPDATE SessionReferences ")
|
|
_T("SET LastAccess=getdate() ")
|
|
_T("WHERE SessionID=?");
|
|
}
|
|
|
|
LPCTSTR GetSessionRefSelect() throw()
|
|
{
|
|
return _T("SELECT * FROM SessionReferences ")
|
|
_T("WHERE SessionID=?");
|
|
}
|
|
|
|
LPCTSTR GetSessionRefGetCount() throw()
|
|
{
|
|
return _T("SELECT COUNT(*) FROM SessionReferences");
|
|
}
|
|
|
|
|
|
LPCTSTR GetSessionVarCount() throw()
|
|
{
|
|
return _T("SELECT COUNT(*) FROM SessionVariables WHERE SessionID=?");
|
|
}
|
|
|
|
LPCTSTR GetSessionVarInsert() throw()
|
|
{
|
|
return _T("INSERT INTO SessionVariables ")
|
|
_T("(VariableValue, SessionID, VariableName) ")
|
|
_T("VALUES (?, ?, ?)");
|
|
}
|
|
|
|
LPCTSTR GetSessionVarUpdate() throw()
|
|
{
|
|
return _T("UPDATE SessionVariables ")
|
|
_T("SET VariableValue=? ")
|
|
_T("WHERE SessionID=? AND VariableName=?");
|
|
}
|
|
|
|
LPCTSTR GetSessionVarDeleteVar() throw()
|
|
{
|
|
return _T("DELETE FROM SessionVariables ")
|
|
_T("WHERE SessionID=? AND VariableName=?");
|
|
}
|
|
|
|
LPCTSTR GetSessionVarDeleteAllVars() throw()
|
|
{
|
|
return _T("DELETE FROM SessionVariables WHERE (SessionID=?)");
|
|
}
|
|
|
|
LPCTSTR GetSessionVarSelectVar()throw()
|
|
{
|
|
return _T("SELECT SessionID, VariableName, VariableValue ")
|
|
_T("FROM SessionVariables ")
|
|
_T("WHERE SessionID=? AND VariableName=?");
|
|
}
|
|
|
|
LPCTSTR GetSessionVarSelectAllVars() throw()
|
|
{
|
|
return _T("SELECT SessionID, VariableName, VariableValue ")
|
|
_T("FROM SessionVariables ")
|
|
_T("WHERE SessionID=?");
|
|
}
|
|
|
|
LPCTSTR GetSessionReferencesSet() throw()
|
|
{
|
|
return _T("UPDATE SessionReferences SET TimeoutMs=?");
|
|
}
|
|
};
|
|
|
|
|
|
// Contains the data for the session variable accessors
|
|
class CSessionDataBase
|
|
{
|
|
public:
|
|
TCHAR m_szSessionID[MAX_SESSION_KEY_LEN];
|
|
TCHAR m_VariableName[MAX_VARIABLE_NAME_LENGTH];
|
|
BYTE m_VariableValue[MAX_VARIABLE_VALUE_LENGTH];
|
|
DWORD m_VariableLen;
|
|
CSessionDataBase() throw()
|
|
{
|
|
m_szSessionID[0] = '\0';
|
|
m_VariableName[0] = '\0';
|
|
m_VariableValue[0] = '\0';
|
|
m_VariableLen = 0;
|
|
}
|
|
HRESULT Assign(LPCTSTR szSessionID, LPCTSTR szVarName, VARIANT *pVal) throw()
|
|
{
|
|
HRESULT hr = S_OK;
|
|
CVariantStream stream;
|
|
if ( szSessionID )
|
|
{
|
|
if ( _tcslen(szSessionID)< MAX_SESSION_KEY_LEN)
|
|
_tcscpy(m_szSessionID, szSessionID);
|
|
else
|
|
hr = E_OUTOFMEMORY;
|
|
}
|
|
else
|
|
return E_INVALIDARG;
|
|
|
|
if (szVarName)
|
|
if ( _tcslen(szVarName) < MAX_VARIABLE_NAME_LENGTH)
|
|
_tcscpy(m_VariableName, szVarName);
|
|
else
|
|
hr = E_OUTOFMEMORY;
|
|
|
|
if (pVal)
|
|
{
|
|
hr = stream.InsertVariant(pVal);
|
|
if (hr == S_OK)
|
|
{
|
|
BYTE *pBytes = stream.m_stream;
|
|
size_t size = stream.GetVariantSize();
|
|
if (pBytes && size && size < MAX_VARIABLE_VALUE_LENGTH)
|
|
{
|
|
memcpy(m_VariableValue, pBytes, stream.GetVariantSize());
|
|
m_VariableLen = (DWORD)size;
|
|
}
|
|
else
|
|
hr = E_UNEXPECTED;
|
|
}
|
|
}
|
|
|
|
return hr;
|
|
}
|
|
};
|
|
|
|
// Use to select a session variable given the name
|
|
// of a session and the name of a variable.
|
|
class CSessionDataSelector : public CSessionDataBase
|
|
{
|
|
public:
|
|
BEGIN_COLUMN_MAP(CSessionDataSelector)
|
|
COLUMN_ENTRY(1, m_szSessionID)
|
|
COLUMN_ENTRY(2, m_VariableName)
|
|
COLUMN_ENTRY_LENGTH(3, m_VariableValue, m_VariableLen)
|
|
END_COLUMN_MAP()
|
|
BEGIN_PARAM_MAP(CSessionDataSelector)
|
|
SET_PARAM_TYPE(DBPARAMIO_INPUT)
|
|
COLUMN_ENTRY(1, m_szSessionID)
|
|
COLUMN_ENTRY(2, m_VariableName)
|
|
END_PARAM_MAP()
|
|
};
|
|
|
|
// Use to select all session variables given the name of
|
|
// of a session.
|
|
class CAllSessionDataSelector : public CSessionDataBase
|
|
{
|
|
public:
|
|
BEGIN_COLUMN_MAP(CAllSessionDataSelector)
|
|
COLUMN_ENTRY(1, m_szSessionID)
|
|
COLUMN_ENTRY(2, m_VariableName)
|
|
COLUMN_ENTRY_LENGTH(3, m_VariableValue, m_VariableLen)
|
|
END_COLUMN_MAP()
|
|
BEGIN_PARAM_MAP(CAllSessionDataSelector)
|
|
SET_PARAM_TYPE(DBPARAMIO_INPUT)
|
|
COLUMN_ENTRY(1, m_szSessionID)
|
|
END_PARAM_MAP()
|
|
};
|
|
|
|
// Use to update the value of a session variable
|
|
class CSessionDataUpdator : public CSessionDataBase
|
|
{
|
|
public:
|
|
BEGIN_PARAM_MAP(CSessionDataUpdator)
|
|
SET_PARAM_TYPE(DBPARAMIO_INPUT)
|
|
COLUMN_ENTRY_LENGTH(1, m_VariableValue, m_VariableLen)
|
|
COLUMN_ENTRY(2, m_szSessionID)
|
|
COLUMN_ENTRY(3, m_VariableName)
|
|
END_PARAM_MAP()
|
|
};
|
|
|
|
// Use to delete a session variable given the
|
|
// session name and the name of the variable
|
|
class CSessionDataDeletor
|
|
{
|
|
public:
|
|
CSessionDataDeletor()
|
|
{
|
|
m_szSessionID[0] = '\0';
|
|
m_VariableName[0] = '\0';
|
|
}
|
|
|
|
TCHAR m_szSessionID[MAX_SESSION_KEY_LEN];
|
|
TCHAR m_VariableName[MAX_VARIABLE_NAME_LENGTH];
|
|
HRESULT Assign(LPCTSTR szSessionID, LPCTSTR szVarName) throw()
|
|
{
|
|
if (szSessionID)
|
|
{
|
|
if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
|
|
_tcscpy(m_szSessionID, szSessionID);
|
|
else
|
|
return E_OUTOFMEMORY;
|
|
}
|
|
|
|
if (szVarName)
|
|
{
|
|
if(_tcslen(szVarName) < MAX_VARIABLE_NAME_LENGTH)
|
|
_tcscpy(m_VariableName, szVarName);
|
|
else
|
|
return E_OUTOFMEMORY;
|
|
}
|
|
return S_OK;
|
|
}
|
|
|
|
BEGIN_PARAM_MAP(CSessionDataDeletor)
|
|
SET_PARAM_TYPE(DBPARAMIO_INPUT)
|
|
COLUMN_ENTRY(1, m_szSessionID)
|
|
COLUMN_ENTRY(2, m_VariableName)
|
|
END_PARAM_MAP()
|
|
};
|
|
|
|
class CSessionDataDeleteAll
|
|
{
|
|
public:
|
|
TCHAR m_szSessionID[MAX_SESSION_KEY_LEN];
|
|
HRESULT Assign(LPCTSTR szSessionID) throw()
|
|
{
|
|
if (!szSessionID)
|
|
return E_INVALIDARG;
|
|
|
|
if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
|
|
_tcscpy(m_szSessionID, szSessionID);
|
|
else
|
|
return E_OUTOFMEMORY;
|
|
|
|
return S_OK;
|
|
}
|
|
|
|
BEGIN_PARAM_MAP(CSessionDataDeleteAll)
|
|
SET_PARAM_TYPE(DBPARAMIO_INPUT)
|
|
COLUMN_ENTRY(1, m_szSessionID)
|
|
END_PARAM_MAP()
|
|
};
|
|
|
|
// Used for retrieving the count of session variables for
|
|
// a given session ID.
|
|
class CCountAccessor
|
|
{
|
|
public:
|
|
LONG m_nCount;
|
|
TCHAR m_szSessionID[MAX_SESSION_KEY_LEN];
|
|
CCountAccessor() throw()
|
|
{
|
|
m_szSessionID[0] = '\0';
|
|
m_nCount = 0;
|
|
}
|
|
|
|
HRESULT Assign(LPCTSTR szSessionID) throw()
|
|
{
|
|
if (!szSessionID)
|
|
return E_INVALIDARG;
|
|
|
|
if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
|
|
_tcscpy(m_szSessionID, szSessionID);
|
|
else
|
|
return E_OUTOFMEMORY;
|
|
|
|
return S_OK;
|
|
}
|
|
|
|
BEGIN_COLUMN_MAP(CCountAccessor)
|
|
COLUMN_ENTRY(1, m_nCount)
|
|
END_COLUMN_MAP()
|
|
BEGIN_PARAM_MAP(CCountAccessor)
|
|
SET_PARAM_TYPE(DBPARAMIO_INPUT)
|
|
COLUMN_ENTRY(1, m_szSessionID)
|
|
END_PARAM_MAP()
|
|
};
|
|
|
|
|
|
// Used for updating entries in the session
|
|
// references table, given a session ID
|
|
class CSessionRefUpdator
|
|
{
|
|
public:
|
|
TCHAR m_SessionID[MAX_SESSION_KEY_LEN];
|
|
HRESULT Assign(LPCTSTR szSessionID) throw()
|
|
{
|
|
if (!szSessionID)
|
|
return E_INVALIDARG;
|
|
if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
|
|
_tcscpy(m_SessionID, szSessionID);
|
|
else
|
|
return E_OUTOFMEMORY;
|
|
return S_OK;
|
|
}
|
|
BEGIN_PARAM_MAP(CSessionRefUpdator)
|
|
SET_PARAM_TYPE(DBPARAMIO_INPUT)
|
|
COLUMN_ENTRY(1, m_SessionID)
|
|
END_PARAM_MAP()
|
|
};
|
|
|
|
class CSessionRefIsExpired
|
|
{
|
|
public:
|
|
TCHAR m_SessionID[MAX_SESSION_KEY_LEN];
|
|
TCHAR m_SessionIDOut[MAX_SESSION_KEY_LEN];
|
|
HRESULT Assign(LPCTSTR szSessionID) throw()
|
|
{
|
|
m_SessionIDOut[0]=0;
|
|
if (!szSessionID)
|
|
return E_INVALIDARG;
|
|
if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
|
|
_tcscpy(m_SessionID, szSessionID);
|
|
else
|
|
return E_OUTOFMEMORY;
|
|
return S_OK;
|
|
}
|
|
BEGIN_COLUMN_MAP(CSessionRefIsExpired)
|
|
COLUMN_ENTRY(1, m_SessionIDOut)
|
|
END_COLUMN_MAP()
|
|
BEGIN_PARAM_MAP(CSessionRefIsExpired)
|
|
SET_PARAM_TYPE(DBPARAMIO_INPUT)
|
|
COLUMN_ENTRY(1, m_SessionID)
|
|
END_PARAM_MAP()
|
|
};
|
|
|
|
class CSetAllTimeouts
|
|
{
|
|
public:
|
|
unsigned __int64 m_dwNewTimeout;
|
|
HRESULT Assign(unsigned __int64 dwNewValue)
|
|
{
|
|
m_dwNewTimeout = dwNewValue;
|
|
return S_OK;
|
|
}
|
|
BEGIN_PARAM_MAP(CSetAllTimeouts)
|
|
SET_PARAM_TYPE(DBPARAMIO_INPUT)
|
|
COLUMN_ENTRY(1, m_dwNewTimeout)
|
|
END_PARAM_MAP()
|
|
};
|
|
|
|
class CSessionRefUpdateTimeout
|
|
{
|
|
public:
|
|
TCHAR m_SessionID[MAX_SESSION_KEY_LEN];
|
|
unsigned __int64 m_nNewTimeout;
|
|
HRESULT Assign(LPCTSTR szSessionID, unsigned __int64 nNewTimeout) throw()
|
|
{
|
|
if (!szSessionID)
|
|
return E_INVALIDARG;
|
|
|
|
if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
|
|
_tcscpy(m_SessionID, szSessionID);
|
|
else
|
|
return E_OUTOFMEMORY;
|
|
|
|
m_nNewTimeout = nNewTimeout;
|
|
|
|
return S_OK;
|
|
}
|
|
|
|
BEGIN_PARAM_MAP(CSessionRefUpdateTimeout)
|
|
SET_PARAM_TYPE(DBPARAMIO_INPUT)
|
|
COLUMN_ENTRY(1, m_nNewTimeout)
|
|
COLUMN_ENTRY(2, m_SessionID)
|
|
END_PARAM_MAP()
|
|
};
|
|
|
|
class CSessionRefSelector
|
|
{
|
|
public:
|
|
TCHAR m_SessionID[MAX_SESSION_KEY_LEN];
|
|
int m_RefCount;
|
|
HRESULT Assign(LPCTSTR szSessionID) throw()
|
|
{
|
|
if (!szSessionID)
|
|
return E_INVALIDARG;
|
|
if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
|
|
_tcscpy(m_SessionID, szSessionID);
|
|
else
|
|
return E_OUTOFMEMORY;
|
|
return S_OK;
|
|
}
|
|
BEGIN_COLUMN_MAP(CSessionRefSelector)
|
|
COLUMN_ENTRY(1, m_SessionID)
|
|
COLUMN_ENTRY(3, m_RefCount)
|
|
END_COLUMN_MAP()
|
|
BEGIN_PARAM_MAP(CSessionRefSelector)
|
|
SET_PARAM_TYPE(DBPARAMIO_INPUT)
|
|
COLUMN_ENTRY(1, m_SessionID)
|
|
END_PARAM_MAP()
|
|
};
|
|
|
|
class CSessionRefCount
|
|
{
|
|
public:
|
|
LONG m_nCount;
|
|
BEGIN_COLUMN_MAP(CSessionRefCount)
|
|
COLUMN_ENTRY(1, m_nCount)
|
|
END_COLUMN_MAP()
|
|
};
|
|
|
|
// Used for creating new entries in the session
|
|
// references table.
|
|
class CSessionRefCreator
|
|
{
|
|
public:
|
|
TCHAR m_SessionID[MAX_SESSION_KEY_LEN];
|
|
unsigned __int64 m_TimeoutMs;
|
|
HRESULT Assign(LPCTSTR szSessionID, unsigned __int64 timeout) throw()
|
|
{
|
|
if (!szSessionID)
|
|
return E_INVALIDARG;
|
|
if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
|
|
{
|
|
_tcscpy(m_SessionID, szSessionID);
|
|
m_TimeoutMs = timeout;
|
|
}
|
|
else
|
|
return E_OUTOFMEMORY;
|
|
return S_OK;
|
|
}
|
|
BEGIN_PARAM_MAP(CSessionRefCreator)
|
|
SET_PARAM_TYPE(DBPARAMIO_INPUT)
|
|
COLUMN_ENTRY(1, m_SessionID)
|
|
COLUMN_ENTRY(2, m_TimeoutMs)
|
|
END_PARAM_MAP()
|
|
};
|
|
|
|
|
|
// CDBSession
|
|
// This session persistance class persists session variables to
|
|
// an OLEDB datasource. The following table gives a general description
|
|
// of the table schema for the tables this class uses.
|
|
//
|
|
// TableName: SessionVariables
|
|
// Column Name Type Description
|
|
// 1 SessionID char[MAX_SESSION_KEY_LEN] Session Key name
|
|
// 2 VariableName char[MAX_VARIABLE_NAME_LENGTH] Variable Name
|
|
// 3 VariableValue varbinary[MAX_VARIABLE_VALUE_LENGTH] Variable Value
|
|
|
|
//
|
|
// TableName: SessionReferences
|
|
// Column Name Type Description
|
|
// 1 SessionID char[MAX_SESSION_KEY_LEN] Session Key Name.
|
|
// 2 LastAccess datetime Date and time of last access to this session.
|
|
// 3 RefCount int Current references on this session.
|
|
// 4 TimeoutMS int Timeout value for the session in milli seconds
|
|
|
|
typedef bool (*PFN_GETPROVIDERINFO)(DWORD_PTR, wchar_t **);
|
|
|
|
template <class QueryClass=CDefaultQueryClass>
|
|
class CDBSession:
|
|
public ISession,
|
|
public CComObjectRootEx<CComGlobalsThreadModel>
|
|
|
|
{
|
|
typedef CCommand<CAccessor<CAllSessionDataSelector> > iterator_accessor;
|
|
public:
|
|
typedef QueryClass DBQUERYCLASS_TYPE;
|
|
BEGIN_COM_MAP(CDBSession)
|
|
COM_INTERFACE_ENTRY(ISession)
|
|
END_COM_MAP()
|
|
|
|
CDBSession() throw():
|
|
m_dwTimeout(ATL_SESSION_TIMEOUT)
|
|
{
|
|
m_szSessionName[0] = '\0';
|
|
}
|
|
|
|
~CDBSession() throw()
|
|
{
|
|
}
|
|
|
|
void FinalRelease()throw()
|
|
{
|
|
SessionUnlock();
|
|
}
|
|
|
|
STDMETHOD(SetVariable)(LPCSTR szName, VARIANT Val) throw()
|
|
{
|
|
HRESULT hr = E_FAIL;
|
|
if (!szName)
|
|
return E_INVALIDARG;
|
|
|
|
// Get the data connection for this thread.
|
|
CDataConnection dataconn;
|
|
hr = GetSessionConnection(&dataconn, m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// Update the last access time for this session
|
|
hr = Access();
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// Allocate an updator command and fill out it's input parameters.
|
|
CCommand<CAccessor<CSessionDataUpdator> > command;
|
|
_ATLTRY
|
|
{
|
|
CA2CT name(szName);
|
|
hr = command.Assign(m_szSessionName, name, &Val);
|
|
}
|
|
_ATLCATCHALL()
|
|
{
|
|
hr = E_OUTOFMEMORY;
|
|
}
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// Try an update. Update will fail if the variable is not already there.
|
|
LONG nRows = 0;
|
|
|
|
hr = command.Open(dataconn,
|
|
m_QueryObj.GetSessionVarUpdate(),
|
|
NULL, &nRows, DBGUID_DEFAULT, false);
|
|
if (hr == S_OK && nRows <= 0)
|
|
hr = E_UNEXPECTED;
|
|
if (hr != S_OK)
|
|
{
|
|
// Try an insert
|
|
hr = command.Open(dataconn, m_QueryObj.GetSessionVarInsert(), NULL, &nRows, DBGUID_DEFAULT, false);
|
|
if (hr == S_OK && nRows <=0)
|
|
hr = E_UNEXPECTED;
|
|
}
|
|
|
|
return hr;
|
|
}
|
|
|
|
// Warning: For string data types, depending on the configuration of
|
|
// your database, strings might be returned with trailing white space.
|
|
STDMETHOD(GetVariable)(LPCSTR szName, VARIANT *pVal) throw()
|
|
{
|
|
HRESULT hr = E_FAIL;
|
|
if (!szName)
|
|
return E_INVALIDARG;
|
|
if (pVal)
|
|
VariantClear(pVal);
|
|
else
|
|
return E_POINTER;
|
|
|
|
// Get the data connection for this thread
|
|
CDataConnection dataconn;
|
|
hr = GetSessionConnection(&dataconn, m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// Update the last access time for this session
|
|
hr = Access();
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// Allocate a command a fill out it's input parameters.
|
|
CCommand<CAccessor<CSessionDataSelector> > command;
|
|
_ATLTRY
|
|
{
|
|
CA2CT name(szName);
|
|
hr = command.Assign(m_szSessionName, name, NULL);
|
|
}
|
|
_ATLCATCHALL()
|
|
{
|
|
hr = E_OUTOFMEMORY;
|
|
}
|
|
|
|
if (hr == S_OK)
|
|
{
|
|
hr = command.Open(dataconn, m_QueryObj.GetSessionVarSelectVar());
|
|
if (SUCCEEDED(hr))
|
|
{
|
|
if ( S_OK == (hr = command.MoveFirst()))
|
|
{
|
|
CStreamOnByteArray stream(command.m_VariableValue);
|
|
CComVariant vOut;
|
|
hr = vOut.ReadFromStream(static_cast<IStream*>(&stream));
|
|
if (hr == S_OK)
|
|
hr = vOut.Detach(pVal);
|
|
}
|
|
}
|
|
}
|
|
return hr;
|
|
}
|
|
|
|
STDMETHOD(RemoveVariable)(LPCSTR szName) throw()
|
|
{
|
|
HRESULT hr = E_FAIL;
|
|
if (!szName)
|
|
return E_INVALIDARG;
|
|
|
|
// Get the data connection for this thread.
|
|
CDataConnection dataconn;
|
|
hr = GetSessionConnection(&dataconn, m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// update the last access time for this session
|
|
hr = Access();
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// allocate a command and set it's input parameters
|
|
CCommand<CAccessor<CSessionDataDeletor> > command;
|
|
_ATLTRY
|
|
{
|
|
CA2CT name(szName);
|
|
hr = command.Assign(m_szSessionName, name);
|
|
}
|
|
_ATLCATCHALL()
|
|
{
|
|
return E_OUTOFMEMORY;
|
|
}
|
|
|
|
// execute the command
|
|
long nRows = 0;
|
|
if (hr == S_OK)
|
|
hr = command.Open(dataconn, m_QueryObj.GetSessionVarDeleteVar(),
|
|
NULL, &nRows, DBGUID_DEFAULT, false);
|
|
if (hr == S_OK && nRows <= 0)
|
|
hr = E_UNEXPECTED;
|
|
return hr;
|
|
}
|
|
|
|
// Gives the count of rows in the table for this session ID.
|
|
STDMETHOD(GetCount)(long *pnCount) throw()
|
|
{
|
|
HRESULT hr = S_OK;
|
|
if (pnCount)
|
|
*pnCount = 0;
|
|
else
|
|
return E_POINTER;
|
|
|
|
// Get the database connection for this thread.
|
|
CDataConnection dataconn;
|
|
hr = GetSessionConnection(&dataconn, m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
hr = Access();
|
|
if (hr != S_OK)
|
|
return hr;
|
|
CCommand<CAccessor<CCountAccessor> > command;
|
|
|
|
hr = command.Assign(m_szSessionName);
|
|
if (hr == S_OK)
|
|
{
|
|
hr = command.Open(dataconn, m_QueryObj.GetSessionVarCount());
|
|
if (hr == S_OK)
|
|
{
|
|
if (S_OK == (hr = command.MoveFirst()))
|
|
{
|
|
*pnCount = command.m_nCount;
|
|
hr = S_OK;
|
|
}
|
|
}
|
|
}
|
|
return hr;
|
|
}
|
|
|
|
STDMETHOD(RemoveAllVariables)() throw()
|
|
{
|
|
HRESULT hr = E_UNEXPECTED;
|
|
|
|
// Get the data connection for this thread.
|
|
CDataConnection dataconn;
|
|
hr = GetSessionConnection(&dataconn, m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
CCommand<CAccessor<CSessionDataDeleteAll> > command;
|
|
hr = command.Assign(m_szSessionName);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// delete all session variables
|
|
hr = command.Open(dataconn, m_QueryObj.GetSessionVarDeleteAllVars(), NULL, NULL, DBGUID_DEFAULT, false);
|
|
return hr;
|
|
}
|
|
|
|
// Iteration of variables works by taking a snapshot
|
|
// of the sessions at the point in time BeginVariableEnum
|
|
// is called, and then keeping an index variable that you use to
|
|
// move through the snapshot rowset. It is important to know
|
|
// that the handle returned in phEnum is not thread safe. It
|
|
// should only be used by the calling thread.
|
|
STDMETHOD(BeginVariableEnum)(HSESSIONENUM *phEnum, POSITION *pPOS) throw()
|
|
{
|
|
HRESULT hr = E_FAIL;
|
|
if (!pPOS)
|
|
return E_POINTER;
|
|
|
|
if (phEnum)
|
|
*phEnum = NULL;
|
|
else
|
|
return E_POINTER;
|
|
|
|
// Get the data connection for this thread.
|
|
CDataConnection dataconn;
|
|
hr = GetSessionConnection(&dataconn, m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// Update the last access time for this session.
|
|
hr = Access();
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// Allocate a new iterator accessor and initialize it's input parameters.
|
|
iterator_accessor *pIteratorAccessor = NULL;
|
|
ATLTRYALLOC(pIteratorAccessor = new iterator_accessor);
|
|
if (!pIteratorAccessor)
|
|
return E_OUTOFMEMORY;
|
|
|
|
hr = pIteratorAccessor->Assign(m_szSessionName, NULL, NULL);
|
|
if (hr == S_OK)
|
|
{
|
|
// execute the command and move to the first row of the recordset.
|
|
hr = pIteratorAccessor->Open(dataconn,
|
|
m_QueryObj.GetSessionVarSelectAllVars());
|
|
if (hr == S_OK)
|
|
{
|
|
hr = pIteratorAccessor->MoveFirst();
|
|
if (hr == S_OK)
|
|
{
|
|
*pPOS = (POSITION) INVALID_DB_SESSION_POS + 1;
|
|
*phEnum = reinterpret_cast<HSESSIONENUM>(pIteratorAccessor);
|
|
}
|
|
}
|
|
|
|
if (hr != S_OK)
|
|
{
|
|
*pPOS = INVALID_DB_SESSION_POS;
|
|
*phEnum = NULL;
|
|
delete pIteratorAccessor;
|
|
}
|
|
}
|
|
return hr;
|
|
}
|
|
|
|
// The values for hEnum and pPos must have been initialized in a previous
|
|
// call to BeginVariableEnum. On success, the out variant will hold the next
|
|
// variable
|
|
STDMETHOD(GetNextVariable)(HSESSIONENUM hEnum, POSITION *pPOS, LPSTR szName, DWORD dwLen, VARIANT *pVal) throw()
|
|
{
|
|
if (!pPOS)
|
|
return E_INVALIDARG;
|
|
|
|
if (pVal)
|
|
VariantInit(pVal);
|
|
else
|
|
return E_POINTER;
|
|
|
|
if (!hEnum)
|
|
return E_UNEXPECTED;
|
|
|
|
if (*pPOS <= INVALID_DB_SESSION_POS)
|
|
return E_UNEXPECTED;
|
|
|
|
iterator_accessor *pIteratorAccessor = reinterpret_cast<iterator_accessor*>(hEnum);
|
|
|
|
// update the last access time.
|
|
HRESULT hr = Access();
|
|
|
|
POSITION posCurrent = *pPOS;
|
|
|
|
if (szName)
|
|
{
|
|
// caller wants entry name
|
|
size_t nNameLenChars = _tcslen(pIteratorAccessor->m_VariableName);
|
|
if (dwLen > nNameLenChars)
|
|
{
|
|
_ATLTRY
|
|
{
|
|
CT2CA szVarName(pIteratorAccessor->m_VariableName);
|
|
strcpy(szName, szVarName);
|
|
}
|
|
_ATLCATCHALL()
|
|
{
|
|
hr = E_OUTOFMEMORY;
|
|
}
|
|
}
|
|
else
|
|
hr = E_OUTOFMEMORY; // buffer not big enough
|
|
}
|
|
|
|
if (hr == S_OK)
|
|
{
|
|
CStreamOnByteArray stream(pIteratorAccessor->m_VariableValue);
|
|
CComVariant vOut;
|
|
hr = vOut.ReadFromStream(static_cast<IStream*>(&stream));
|
|
if (hr == S_OK)
|
|
vOut.Detach(pVal);
|
|
else
|
|
return hr;
|
|
}
|
|
else
|
|
return hr;
|
|
|
|
hr = pIteratorAccessor->MoveNext();
|
|
*pPOS = ++posCurrent;
|
|
|
|
if (hr == DB_S_ENDOFROWSET)
|
|
{
|
|
// We're done iterating, reset everything
|
|
*pPOS = INVALID_DB_SESSION_POS;
|
|
hr = S_OK;
|
|
}
|
|
|
|
if (hr != S_OK)
|
|
{
|
|
VariantClear(pVal);
|
|
}
|
|
return hr;
|
|
}
|
|
|
|
// CloseEnum frees up any resources allocated by the iterator
|
|
STDMETHOD(CloseEnum)(HSESSIONENUM hEnum) throw()
|
|
{
|
|
iterator_accessor *pIteratorAccessor = reinterpret_cast<iterator_accessor*>(hEnum);
|
|
if (!pIteratorAccessor)
|
|
return E_INVALIDARG;
|
|
pIteratorAccessor->Close();
|
|
delete pIteratorAccessor;
|
|
return S_OK;
|
|
}
|
|
|
|
//
|
|
// Returns S_FALSE if it's not expired
|
|
// S_OK if it is expired and an error HRESULT
|
|
// if an error occurred.
|
|
STDMETHOD(IsExpired)() throw()
|
|
{
|
|
HRESULT hrRet = S_FALSE;
|
|
HRESULT hr = E_UNEXPECTED;
|
|
|
|
// Get the data connection for this thread.
|
|
CDataConnection dataconn;
|
|
hr = GetSessionConnection(&dataconn, m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
CCommand<CAccessor<CSessionRefIsExpired> > command;
|
|
hr = command.Assign(m_szSessionName);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
hr = command.Open(dataconn, m_QueryObj.GetSessionRefIsExpired(),
|
|
NULL, NULL, DBGUID_DEFAULT, true);
|
|
if (hr == S_OK)
|
|
{
|
|
if (S_OK == command.MoveFirst())
|
|
{
|
|
if (!_tcscmp(command.m_SessionIDOut, m_szSessionName))
|
|
hrRet = S_OK;
|
|
}
|
|
}
|
|
|
|
if (hr == S_OK)
|
|
return hrRet;
|
|
return hr;
|
|
}
|
|
|
|
STDMETHOD(SetTimeout)(unsigned __int64 dwNewTimeout) throw()
|
|
{
|
|
HRESULT hr = E_UNEXPECTED;
|
|
|
|
// Get the data connection for this thread.
|
|
CDataConnection dataconn;
|
|
hr = GetSessionConnection(&dataconn, m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// allocate a command and set it's input parameters
|
|
CCommand<CAccessor<CSessionRefUpdateTimeout> > command;
|
|
hr = command.Assign(m_szSessionName, dwNewTimeout);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
hr = command.Open(dataconn, m_QueryObj.GetSessionRefUpdateTimeout(),
|
|
NULL, NULL, DBGUID_DEFAULT, false);
|
|
|
|
return hr;
|
|
}
|
|
|
|
// SessionLock increments the session reference count for this session.
|
|
// If there is not a session by this name in the session references table,
|
|
// a new session entry is created in the the table.
|
|
HRESULT SessionLock() throw()
|
|
{
|
|
HRESULT hr = E_UNEXPECTED;
|
|
if (!m_szSessionName || m_szSessionName[0]==0)
|
|
return hr; // no session to lock.
|
|
|
|
// retrieve the data connection for this thread
|
|
CDataConnection dataconn;
|
|
hr = GetSessionConnection(&dataconn, m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// first try to update a session with this name
|
|
LONG nRows = 0;
|
|
CCommand<CAccessor<CSessionRefUpdator> > updator;
|
|
if (S_OK == updator.Assign(m_szSessionName))
|
|
{
|
|
if (S_OK != (hr = updator.Open(dataconn, m_QueryObj.GetSessionRefAddRef(),
|
|
NULL, &nRows, DBGUID_DEFAULT, false)) ||
|
|
nRows == 0)
|
|
{
|
|
// No session to update. Use the creator accessor
|
|
// to create a new session reference.
|
|
CCommand<CAccessor<CSessionRefCreator> > creator;
|
|
hr = creator.Assign(m_szSessionName, m_dwTimeout);
|
|
if (hr == S_OK)
|
|
hr = creator.Open(dataconn, m_QueryObj.GetSessionRefCreate(),
|
|
NULL, &nRows, DBGUID_DEFAULT, false);
|
|
}
|
|
}
|
|
|
|
// We should have been able to create or update a session.
|
|
ATLASSERT(nRows > 0);
|
|
if (hr == S_OK && nRows <= 0)
|
|
hr = E_UNEXPECTED;
|
|
|
|
return hr;
|
|
}
|
|
|
|
// SessionUnlock decrements the session RefCount for this session.
|
|
// Sessions cannot be removed from the database unless the session
|
|
// refcount is 0
|
|
HRESULT SessionUnlock() throw()
|
|
{
|
|
HRESULT hr = E_UNEXPECTED;
|
|
if (!m_szSessionName ||
|
|
m_szSessionName[0]==0)
|
|
return hr;
|
|
|
|
// get the data connection for this thread
|
|
CDataConnection dataconn;
|
|
hr = GetSessionConnection(&dataconn, m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// The session must exist at this point in order to unlock it
|
|
// so we can just use the session updator here.
|
|
LONG nRows = 0;
|
|
CCommand<CAccessor<CSessionRefUpdator> > updator;
|
|
hr = updator.Assign(m_szSessionName);
|
|
if (hr == S_OK)
|
|
{
|
|
hr = updator.Open( dataconn,
|
|
m_QueryObj.GetSessionRefRemoveRef(),
|
|
NULL,
|
|
&nRows,
|
|
DBGUID_DEFAULT,
|
|
false);
|
|
}
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// delete the session from the database if
|
|
// nobody else is using it and it's expired.
|
|
hr = FreeSession();
|
|
return hr;
|
|
}
|
|
|
|
// Access updates the last access time for the session. The access
|
|
// time for sessions is updated using the SQL GETDATE function on the
|
|
// database server so that all clients will be using the same clock
|
|
// to compare access times against.
|
|
HRESULT Access() throw()
|
|
{
|
|
HRESULT hr = E_UNEXPECTED;
|
|
|
|
if (!m_szSessionName ||
|
|
m_szSessionName[0]==0)
|
|
return hr; // no session to access
|
|
|
|
// get the data connection for this thread
|
|
CDataConnection dataconn;
|
|
hr = GetSessionConnection(&dataconn, m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// The session reference entry in the references table must
|
|
// be created prior to calling this function so we can just
|
|
// use an updator to update the current entry.
|
|
CCommand<CAccessor<CSessionRefUpdator> > updator;
|
|
|
|
LONG nRows = 0;
|
|
hr = updator.Assign(m_szSessionName);
|
|
if (hr == S_OK)
|
|
{
|
|
hr = updator.Open( dataconn,
|
|
m_QueryObj.GetSessionRefAccess(),
|
|
NULL,
|
|
&nRows,
|
|
DBGUID_DEFAULT,
|
|
false);
|
|
}
|
|
|
|
ATLASSERT(nRows > 0);
|
|
if (hr == S_OK && nRows <= 0)
|
|
hr = E_UNEXPECTED;
|
|
return hr;
|
|
}
|
|
|
|
// If the session is expired and it's reference is 0,
|
|
// it can be deleted. SessionUnlock calls this function to
|
|
// unlock the session and delete it after we release a session
|
|
// lock. Note that our SQL command will only delete the session
|
|
// if it is expired and it's refcount is <= 0
|
|
HRESULT FreeSession() throw()
|
|
{
|
|
HRESULT hr = E_UNEXPECTED;
|
|
if (!m_szSessionName ||
|
|
m_szSessionName[0]==0)
|
|
return hr;
|
|
|
|
// Get the data connection for this thread.
|
|
CDataConnection dataconn;
|
|
hr = GetSessionConnection(&dataconn, m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
CCommand<CAccessor<CSessionRefUpdator> > updator;
|
|
|
|
// The SQL for this command only deletes the
|
|
// session reference from the references table if it's access
|
|
// count is 0 and it has expired.
|
|
return updator.Open(dataconn,
|
|
m_QueryObj.GetSessionRefDelete(),
|
|
NULL,
|
|
NULL,
|
|
DBGUID_DEFAULT,
|
|
false);
|
|
}
|
|
|
|
// Initialize is called each time a new session is created.
|
|
HRESULT Initialize( LPCSTR szSessionName,
|
|
IServiceProvider *pServiceProvider,
|
|
DWORD_PTR dwCookie,
|
|
PFN_GETPROVIDERINFO pfnInfo) throw()
|
|
{
|
|
if (!szSessionName)
|
|
return E_INVALIDARG;
|
|
|
|
if (!pServiceProvider)
|
|
return E_INVALIDARG;
|
|
|
|
if (!pfnInfo)
|
|
return E_INVALIDARG;
|
|
|
|
m_pfnInfo = pfnInfo;
|
|
m_dwProvCookie = dwCookie;
|
|
m_spServiceProvider = pServiceProvider;
|
|
|
|
_ATLTRY
|
|
{
|
|
CA2CT tcsSessionName(szSessionName);
|
|
if (_tcslen(tcsSessionName) < MAX_SESSION_KEY_LEN)
|
|
_tcscpy(m_szSessionName, tcsSessionName);
|
|
else
|
|
return E_OUTOFMEMORY;
|
|
}
|
|
_ATLCATCHALL()
|
|
{
|
|
return E_OUTOFMEMORY;
|
|
}
|
|
return SessionLock();
|
|
}
|
|
|
|
HRESULT GetSessionConnection(CDataConnection *pConn,
|
|
IServiceProvider *pProv) throw()
|
|
{
|
|
if (!pProv)
|
|
return E_INVALIDARG;
|
|
|
|
if (!m_pfnInfo ||
|
|
!m_dwProvCookie)
|
|
return E_UNEXPECTED;
|
|
|
|
wchar_t *wszProv = NULL;
|
|
if (m_pfnInfo(m_dwProvCookie, &wszProv) && wszProv!=NULL)
|
|
{
|
|
return GetDataSource(pProv,
|
|
ATL_DBSESSION_ID,
|
|
wszProv,
|
|
pConn);
|
|
}
|
|
return E_FAIL;
|
|
}
|
|
|
|
|
|
protected:
|
|
TCHAR m_szSessionName[MAX_SESSION_KEY_LEN];
|
|
unsigned __int64 m_dwTimeout;
|
|
CComPtr<IServiceProvider> m_spServiceProvider;
|
|
DWORD_PTR m_dwProvCookie;
|
|
PFN_GETPROVIDERINFO m_pfnInfo;
|
|
DBQUERYCLASS_TYPE m_QueryObj;
|
|
}; // CDBSession
|
|
|
|
|
|
template <class TDBSession=CDBSession<> >
|
|
class CDBSessionServiceImplT
|
|
{
|
|
wchar_t m_szConnectionString[MAX_CONNECTION_STRING_LEN];
|
|
CComPtr<IServiceProvider> m_spServiceProvider;
|
|
TDBSession::DBQUERYCLASS_TYPE m_QueryObj;
|
|
public:
|
|
typedef const wchar_t* SERVICEIMPL_INITPARAM_TYPE;
|
|
CDBSessionServiceImplT() throw()
|
|
{
|
|
m_dwTimeout = ATL_SESSION_TIMEOUT;
|
|
m_szConnectionString[0] = '\0';
|
|
}
|
|
|
|
static bool GetProviderInfo(DWORD_PTR dwProvCookie, wchar_t **ppszProvInfo) throw()
|
|
{
|
|
if (dwProvCookie &&
|
|
ppszProvInfo)
|
|
{
|
|
CDBSessionServiceImplT<TDBSession> *pSvc =
|
|
reinterpret_cast<CDBSessionServiceImplT<TDBSession>*>(dwProvCookie);
|
|
*ppszProvInfo = pSvc->m_szConnectionString;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
HRESULT GetSessionConnection(CDataConnection *pConn,
|
|
IServiceProvider *pProv) throw()
|
|
{
|
|
if (!pProv)
|
|
return E_INVALIDARG;
|
|
|
|
if(!m_szConnectionString[0])
|
|
return E_UNEXPECTED;
|
|
|
|
return GetDataSource(pProv,
|
|
ATL_DBSESSION_ID,
|
|
m_szConnectionString,
|
|
pConn);
|
|
}
|
|
|
|
HRESULT Initialize(SERVICEIMPL_INITPARAM_TYPE pData,
|
|
IServiceProvider *pProvider,
|
|
unsigned __int64 dwInitialTimeout) throw()
|
|
{
|
|
if (!pData || !pProvider)
|
|
return E_INVALIDARG;
|
|
|
|
if (wcslen(pData) < MAX_CONNECTION_STRING_LEN)
|
|
{
|
|
wcscpy(m_szConnectionString, pData);
|
|
}
|
|
else
|
|
return E_OUTOFMEMORY;
|
|
|
|
m_dwTimeout = dwInitialTimeout;
|
|
m_spServiceProvider = pProvider;
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT CreateNewSession(LPSTR szNewID, DWORD *pdwSize, ISession** ppSession) throw()
|
|
{
|
|
HRESULT hr = E_FAIL;
|
|
CComObject<TDBSession> *pNewSession = NULL;
|
|
|
|
if (!pdwSize)
|
|
return E_INVALIDARG;
|
|
|
|
if (ppSession)
|
|
*ppSession = NULL;
|
|
else
|
|
return E_POINTER;
|
|
|
|
if (szNewID)
|
|
*szNewID = NULL;
|
|
else
|
|
return E_INVALIDARG;
|
|
|
|
|
|
// Create new session
|
|
CComObject<TDBSession>::CreateInstance(&pNewSession);
|
|
if (pNewSession == NULL)
|
|
return E_OUTOFMEMORY;
|
|
|
|
// Create a session name and initialize the object
|
|
hr = m_SessionNameGenerator.GetNewSessionName(szNewID, pdwSize);
|
|
if (hr == S_OK)
|
|
{
|
|
hr = pNewSession->Initialize(szNewID,
|
|
m_spServiceProvider,
|
|
reinterpret_cast<DWORD_PTR>(this),
|
|
GetProviderInfo);
|
|
if (hr == S_OK)
|
|
{
|
|
// we don't hold a reference to the object
|
|
hr = pNewSession->QueryInterface(ppSession);
|
|
}
|
|
}
|
|
|
|
if (hr != S_OK)
|
|
delete pNewSession;
|
|
return hr;
|
|
}
|
|
|
|
HRESULT GetSession(LPCSTR szID, ISession **ppSession) throw()
|
|
{
|
|
HRESULT hr = E_FAIL;
|
|
if (!szID)
|
|
return E_INVALIDARG;
|
|
|
|
if (ppSession)
|
|
*ppSession = NULL;
|
|
else
|
|
return E_POINTER;
|
|
|
|
CComObject<TDBSession> *pNewSession = NULL;
|
|
|
|
// Check the DB to see if the session ID is a valid session
|
|
_ATLTRY
|
|
{
|
|
CA2CT session(szID);
|
|
hr = IsValidSession(session);
|
|
}
|
|
_ATLCATCHALL()
|
|
{
|
|
hr = E_OUTOFMEMORY;
|
|
}
|
|
if (hr == S_OK)
|
|
{
|
|
// Create new session object to represent this session
|
|
CComObject<TDBSession>::CreateInstance(&pNewSession);
|
|
if (pNewSession == NULL)
|
|
return E_OUTOFMEMORY;
|
|
|
|
hr = pNewSession->Initialize(szID,
|
|
m_spServiceProvider,
|
|
reinterpret_cast<DWORD_PTR>(this),
|
|
GetProviderInfo);
|
|
if (hr == S_OK)
|
|
{
|
|
// we don't hold a reference to the object
|
|
hr = pNewSession->QueryInterface(ppSession);
|
|
}
|
|
}
|
|
|
|
if (hr != S_OK && pNewSession)
|
|
delete pNewSession;
|
|
return hr;
|
|
}
|
|
|
|
HRESULT CloseSession(LPCSTR szID) throw()
|
|
{
|
|
if (!szID)
|
|
return E_INVALIDARG;
|
|
|
|
CDataConnection conn;
|
|
HRESULT hr = GetSessionConnection(&conn,
|
|
m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// set up accessors
|
|
CCommand<CAccessor<CSessionRefUpdator> > updator;
|
|
CCommand<CAccessor<CSessionDataDeleteAll> > command;
|
|
_ATLTRY
|
|
{
|
|
CA2CT session(szID);
|
|
hr = updator.Assign(session);
|
|
if (hr == S_OK)
|
|
hr = command.Assign(session);
|
|
}
|
|
_ATLCATCHALL()
|
|
{
|
|
hr = E_OUTOFMEMORY;
|
|
}
|
|
|
|
if (hr == S_OK)
|
|
{
|
|
// delete all session variables
|
|
hr = command.Open(conn,
|
|
m_QueryObj.GetSessionVarDeleteAllVars(),
|
|
NULL,
|
|
NULL,
|
|
DBGUID_DEFAULT,
|
|
false);
|
|
if (hr == S_OK)
|
|
{
|
|
// delete references in the session references table
|
|
hr = updator.Open(conn,
|
|
m_QueryObj.GetSessionRefDeleteFinal(),
|
|
NULL,
|
|
NULL,
|
|
DBGUID_DEFAULT,
|
|
false);
|
|
}
|
|
}
|
|
return hr;
|
|
}
|
|
|
|
HRESULT SetSessionTimeout(unsigned __int64 nTimeout) throw()
|
|
{
|
|
// Get the data connection for this thread
|
|
CDataConnection conn;
|
|
|
|
HRESULT hr = GetSessionConnection(&conn, m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// all sessions get the same timeout
|
|
CCommand<CAccessor<CSetAllTimeouts> > command;
|
|
hr = command.Assign(nTimeout);
|
|
if (hr == S_OK)
|
|
{
|
|
hr = command.Open(conn, m_QueryObj.GetSessionReferencesSet(),
|
|
NULL,
|
|
NULL,
|
|
DBGUID_DEFAULT,
|
|
false);
|
|
if (hr == S_OK)
|
|
{
|
|
m_dwTimeout = nTimeout;
|
|
}
|
|
}
|
|
return hr;
|
|
}
|
|
|
|
|
|
HRESULT GetSessionTimeout(unsigned __int64* pnTimeout) throw()
|
|
{
|
|
if (pnTimeout)
|
|
*pnTimeout = m_dwTimeout;
|
|
else
|
|
return E_INVALIDARG;
|
|
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT GetSessionCount(DWORD *pnCount) throw()
|
|
{
|
|
if (pnCount)
|
|
*pnCount = 0;
|
|
else
|
|
return E_INVALIDARG;
|
|
|
|
CCommand<CAccessor<CSessionRefCount> > command;
|
|
CDataConnection conn;
|
|
HRESULT hr = GetSessionConnection(&conn,
|
|
m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
hr = command.Open(conn,
|
|
m_QueryObj.GetSessionRefGetCount());
|
|
if (hr == S_OK)
|
|
{
|
|
hr = command.MoveFirst();
|
|
if (hr == S_OK)
|
|
{
|
|
*pnCount = (DWORD)command.m_nCount;
|
|
}
|
|
}
|
|
|
|
return hr;
|
|
}
|
|
|
|
void ReleaseAllSessions() throw()
|
|
{
|
|
// nothing to do
|
|
}
|
|
|
|
void SweepSessions() throw()
|
|
{
|
|
// nothing to do
|
|
}
|
|
|
|
|
|
// Helpers
|
|
HRESULT IsValidSession(LPCTSTR szID) throw()
|
|
{
|
|
if (!szID)
|
|
return E_INVALIDARG;
|
|
// Look in the sessionreferences table to see if there is an entry
|
|
// for this session.
|
|
if (m_szConnectionString[0] == 0)
|
|
return E_UNEXPECTED;
|
|
|
|
CDataConnection conn;
|
|
HRESULT hr = GetSessionConnection(&conn,
|
|
m_spServiceProvider);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// Check the session references table to see if
|
|
// this is a valid session
|
|
CCommand<CAccessor<CSessionRefSelector> > selector;
|
|
hr = selector.Assign(szID);
|
|
if (hr != S_OK)
|
|
return hr;
|
|
|
|
// The SQL for this command only deletes the
|
|
// session reference from the references table if it's access
|
|
// count is 0 and it has expired.
|
|
hr = selector.Open(conn,
|
|
m_QueryObj.GetSessionRefSelect(),
|
|
NULL,
|
|
NULL,
|
|
DBGUID_DEFAULT,
|
|
true);
|
|
if (hr == S_OK)
|
|
return selector.MoveFirst();
|
|
return hr;
|
|
}
|
|
|
|
CSessionNameGenerator m_SessionNameGenerator; // Object for generating session names
|
|
unsigned __int64 m_dwTimeout;
|
|
}; // CDBSessionServiceImplT
|
|
|
|
typedef CDBSessionServiceImplT<> CDBSessionServiceImpl;
|
|
|
|
|
|
|
|
|
|
|
|
//////////////////////////////////////////////////////////////////
|
|
//
|
|
// In-memory persisted session
|
|
//
|
|
//////////////////////////////////////////////////////////////////
|
|
|
|
// In-memory persisted session service keeps a pointer
|
|
// to the session obejct around in memory. The pointer is
|
|
// contained in a CComPtr, which is stored in a CAtlMap, so
|
|
// we have to have a CElementTraits class for that.
|
|
typedef CComPtr<ISession> SESSIONPTRTYPE;
|
|
|
|
template<>
|
|
class CElementTraits<SESSIONPTRTYPE> :
|
|
public CElementTraitsBase<SESSIONPTRTYPE>
|
|
{
|
|
public:
|
|
static ULONG Hash( INARGTYPE obj ) throw()
|
|
{
|
|
return( (ULONG)(ULONG_PTR)obj.p);
|
|
}
|
|
|
|
static BOOL CompareElements( OUTARGTYPE element1, OUTARGTYPE element2 ) throw()
|
|
{
|
|
return element1.IsEqualObject(element2.p) ? TRUE : FALSE;
|
|
}
|
|
|
|
static int CompareElementsOrdered( INARGTYPE , INARGTYPE ) throw()
|
|
{
|
|
ATLASSERT(0); // NOT IMPLEMENTED
|
|
return 0;
|
|
}
|
|
};
|
|
|
|
|
|
// CMemSession
|
|
// This session persistance class persists session variables in memory.
|
|
// Note that this type of persistance should only be used on single server
|
|
// web sites.
|
|
class CMemSession :
|
|
public ISession,
|
|
public CComObjectRootEx<CComGlobalsThreadModel>
|
|
{
|
|
public:
|
|
BEGIN_COM_MAP(CMemSession)
|
|
COM_INTERFACE_ENTRY(ISession)
|
|
END_COM_MAP()
|
|
|
|
CMemSession() throw(...)
|
|
{
|
|
}
|
|
|
|
STDMETHOD(GetVariable)(LPCSTR szName, VARIANT *pVal) throw()
|
|
{
|
|
if (!szName)
|
|
return E_INVALIDARG;
|
|
|
|
if (pVal)
|
|
VariantInit(pVal);
|
|
else
|
|
return E_POINTER;
|
|
|
|
HRESULT hr = Access();
|
|
if (hr == S_OK)
|
|
{
|
|
CSLockType lock(m_cs, false);
|
|
hr = lock.Lock();
|
|
if (FAILED(hr))
|
|
return hr;
|
|
_ATLTRY
|
|
{
|
|
CComVariant val;
|
|
if (m_Variables.Lookup(szName, val))
|
|
{
|
|
hr = VariantCopy(pVal, &val);
|
|
}
|
|
}
|
|
_ATLCATCHALL()
|
|
{
|
|
hr = E_UNEXPECTED;
|
|
}
|
|
}
|
|
return hr;
|
|
}
|
|
|
|
STDMETHOD(SetVariable)(LPCSTR szName, VARIANT vNewVal) throw()
|
|
{
|
|
if (!szName)
|
|
return E_INVALIDARG;
|
|
|
|
HRESULT hr = Access();
|
|
if (hr == S_OK)
|
|
{
|
|
CSLockType lock(m_cs, false);
|
|
hr = lock.Lock();
|
|
if (FAILED(hr))
|
|
return hr;
|
|
_ATLTRY
|
|
{
|
|
hr = m_Variables.SetAt(szName, vNewVal) ? S_OK : E_FAIL;
|
|
}
|
|
_ATLCATCHALL()
|
|
{
|
|
hr = E_UNEXPECTED;
|
|
}
|
|
}
|
|
return hr;
|
|
}
|
|
|
|
STDMETHOD(RemoveVariable)(LPCSTR szName) throw()
|
|
{
|
|
if (!szName)
|
|
return E_INVALIDARG;
|
|
|
|
HRESULT hr = Access();
|
|
if (hr == S_OK)
|
|
{
|
|
CSLockType lock(m_cs, false);
|
|
hr = lock.Lock();
|
|
if (FAILED(hr))
|
|
return hr;
|
|
_ATLTRY
|
|
{
|
|
hr = m_Variables.RemoveKey(szName) ? S_OK : E_FAIL;
|
|
}
|
|
_ATLCATCHALL()
|
|
{
|
|
hr = E_UNEXPECTED;
|
|
}
|
|
}
|
|
return hr;
|
|
}
|
|
|
|
STDMETHOD(GetCount)(long *pnCount) throw()
|
|
{
|
|
if (pnCount)
|
|
return *pnCount = 0;
|
|
else
|
|
return E_POINTER;
|
|
|
|
HRESULT hr = Access();
|
|
if (hr == S_OK)
|
|
{
|
|
CSLockType lock(m_cs, false);
|
|
hr = lock.Lock();
|
|
if (FAILED(hr))
|
|
return hr;
|
|
*pnCount = (long) m_Variables.GetCount();
|
|
}
|
|
return hr;
|
|
}
|
|
|
|
STDMETHOD(RemoveAllVariables)() throw()
|
|
{
|
|
HRESULT hr = Access();
|
|
if (hr == S_OK)
|
|
{
|
|
CSLockType lock(m_cs, false);
|
|
hr = lock.Lock();
|
|
if (FAILED(hr))
|
|
return hr;
|
|
m_Variables.RemoveAll();
|
|
}
|
|
|
|
return hr;
|
|
}
|
|
|
|
STDMETHOD(BeginVariableEnum)(HSESSIONENUM *phEnumHandle, POSITION *pPOS) throw()
|
|
{
|
|
if (phEnumHandle)
|
|
*phEnumHandle = NULL;
|
|
else
|
|
return E_POINTER;
|
|
|
|
if (pPOS)
|
|
*pPOS = NULL;
|
|
else
|
|
return E_POINTER;
|
|
|
|
HRESULT hr = Access();
|
|
if (hr == S_OK)
|
|
{
|
|
CSLockType lock(m_cs, false);
|
|
hr = lock.Lock();
|
|
if (FAILED(hr))
|
|
return hr;
|
|
*pPOS = m_Variables.GetStartPosition();
|
|
}
|
|
return hr;
|
|
}
|
|
|
|
STDMETHOD(GetNextVariable)(HSESSIONENUM /*hEnum*/,
|
|
POSITION *pPOS, LPSTR szName,
|
|
DWORD dwLen, VARIANT *pVal) throw()
|
|
{
|
|
if (!szName)
|
|
return E_INVALIDARG;
|
|
|
|
if (pVal)
|
|
VariantInit(pVal);
|
|
else
|
|
return E_POINTER;
|
|
|
|
if (!pPOS)
|
|
return E_POINTER;
|
|
|
|
CComVariant val;
|
|
POSITION pos = *pPOS;
|
|
HRESULT hr = Access();
|
|
if (hr == S_OK)
|
|
{
|
|
CSLockType lock(m_cs, false);
|
|
hr = lock.Lock();
|
|
if (FAILED(hr))
|
|
return hr;
|
|
|
|
_ATLTRY
|
|
{
|
|
CStringA strName = m_Variables.GetKeyAt(pos);
|
|
if (strName.GetLength())
|
|
{
|
|
if (dwLen > (DWORD)strName.GetLength())
|
|
strcpy(szName, strName);
|
|
else
|
|
hr = E_OUTOFMEMORY;
|
|
}
|
|
if (hr == S_OK)
|
|
{
|
|
val = m_Variables.GetNextValue(pos);
|
|
hr = VariantCopy(pVal, &val);
|
|
if (hr == S_OK)
|
|
*pPOS = pos;
|
|
}
|
|
|
|
}
|
|
_ATLCATCHALL()
|
|
{
|
|
hr = E_UNEXPECTED;
|
|
}
|
|
}
|
|
|
|
return hr;
|
|
}
|
|
|
|
STDMETHOD(CloseEnum)(HSESSIONENUM /*hEnumHandle*/) throw()
|
|
{
|
|
return S_OK;
|
|
}
|
|
|
|
STDMETHOD(IsExpired)() throw()
|
|
{
|
|
CTime tmNow = CTime::GetCurrentTime();
|
|
CTimeSpan span = tmNow-m_tLastAccess;
|
|
if ((unsigned __int64)((span.GetTotalSeconds()*1000)) > m_dwTimeout)
|
|
return S_OK;
|
|
return S_FALSE;
|
|
}
|
|
|
|
HRESULT Access() throw()
|
|
{
|
|
// We lock here to protect against multiple threads
|
|
// updating the same member concurrently.
|
|
CSLockType lock(m_cs, false);
|
|
HRESULT hr = lock.Lock();
|
|
if (FAILED(hr))
|
|
return hr;
|
|
m_tLastAccess = CTime::GetCurrentTime();
|
|
return S_OK;
|
|
}
|
|
|
|
STDMETHOD(SetTimeout)(unsigned __int64 dwNewTimeout) throw()
|
|
{
|
|
// We lock here to protect against multiple threads
|
|
// updating the same member concurrently
|
|
CSLockType lock(m_cs, false);
|
|
HRESULT hr = lock.Lock();
|
|
if (FAILED(hr))
|
|
return hr;
|
|
m_dwTimeout = dwNewTimeout;
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT SessionLock() throw()
|
|
{
|
|
Access();
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT SessionUnlock() throw()
|
|
{
|
|
return S_OK;
|
|
}
|
|
|
|
protected:
|
|
typedef CAtlMap<CStringA,
|
|
CComVariant,
|
|
CStringElementTraits<CStringA> > VarMapType;
|
|
unsigned __int64 m_dwTimeout;
|
|
CTime m_tLastAccess;
|
|
VarMapType m_Variables;
|
|
CComAutoCriticalSection m_cs;
|
|
typedef CComCritSecLock<CComAutoCriticalSection> CSLockType;
|
|
}; // CMemSession
|
|
|
|
|
|
//
|
|
// CMemSessionServiceImpl
|
|
// Implements the service part of in-memory persisted session services.
|
|
//
|
|
class CMemSessionServiceImpl
|
|
{
|
|
public:
|
|
typedef void* SERVICEIMPL_INITPARAM_TYPE;
|
|
CMemSessionServiceImpl() throw()
|
|
{
|
|
m_dwTimeout = ATL_SESSION_TIMEOUT;
|
|
}
|
|
|
|
HRESULT CreateNewSession(LPSTR szNewID, DWORD *pdwSize, ISession** ppSession) throw()
|
|
{
|
|
HRESULT hr = E_FAIL;
|
|
CComObject<CMemSession> *pNewSession = NULL;
|
|
|
|
if (!szNewID)
|
|
return E_INVALIDARG;
|
|
|
|
if (!pdwSize)
|
|
return E_POINTER;
|
|
|
|
if (ppSession)
|
|
*ppSession = NULL;
|
|
else
|
|
return E_POINTER;
|
|
|
|
_ATLTRY
|
|
{
|
|
// Create new session
|
|
CComObject<CMemSession>::CreateInstance(&pNewSession);
|
|
if (pNewSession == NULL)
|
|
return E_OUTOFMEMORY;
|
|
|
|
// Initialize and add to list of CSessionData
|
|
hr = m_SessionNameGenerator.GetNewSessionName(szNewID, pdwSize);
|
|
|
|
if (SUCCEEDED(hr))
|
|
{
|
|
CComPtr<ISession> spSession;
|
|
hr = pNewSession->QueryInterface(&spSession);
|
|
if (SUCCEEDED(hr))
|
|
{
|
|
pNewSession->SetTimeout(m_dwTimeout);
|
|
pNewSession->Access();
|
|
CSLockType lock(m_CritSec, false);
|
|
hr = lock.Lock();
|
|
if (FAILED(hr))
|
|
return hr;
|
|
m_Sessions.SetAt(szNewID, spSession);
|
|
*ppSession = spSession.Detach();
|
|
}
|
|
}
|
|
}
|
|
_ATLCATCHALL()
|
|
{
|
|
hr = E_UNEXPECTED;
|
|
}
|
|
|
|
return hr;
|
|
|
|
}
|
|
|
|
HRESULT GetSession(LPCSTR szID, ISession **ppSession) throw()
|
|
{
|
|
HRESULT hr = E_FAIL;
|
|
SessMapType::CPair *pPair = NULL;
|
|
|
|
if (ppSession)
|
|
*ppSession = NULL;
|
|
else
|
|
return E_POINTER;
|
|
|
|
if (!szID)
|
|
return E_INVALIDARG;
|
|
|
|
CSLockType lock(m_CritSec, false);
|
|
hr = lock.Lock();
|
|
if (FAILED(hr))
|
|
return hr;
|
|
_ATLTRY
|
|
{
|
|
pPair = m_Sessions.Lookup(szID);
|
|
if (pPair) // the session exists and is in our local map of sessions
|
|
{
|
|
hr = pPair->m_value.QueryInterface(ppSession);
|
|
}
|
|
}
|
|
_ATLCATCHALL()
|
|
{
|
|
return E_UNEXPECTED;
|
|
}
|
|
|
|
return hr;
|
|
}
|
|
|
|
HRESULT CloseSession(LPCSTR szID) throw()
|
|
{
|
|
if (!szID)
|
|
return E_INVALIDARG;
|
|
|
|
HRESULT hr = E_FAIL;
|
|
CSLockType lock(m_CritSec, false);
|
|
hr = lock.Lock();
|
|
if (FAILED(hr))
|
|
return hr;
|
|
_ATLTRY
|
|
{
|
|
hr = m_Sessions.RemoveKey(szID) ? S_OK : E_FAIL;
|
|
}
|
|
_ATLCATCHALL()
|
|
{
|
|
hr = E_UNEXPECTED;
|
|
}
|
|
return hr;
|
|
}
|
|
|
|
void SweepSessions() throw()
|
|
{
|
|
POSITION posRemove = NULL;
|
|
const SessMapType::CPair *pPair = NULL;
|
|
POSITION pos = NULL;
|
|
|
|
CSLockType lock(m_CritSec, false);
|
|
if (FAILED(lock.Lock()))
|
|
return;
|
|
pos = m_Sessions.GetStartPosition();
|
|
while (pos)
|
|
{
|
|
posRemove = pos;
|
|
pPair = m_Sessions.GetNext(pos);
|
|
if (pPair)
|
|
{
|
|
|
|
if (pPair->m_value.p &&
|
|
S_OK == pPair->m_value->IsExpired())
|
|
{
|
|
// remove our reference on the session
|
|
m_Sessions.RemoveAtPos(posRemove);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
HRESULT SetSessionTimeout(unsigned __int64 nTimeout) throw()
|
|
{
|
|
HRESULT hr = S_OK;
|
|
CComPtr<ISession> spSession;
|
|
m_dwTimeout = nTimeout;
|
|
POSITION pos = m_Sessions.GetStartPosition();
|
|
|
|
CSLockType lock(m_CritSec, false);
|
|
hr = lock.Lock();
|
|
if (FAILED(hr))
|
|
return hr;
|
|
|
|
while (pos)
|
|
{
|
|
SessMapType::CPair *pPair = const_cast<SessMapType::CPair*>(m_Sessions.GetNext(pos));
|
|
if (pPair)
|
|
{
|
|
spSession = pPair->m_value;
|
|
if (spSession)
|
|
{
|
|
// if we fail on any of the sets we will return the
|
|
// error code immediately
|
|
hr = spSession->SetTimeout(nTimeout);
|
|
spSession.Release();
|
|
if (hr != S_OK)
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
return hr;
|
|
}
|
|
|
|
HRESULT GetSessionTimeout(unsigned __int64* pnTimeout) throw()
|
|
{
|
|
if (pnTimeout)
|
|
*pnTimeout = m_dwTimeout;
|
|
else
|
|
return E_POINTER;
|
|
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT GetSessionCount(DWORD *pnCount) throw()
|
|
{
|
|
if (pnCount)
|
|
*pnCount = 0;
|
|
else
|
|
return E_POINTER;
|
|
|
|
CSLockType lock(m_CritSec, false);
|
|
HRESULT hr = lock.Lock();
|
|
if (FAILED(hr))
|
|
return hr;
|
|
*pnCount = (DWORD)m_Sessions.GetCount();
|
|
|
|
return S_OK;
|
|
}
|
|
|
|
void ReleaseAllSessions() throw()
|
|
{
|
|
CSLockType lock(m_CritSec, false);
|
|
if (FAILED(lock.Lock()))
|
|
return;
|
|
m_Sessions.RemoveAll();
|
|
}
|
|
|
|
HRESULT Initialize(SERVICEIMPL_INITPARAM_TYPE,
|
|
IServiceProvider*,
|
|
unsigned __int64 dwNewTimeout) throw()
|
|
{
|
|
m_dwTimeout = dwNewTimeout;
|
|
return m_CritSec.Init();
|
|
}
|
|
|
|
typedef CAtlMap<CStringA,
|
|
SESSIONPTRTYPE,
|
|
CStringElementTraits<CStringA>,
|
|
CElementTraitsBase<SESSIONPTRTYPE> > SessMapType;
|
|
|
|
SessMapType m_Sessions; // map for holding sessions in memory
|
|
CComCriticalSection m_CritSec; // for synchronizing access to map
|
|
typedef CComCritSecLock<CComCriticalSection> CSLockType;
|
|
CSessionNameGenerator m_SessionNameGenerator; // Object for generating session names
|
|
unsigned __int64 m_dwTimeout;
|
|
}; // CMemSessionServiceImpl
|
|
|
|
|
|
|
|
//
|
|
// CSessionStateService
|
|
// This class implements the session state service which can be
|
|
// exposed to request handlers.
|
|
//
|
|
// Template Parameters:
|
|
// CMonitorClass: Provides periodic sweeping services for the session service class.
|
|
// TServiceImplClass: The class that actually implements the methods of the
|
|
// ISessionStateService and ISessionStateControl interfaces.
|
|
template <class CMonitorClass, class TServiceImplClass >
|
|
class CSessionStateService :
|
|
public ISessionStateService,
|
|
public ISessionStateControl,
|
|
public IWorkerThreadClient,
|
|
public CComObjectRootEx<CComGlobalsThreadModel>
|
|
{
|
|
protected:
|
|
CMonitorClass m_Monitor;
|
|
HANDLE m_hTimer;
|
|
CComPtr<IServiceProvider> m_spServiceProvider;
|
|
TServiceImplClass m_SessionServiceImpl;
|
|
public:
|
|
// Construction/Initialization
|
|
CSessionStateService() throw() :
|
|
m_hTimer(NULL)
|
|
{
|
|
|
|
}
|
|
~CSessionStateService() throw()
|
|
{
|
|
ATLASSERT(m_hTimer == NULL);
|
|
}
|
|
BEGIN_COM_MAP(CSessionStateService)
|
|
COM_INTERFACE_ENTRY(ISessionStateService)
|
|
COM_INTERFACE_ENTRY(ISessionStateControl)
|
|
END_COM_MAP()
|
|
|
|
// ISessionStateServie methods
|
|
STDMETHOD(CreateNewSession)(LPSTR szNewID, DWORD *pdwSize, ISession** ppSession) throw()
|
|
{
|
|
return m_SessionServiceImpl.CreateNewSession(szNewID, pdwSize, ppSession);
|
|
}
|
|
|
|
STDMETHOD(GetSession)(LPCSTR szID, ISession **ppSession) throw()
|
|
{
|
|
return m_SessionServiceImpl.GetSession(szID, ppSession);
|
|
}
|
|
|
|
STDMETHOD(CloseSession)(LPCSTR szSessionID) throw()
|
|
{
|
|
return m_SessionServiceImpl.CloseSession(szSessionID);
|
|
}
|
|
|
|
STDMETHOD(SetSessionTimeout)(unsigned __int64 nTimeout) throw()
|
|
{
|
|
return m_SessionServiceImpl.SetSessionTimeout(nTimeout);
|
|
}
|
|
|
|
STDMETHOD(GetSessionTimeout)(unsigned __int64 *pnTimeout) throw()
|
|
{
|
|
return m_SessionServiceImpl.GetSessionTimeout(pnTimeout);
|
|
}
|
|
|
|
STDMETHOD(GetSessionCount)(DWORD *pnSessionCount) throw()
|
|
{
|
|
return m_SessionServiceImpl.GetSessionCount(pnSessionCount);
|
|
}
|
|
|
|
void SweepSessions() throw()
|
|
{
|
|
m_SessionServiceImpl.SweepSessions();
|
|
}
|
|
|
|
void ReleaseAllSessions() throw()
|
|
{
|
|
m_SessionServiceImpl.ReleaseAllSessions();
|
|
}
|
|
|
|
HRESULT Initialize(
|
|
IServiceProvider *pServiceProvider = NULL,
|
|
unsigned __int64 dwTimeout = ATL_SESSION_TIMEOUT,
|
|
TServiceImplClass::SERVICEIMPL_INITPARAM_TYPE pInitData = NULL) throw()
|
|
{
|
|
HRESULT hr = S_OK;
|
|
if (pServiceProvider)
|
|
m_spServiceProvider = pServiceProvider;
|
|
|
|
hr = m_SessionServiceImpl.Initialize(pInitData, pServiceProvider, dwTimeout);
|
|
|
|
return hr;
|
|
}
|
|
|
|
template <class ThreadTraits>
|
|
HRESULT Initialize(
|
|
CWorkerThread<ThreadTraits> *pWorker,
|
|
IServiceProvider *pServiceProvider = NULL,
|
|
unsigned __int64 dwTimeout = ATL_SESSION_TIMEOUT,
|
|
TServiceImplClass::SERVICEIMPL_INITPARAM_TYPE pInitData = NULL) throw()
|
|
{
|
|
if (!pWorker)
|
|
return E_INVALIDARG;
|
|
|
|
HRESULT hr = Initialize(pServiceProvider, dwTimeout, pInitData);
|
|
if (hr == S_OK)
|
|
{
|
|
hr = m_Monitor.Initialize(pWorker);
|
|
if (hr == S_OK)
|
|
{
|
|
//sweep every 500ms
|
|
hr = m_Monitor.AddTimer(ATL_SESSION_SWEEPER_TIMEOUT, this, 0, &m_hTimer);
|
|
}
|
|
}
|
|
return hr;
|
|
}
|
|
|
|
HRESULT Execute(DWORD_PTR /*dwParam*/, HANDLE /*hObject*/) throw()
|
|
{
|
|
SweepSessions();
|
|
return S_OK;
|
|
}
|
|
|
|
HRESULT CloseHandle(HANDLE hHandle) throw()
|
|
{
|
|
::CloseHandle(hHandle);
|
|
m_hTimer = NULL;
|
|
return S_OK;
|
|
}
|
|
|
|
void Shutdown() throw()
|
|
{
|
|
if (m_hTimer)
|
|
{
|
|
m_Monitor.RemoveHandle(m_hTimer);
|
|
m_hTimer = NULL;
|
|
}
|
|
ReleaseAllSessions();
|
|
}
|
|
}; // CSessionStateService
|
|
|
|
} // namespace ATL
|
|
|
|
#pragma warning(pop)
|
|
#endif // __ATLSESSION_H__
|