// This is a part of the Active Template Library.
// Copyright (C) 1996-2001 Microsoft Corporation
// All rights reserved.
//
// This source code is only intended as a supplement to the
// Active Template Library Reference and related
// electronic documentation provided with the library.
// See these sources for detailed information regarding the
// Active Template Library product.

#ifndef __ATLSESSION_H__
#define __ATLSESSION_H__

#pragma once
#pragma warning(push)
#pragma warning(disable: 4702) // unreachable code

#include <atldbcli.h>
#include <atlcom.h>
#include <atlstr.h>
#include <stdio.h>
#include <atlcoll.h>
#include <atltime.h>
#include <atlcrypt.h>
#include <atlenc.h>
#include <atlutil.h>
#include <atlcache.h>
#include <atlspriv.h>
#include <atlsiface.h>

#ifndef SESSION_KEY_LENGTH 
	#define SESSION_KEY_LENGTH 37
#endif

#ifndef MAX_SESSION_KEY_LEN 
	#define MAX_SESSION_KEY_LEN 128
#endif

#ifndef MAX_VARIABLE_NAME_LENGTH 
	#define MAX_VARIABLE_NAME_LENGTH 50
#endif

#ifndef MAX_VARIABLE_VALUE_LENGTH 
	#define MAX_VARIABLE_VALUE_LENGTH 128
#endif

#ifndef DEFAULT_SQL_LEN
	#define DEFAULT_SQL_LEN 1024
#endif

#ifndef MAX_CONNECTION_STRING_LEN
	#define MAX_CONNECTION_STRING_LEN 2048
#endif

#ifndef SESSION_COOKIE_NAME
	#define SESSION_COOKIE_NAME "SESSIONID"
#endif

#ifndef ATL_SESSION_TIMEOUT
	#define ATL_SESSION_TIMEOUT 600000 //10 min
#endif

#ifndef ATL_SESSION_SWEEPER_TIMEOUT
	#define ATL_SESSION_SWEEPER_TIMEOUT 1000 // 1sec
#endif

#define INVALID_DB_SESSION_POS 0x0
#define ATL_DBSESSION_ID _T("__ATL_SESSION_DB_CONNECTION")


namespace ATL {

// CSessionNameGenerator
// This is a helper class that generates random data for session key
// names. This class tries to use the CryptoApi to generate random
// bytes for the session key name. If the CryptoApi isn't available
// then the CRT rand() is used to generate the random bytes. This
// class's GetNewSessionName member function is used to actually
// generate the session name.
class CSessionNameGenerator :
	public CCryptProv
{
public:
	bool m_bCryptNotAvailable;
	enum {MIN_SESSION_KEY_LEN=5};

	CSessionNameGenerator() throw() :
		m_bCryptNotAvailable(false)
	{
		// Note that the crypto api is being
		// initialized with no private key
		// information
		HRESULT hr = InitVerifyContext();
		m_bCryptNotAvailable = FAILED(hr) ? true : false;
	}

	// This function creates a new session name and base64 encodes it.
	// The base64 encoding algorithm used needs at least MIN_SESSION_KEY_LEN
	// bytes to work correctly. Since we stack allocate the temporary
	// buffer that holds the key name, the buffer must be less than or equal to
	// the MAX_SESSION_KEY_LEN in size.
	HRESULT GetNewSessionName(LPSTR szNewID, DWORD *pdwSize) throw()
	{
		HRESULT hr = E_FAIL;

		if (!pdwSize)
			return E_POINTER;

		if (*pdwSize < MIN_SESSION_KEY_LEN ||
			*pdwSize > MAX_SESSION_KEY_LEN)
			return E_INVALIDARG;

		if (!szNewID)
			return E_POINTER;
			
		BYTE key[MAX_SESSION_KEY_LEN] = {0x0};


		// calculate the number of bytes that will fit in the
		// buffer we've been passed
		DWORD dwDataSize = CalcMaxInputSize(*pdwSize);

		if (dwDataSize && *pdwSize >= (DWORD)(Base64EncodeGetRequiredLength(dwDataSize,
			ATL_BASE64_FLAG_NOCRLF)))
		{
			int dwKeySize = *pdwSize;
			hr = GenerateRandomName(key, dwDataSize);
			if (SUCCEEDED(hr))
			{
				if( Base64Encode(key,
								dwDataSize,
								szNewID,
                                &dwKeySize,
								ATL_BASE64_FLAG_NOCRLF) )
				{
					//null terminate
					szNewID[dwKeySize]=0;
					*pdwSize = dwKeySize+1;
				}
				else
					hr = E_FAIL;
			}
			else
			{
				*pdwSize = (DWORD)(Base64EncodeGetRequiredLength(dwDataSize,
					ATL_BASE64_FLAG_NOCRLF));
				return E_OUTOFMEMORY;
			}
		}
		return hr;
	}

	DWORD CalcMaxInputSize(DWORD nOutputSize) throw()
	{
		if (nOutputSize < (DWORD)MIN_SESSION_KEY_LEN)
			return 0;
		// subtract one from the output size to make room
		// for the NULL terminator in the output then
		// calculate the biggest number of input bytes that
		// when base64 encoded will fit in a buffer of size
		// nOutputSize (including base64 padding)
		int nInputSize = ((nOutputSize-1)*3)/4;
		int factor = ((nInputSize*4)/3)%4;
		if (factor)
			nInputSize -= factor;
		return nInputSize;
	}


	HRESULT GenerateRandomName(BYTE *pBuff, DWORD dwBuffSize) throw()
	{
		if (!pBuff)
			return E_POINTER;

		if (!dwBuffSize)
			return E_UNEXPECTED;

		if (!m_bCryptNotAvailable && GetHandle())
		{
			// Use the crypto api to generate random data.
			return GenRandom(dwBuffSize, pBuff);
		}

		// CryptoApi isn't available so we generate
		// random data using rand. We seed the random
		// number generator with a seed that is a combination
		// of bytes from an arbitrary number and the system
		// time which changes every millisecond so it will
		// be different for every call to this function.
		FILETIME ft;
		GetSystemTimeAsFileTime(&ft);
		static DWORD dwVal = 0x21;
		DWORD dwSeed = (dwVal++ << 0x18) | (ft.dwLowDateTime & 0x00ffff00) | dwVal++ & 0x000000ff;
		srand(dwSeed);
		BYTE *pCurr = pBuff;
		// fill buffer with random bytes
		for (int i=0; i < (int)dwBuffSize; i++)
		{
			*pCurr = (BYTE) (rand() & 0x000000ff);
			pCurr++;
		}
		return S_OK;
	}
};


//
// CDefaultQueryClass
// returns Query strings for use in SQL queries used 
// by the database persisted session service.
class CDefaultQueryClass
{
public:
	LPCTSTR GetSessionRefDelete() throw()
	{
		return 	_T("DELETE FROM SessionReferences ")
				_T("WHERE SessionID=? AND RefCount <= 0 ")
				_T("AND DATEDIFF(millisecond,  LastAccess, getdate()) > TimeoutMs");
	}

	LPCTSTR GetSessionRefIsExpired() throw()
	{
		return _T("SELECT SessionID FROM SessionReferences ")
			   _T("WHERE (SessionID=?) AND (DATEDIFF(millisecond,  LastAccess, getdate()) > TimeoutMs)");
	}

	LPCTSTR GetSessionRefDeleteFinal() throw()
	{
		return _T("DELETE FROM SessionReferences ")
			   _T("WHERE SessionID=?");
	}

	LPCTSTR GetSessionRefCreate() throw()
	{
		return _T("INSERT INTO SessionReferences ")
			_T("(SessionID, LastAccess, RefCount, TimeoutMs) ")
			_T("VALUES (?, getdate(), 1, ?)");
	}

	LPCTSTR GetSessionRefUpdateTimeout() throw()
	{
		return _T("UPDATE SessionReferences ")
			   _T("SET TimeoutMs=? WHERE SessionID=?");
	}

	LPCTSTR GetSessionRefAddRef() throw()
	{
		return _T("UPDATE SessionReferences ")
			_T("SET RefCount=RefCount+1, ")
			_T("LastAccess=getdate() ")
			_T("WHERE SessionID=?");
	}

	LPCTSTR GetSessionRefRemoveRef() throw()
	{
		return _T("UPDATE SessionReferences ")
					_T("SET RefCount=RefCount-1, ")
					_T("LastAccess=getdate() ")
					_T("WHERE SessionID=?");
	}

	LPCTSTR GetSessionRefAccess() throw()
	{
		return 	_T("UPDATE SessionReferences ")
				_T("SET LastAccess=getdate() ")
				_T("WHERE SessionID=?");
	}

	LPCTSTR GetSessionRefSelect() throw()
	{
		return _T("SELECT * FROM SessionReferences ")
			   _T("WHERE SessionID=?");
	}

	LPCTSTR GetSessionRefGetCount() throw()
	{
		return 	_T("SELECT COUNT(*) FROM SessionReferences");
	}


	LPCTSTR GetSessionVarCount() throw()
	{
		return _T("SELECT COUNT(*) FROM SessionVariables WHERE SessionID=?");
	}

	LPCTSTR GetSessionVarInsert() throw()
	{
		return  _T("INSERT INTO SessionVariables ")
				_T("(VariableValue, SessionID, VariableName) ")
				_T("VALUES (?, ?, ?)");
	}

	LPCTSTR GetSessionVarUpdate() throw()
	{
		return 	_T("UPDATE SessionVariables ")
				_T("SET VariableValue=? ")
				_T("WHERE SessionID=? AND VariableName=?");
	}

	LPCTSTR GetSessionVarDeleteVar() throw()
	{
		return _T("DELETE FROM SessionVariables ")
				_T("WHERE SessionID=? AND VariableName=?");
	}

	LPCTSTR GetSessionVarDeleteAllVars() throw()
	{
		return _T("DELETE FROM SessionVariables WHERE (SessionID=?)");
	}

	LPCTSTR GetSessionVarSelectVar()throw()
	{
		return _T("SELECT SessionID, VariableName, VariableValue ")
			   _T("FROM SessionVariables ")
			   _T("WHERE SessionID=? AND VariableName=?");
	}

	LPCTSTR GetSessionVarSelectAllVars() throw()
	{
		return _T("SELECT SessionID, VariableName, VariableValue ")
				_T("FROM SessionVariables ")
				_T("WHERE SessionID=?");
	}

	LPCTSTR GetSessionReferencesSet() throw()
	{
		return _T("UPDATE SessionReferences SET TimeoutMs=?");
	}
};


// Contains the data for the session variable accessors
class CSessionDataBase
{
public:
	TCHAR m_szSessionID[MAX_SESSION_KEY_LEN];
	TCHAR m_VariableName[MAX_VARIABLE_NAME_LENGTH];
	BYTE m_VariableValue[MAX_VARIABLE_VALUE_LENGTH];
	DWORD m_VariableLen;
	CSessionDataBase() throw()
	{
		m_szSessionID[0] = '\0';
		m_VariableName[0] = '\0';
		m_VariableValue[0] = '\0';
		m_VariableLen = 0;
	}
	HRESULT Assign(LPCTSTR szSessionID, LPCTSTR szVarName, VARIANT *pVal) throw()
	{
		HRESULT hr = S_OK;
		CVariantStream stream;
		if ( szSessionID )
		{
			if ( _tcslen(szSessionID)< MAX_SESSION_KEY_LEN)
				_tcscpy(m_szSessionID, szSessionID);
			else
				hr = E_OUTOFMEMORY;
		}
		else
			return E_INVALIDARG;

		if (szVarName)
			if ( _tcslen(szVarName) < MAX_VARIABLE_NAME_LENGTH)
				_tcscpy(m_VariableName, szVarName);
			else
				hr = E_OUTOFMEMORY;
		
		if (pVal)
		{
			hr = stream.InsertVariant(pVal);
			if (hr == S_OK)
			{
				BYTE *pBytes = stream.m_stream;
				size_t size = stream.GetVariantSize();
				if (pBytes && size && size < MAX_VARIABLE_VALUE_LENGTH)
				{
					memcpy(m_VariableValue, pBytes, stream.GetVariantSize());
					m_VariableLen = (DWORD)size;
				}
				else
					hr = E_UNEXPECTED;
			}
		}

		return hr;
	}
};

// Use to select a session variable given the name
// of a session and the name of a variable.
class CSessionDataSelector : public CSessionDataBase
{
public:
	BEGIN_COLUMN_MAP(CSessionDataSelector) 
		COLUMN_ENTRY(1, m_szSessionID)
		COLUMN_ENTRY(2, m_VariableName)
		COLUMN_ENTRY_LENGTH(3, m_VariableValue, m_VariableLen)
	END_COLUMN_MAP()
	BEGIN_PARAM_MAP(CSessionDataSelector) 
		SET_PARAM_TYPE(DBPARAMIO_INPUT)
		COLUMN_ENTRY(1, m_szSessionID)
		COLUMN_ENTRY(2, m_VariableName)
	END_PARAM_MAP()
};

// Use to select all session variables given the name of
// of a session.
class CAllSessionDataSelector : public CSessionDataBase
{
public:
	BEGIN_COLUMN_MAP(CAllSessionDataSelector) 
		COLUMN_ENTRY(1, m_szSessionID)
		COLUMN_ENTRY(2, m_VariableName)
		COLUMN_ENTRY_LENGTH(3, m_VariableValue, m_VariableLen)
	END_COLUMN_MAP()
	BEGIN_PARAM_MAP(CAllSessionDataSelector) 
		SET_PARAM_TYPE(DBPARAMIO_INPUT)
		COLUMN_ENTRY(1, m_szSessionID)
	END_PARAM_MAP()
};

// Use to update the value of a session variable
class CSessionDataUpdator : public CSessionDataBase
{
public:
	BEGIN_PARAM_MAP(CSessionDataUpdator) 
		SET_PARAM_TYPE(DBPARAMIO_INPUT)
		COLUMN_ENTRY_LENGTH(1, m_VariableValue, m_VariableLen)
		COLUMN_ENTRY(2, m_szSessionID)
		COLUMN_ENTRY(3, m_VariableName)
	END_PARAM_MAP()
};

// Use to delete a session variable given the
// session name and the name of the variable
class CSessionDataDeletor
{
public:
	CSessionDataDeletor()
	{
		m_szSessionID[0] = '\0';
		m_VariableName[0] = '\0';
	}

	TCHAR m_szSessionID[MAX_SESSION_KEY_LEN];
	TCHAR m_VariableName[MAX_VARIABLE_NAME_LENGTH];
	HRESULT Assign(LPCTSTR szSessionID, LPCTSTR szVarName) throw()
	{
		if (szSessionID)
		{
			if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
				_tcscpy(m_szSessionID, szSessionID);
			else
				return E_OUTOFMEMORY;
		}

		if (szVarName)
		{
			if(_tcslen(szVarName) < MAX_VARIABLE_NAME_LENGTH)
				_tcscpy(m_VariableName, szVarName);
			else
				return E_OUTOFMEMORY;
		}
		return S_OK;
	}

	BEGIN_PARAM_MAP(CSessionDataDeletor) 
		SET_PARAM_TYPE(DBPARAMIO_INPUT)
		COLUMN_ENTRY(1, m_szSessionID)
		COLUMN_ENTRY(2, m_VariableName)
	END_PARAM_MAP()
};

class CSessionDataDeleteAll
{
public:
	TCHAR m_szSessionID[MAX_SESSION_KEY_LEN];
	HRESULT Assign(LPCTSTR szSessionID) throw()
	{
		if (!szSessionID)
			return E_INVALIDARG;

		if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
			_tcscpy(m_szSessionID, szSessionID);
		else
			return E_OUTOFMEMORY;

		return S_OK;
	}

	BEGIN_PARAM_MAP(CSessionDataDeleteAll) 
		SET_PARAM_TYPE(DBPARAMIO_INPUT)
		COLUMN_ENTRY(1, m_szSessionID)
	END_PARAM_MAP()
};

// Used for retrieving the count of session variables for
// a given session ID.
class CCountAccessor
{
public:
	LONG m_nCount;
	TCHAR m_szSessionID[MAX_SESSION_KEY_LEN];
	CCountAccessor() throw()
	{
		m_szSessionID[0] = '\0';
		m_nCount = 0;
	}

	HRESULT Assign(LPCTSTR szSessionID) throw()
	{
		if (!szSessionID)
			return E_INVALIDARG;

		if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
			_tcscpy(m_szSessionID, szSessionID);
		else
			return E_OUTOFMEMORY;

		return S_OK;
	}

	BEGIN_COLUMN_MAP(CCountAccessor)
		COLUMN_ENTRY(1, m_nCount)
	END_COLUMN_MAP()
	BEGIN_PARAM_MAP(CCountAccessor)
		SET_PARAM_TYPE(DBPARAMIO_INPUT)
		COLUMN_ENTRY(1, m_szSessionID)
	END_PARAM_MAP()
};


// Used for updating entries in the session
// references table, given a session ID
class CSessionRefUpdator
{
public:
	TCHAR m_SessionID[MAX_SESSION_KEY_LEN];
	HRESULT Assign(LPCTSTR szSessionID) throw()
	{
		if (!szSessionID)
			return E_INVALIDARG;
		if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
			_tcscpy(m_SessionID, szSessionID);
		else
			return E_OUTOFMEMORY;
		return S_OK;
	}
	BEGIN_PARAM_MAP(CSessionRefUpdator)
		SET_PARAM_TYPE(DBPARAMIO_INPUT)
		COLUMN_ENTRY(1, m_SessionID)
	END_PARAM_MAP()
};

class CSessionRefIsExpired
{
public:
	TCHAR m_SessionID[MAX_SESSION_KEY_LEN];
	TCHAR m_SessionIDOut[MAX_SESSION_KEY_LEN];
	HRESULT Assign(LPCTSTR szSessionID) throw()
	{
		m_SessionIDOut[0]=0;
		if (!szSessionID)
			return E_INVALIDARG;
		if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
			_tcscpy(m_SessionID, szSessionID);
		else
			return E_OUTOFMEMORY;
		return S_OK;
	}
	BEGIN_COLUMN_MAP(CSessionRefIsExpired)
		COLUMN_ENTRY(1, m_SessionIDOut)
	END_COLUMN_MAP()
	BEGIN_PARAM_MAP(CSessionRefIsExpired)
		SET_PARAM_TYPE(DBPARAMIO_INPUT)
		COLUMN_ENTRY(1, m_SessionID)
	END_PARAM_MAP()
};

class CSetAllTimeouts
{
public:
	unsigned __int64 m_dwNewTimeout;
	HRESULT Assign(unsigned __int64 dwNewValue)
	{
		m_dwNewTimeout = dwNewValue;
		return S_OK;
	}
	BEGIN_PARAM_MAP(CSetAllTimeouts)
		SET_PARAM_TYPE(DBPARAMIO_INPUT)
		COLUMN_ENTRY(1, m_dwNewTimeout)
	END_PARAM_MAP()
};

class CSessionRefUpdateTimeout
{
public:
	TCHAR m_SessionID[MAX_SESSION_KEY_LEN];
	unsigned __int64 m_nNewTimeout;
	HRESULT Assign(LPCTSTR szSessionID, unsigned __int64 nNewTimeout) throw()
	{
		if (!szSessionID)
			return E_INVALIDARG;

		if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
			_tcscpy(m_SessionID, szSessionID);
		else
			return E_OUTOFMEMORY;

		m_nNewTimeout = nNewTimeout;

		return S_OK;
	}

	BEGIN_PARAM_MAP(CSessionRefUpdateTimeout)
		SET_PARAM_TYPE(DBPARAMIO_INPUT)
		COLUMN_ENTRY(1, m_nNewTimeout)
		COLUMN_ENTRY(2, m_SessionID)
	END_PARAM_MAP()
};

class CSessionRefSelector
{
public:
	TCHAR m_SessionID[MAX_SESSION_KEY_LEN];
	int m_RefCount;
	HRESULT Assign(LPCTSTR szSessionID) throw()
	{
		if (!szSessionID)
			return E_INVALIDARG;
		if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
			_tcscpy(m_SessionID, szSessionID);
		else
			return E_OUTOFMEMORY;
		return S_OK;
	}
	BEGIN_COLUMN_MAP(CSessionRefSelector)
		COLUMN_ENTRY(1, m_SessionID)
		COLUMN_ENTRY(3, m_RefCount)
	END_COLUMN_MAP()
	BEGIN_PARAM_MAP(CSessionRefSelector)
		SET_PARAM_TYPE(DBPARAMIO_INPUT)
		COLUMN_ENTRY(1, m_SessionID)
	END_PARAM_MAP()
};

class CSessionRefCount
{
public:
	LONG m_nCount;
	BEGIN_COLUMN_MAP(CSessionRefCount)
		COLUMN_ENTRY(1, m_nCount)
	END_COLUMN_MAP()
};

// Used for creating new entries in the session
// references table.
class CSessionRefCreator
{
public:
	TCHAR m_SessionID[MAX_SESSION_KEY_LEN];
	unsigned __int64 m_TimeoutMs;
	HRESULT Assign(LPCTSTR szSessionID, unsigned __int64 timeout) throw()
	{
		if (!szSessionID)
			return E_INVALIDARG;
		if (_tcslen(szSessionID) < MAX_SESSION_KEY_LEN)
		{
			_tcscpy(m_SessionID, szSessionID);
			m_TimeoutMs = timeout;
		}
		else
			return E_OUTOFMEMORY;
		return S_OK;
	}
	BEGIN_PARAM_MAP(CSessionRefCreator)
		SET_PARAM_TYPE(DBPARAMIO_INPUT)
		COLUMN_ENTRY(1, m_SessionID)
		COLUMN_ENTRY(2, m_TimeoutMs)
	END_PARAM_MAP()
};


// CDBSession
// This session persistance class persists session variables to
// an OLEDB datasource. The following table gives a general description
// of the table schema for the tables this class uses.
//
// TableName: SessionVariables
// Column		Name			Type							Description
// 1			SessionID		char[MAX_SESSION_KEY_LEN]		Session Key name
// 2			VariableName	char[MAX_VARIABLE_NAME_LENGTH]	Variable Name
// 3			VariableValue	varbinary[MAX_VARIABLE_VALUE_LENGTH]	Variable Value

//
// TableName: SessionReferences
// Column		Name			Type							Description
// 1			SessionID		char[MAX_SESSION_KEY_LEN]		Session Key Name.
// 2			LastAccess		datetime						Date and time of last access to this session.
// 3			RefCount		int								Current references on this session.
// 4			TimeoutMS		int								Timeout value for the session in milli seconds

typedef bool (*PFN_GETPROVIDERINFO)(DWORD_PTR, wchar_t **);

template <class QueryClass=CDefaultQueryClass>
class CDBSession:
	public ISession,
	public CComObjectRootEx<CComGlobalsThreadModel>

{
	typedef CCommand<CAccessor<CAllSessionDataSelector> >  iterator_accessor;
public:
	typedef QueryClass DBQUERYCLASS_TYPE;
	BEGIN_COM_MAP(CDBSession)
		COM_INTERFACE_ENTRY(ISession)
	END_COM_MAP()

	CDBSession() throw():
		m_dwTimeout(ATL_SESSION_TIMEOUT)
	{
		m_szSessionName[0] = '\0';
	}

	~CDBSession() throw()
	{
	}

	void FinalRelease()throw()
	{
		SessionUnlock();
	}

	STDMETHOD(SetVariable)(LPCSTR szName, VARIANT Val) throw()
	{
		HRESULT hr = E_FAIL;
		if (!szName)
			return E_INVALIDARG;

		// Get the data connection for this thread.
		CDataConnection dataconn;
		hr = GetSessionConnection(&dataconn, m_spServiceProvider);
		if (hr != S_OK)
			return hr;

		// Update the last access time for this session
		hr = Access();
		if (hr != S_OK)
			return hr;

		// Allocate an updator command and fill out it's input parameters.
		CCommand<CAccessor<CSessionDataUpdator> > command;
		_ATLTRY
		{
			CA2CT name(szName);
			hr = command.Assign(m_szSessionName, name, &Val);
		}
		_ATLCATCHALL()
		{
			hr = E_OUTOFMEMORY;
		}
		if (hr != S_OK)
			return hr;

		// Try an update. Update will fail if the variable is not already there.
		LONG nRows = 0;

		hr = command.Open(dataconn, 
						m_QueryObj.GetSessionVarUpdate(),
						NULL, &nRows, DBGUID_DEFAULT, false);
		if (hr == S_OK && nRows <= 0)
			hr = E_UNEXPECTED;
		if (hr != S_OK)
		{
			// Try an insert
			hr = command.Open(dataconn, m_QueryObj.GetSessionVarInsert(), NULL, &nRows, DBGUID_DEFAULT, false);
			if (hr == S_OK && nRows <=0)
				hr = E_UNEXPECTED;
		}

		return hr;
	}

	// Warning: For string data types, depending on the configuration of
	// your database, strings might be returned with trailing white space.
	STDMETHOD(GetVariable)(LPCSTR szName, VARIANT *pVal) throw()
	{
		HRESULT hr = E_FAIL;
		if (!szName)
			return E_INVALIDARG;
		if (pVal)
			VariantClear(pVal);
		else
			return E_POINTER;

		// Get the data connection for this thread
		CDataConnection dataconn;
		hr = GetSessionConnection(&dataconn, m_spServiceProvider);
		if (hr != S_OK)
			return hr;

		// Update the last access time for this session
		hr = Access();
		if (hr != S_OK)
			return hr;

		// Allocate a command a fill out it's input parameters.
		CCommand<CAccessor<CSessionDataSelector> > command;
		_ATLTRY
		{
			CA2CT name(szName);
			hr = command.Assign(m_szSessionName, name, NULL);
		}
		_ATLCATCHALL()
		{
			hr = E_OUTOFMEMORY;
		}

		if (hr == S_OK)
		{
			hr = command.Open(dataconn, m_QueryObj.GetSessionVarSelectVar());
			if (SUCCEEDED(hr))
			{
				if ( S_OK == (hr = command.MoveFirst()))
				{
					CStreamOnByteArray stream(command.m_VariableValue);
					CComVariant vOut;
					hr = vOut.ReadFromStream(static_cast<IStream*>(&stream));
					if (hr == S_OK)
						hr = vOut.Detach(pVal);
				}
			}
		}
		return hr;
	}

	STDMETHOD(RemoveVariable)(LPCSTR szName) throw()
	{
		HRESULT hr = E_FAIL;
		if (!szName)
			return E_INVALIDARG;

		// Get the data connection for this thread.
		CDataConnection dataconn;
		hr = GetSessionConnection(&dataconn, m_spServiceProvider);
		if (hr != S_OK)
			return hr;

		// update the last access time for this session
		hr = Access();
		if (hr != S_OK)
			return hr;

		// allocate a command and set it's input parameters
		CCommand<CAccessor<CSessionDataDeletor> > command;
		_ATLTRY
		{
			CA2CT name(szName);
			hr = command.Assign(m_szSessionName, name);
		}
		_ATLCATCHALL()
		{
			return E_OUTOFMEMORY;
		}

		// execute the command
		long nRows = 0;
		if (hr == S_OK)
			hr = command.Open(dataconn, m_QueryObj.GetSessionVarDeleteVar(),
			NULL, &nRows, DBGUID_DEFAULT, false);
		if (hr == S_OK && nRows <= 0)
			hr = E_UNEXPECTED;
		return hr;
	}

	// Gives the count of rows in the table for this session ID.
	STDMETHOD(GetCount)(long *pnCount) throw()
	{
		HRESULT hr = S_OK;
		if (pnCount)
			*pnCount = 0;
		else
			return E_POINTER;

		// Get the database connection for this thread.
		CDataConnection dataconn;
		hr = GetSessionConnection(&dataconn, m_spServiceProvider);
		if (hr != S_OK)
			return hr;
		hr = Access();
		if (hr != S_OK)
			return hr;
		CCommand<CAccessor<CCountAccessor> > command;
		
		hr = command.Assign(m_szSessionName);
		if (hr == S_OK)
		{
			hr = command.Open(dataconn, m_QueryObj.GetSessionVarCount());
			if (hr == S_OK)
			{
				if (S_OK == (hr = command.MoveFirst()))
				{
					*pnCount = command.m_nCount;
					hr = S_OK;
				}
			}
		}
		return hr;
	}

	STDMETHOD(RemoveAllVariables)() throw()
	{
		HRESULT hr = E_UNEXPECTED;

		// Get the data connection for this thread.
		CDataConnection dataconn;
		hr = GetSessionConnection(&dataconn, m_spServiceProvider);
		if (hr != S_OK)
			return hr;

		CCommand<CAccessor<CSessionDataDeleteAll> > command;
		hr = command.Assign(m_szSessionName);
		if (hr != S_OK)
			return hr;

		// delete all session variables
		hr = command.Open(dataconn, m_QueryObj.GetSessionVarDeleteAllVars(), NULL, NULL, DBGUID_DEFAULT, false);
		return hr;
	}

	// Iteration of variables works by taking a snapshot
	// of the sessions at the point in time BeginVariableEnum
	// is called, and then keeping an index variable that you use to
	// move through the snapshot rowset. It is important to know
	// that the handle returned in phEnum is not thread safe. It
	// should only be used by the calling thread.
	STDMETHOD(BeginVariableEnum)(HSESSIONENUM *phEnum, POSITION *pPOS) throw()
	{
		HRESULT hr = E_FAIL;
		if (!pPOS)
			return E_POINTER;

		if (phEnum)
			*phEnum = NULL;
		else
			return E_POINTER;

		// Get the data connection for this thread.
		CDataConnection dataconn;
		hr = GetSessionConnection(&dataconn, m_spServiceProvider);
		if (hr != S_OK)
			return hr;

		// Update the last access time for this session.
		hr = Access();
		if (hr != S_OK)
			return hr;

		// Allocate a new iterator accessor and initialize it's input parameters.
		iterator_accessor *pIteratorAccessor = NULL;
		ATLTRYALLOC(pIteratorAccessor = new iterator_accessor);
		if (!pIteratorAccessor)
			return E_OUTOFMEMORY;

		hr = pIteratorAccessor->Assign(m_szSessionName, NULL, NULL);
		if (hr == S_OK)
		{
			// execute the command and move to the first row of the recordset.
			hr = pIteratorAccessor->Open(dataconn, 
								m_QueryObj.GetSessionVarSelectAllVars());
			if (hr == S_OK)
			{
				hr = pIteratorAccessor->MoveFirst();
				if (hr == S_OK)
				{
					*pPOS = (POSITION) INVALID_DB_SESSION_POS + 1;
					*phEnum = reinterpret_cast<HSESSIONENUM>(pIteratorAccessor);
				}
			}

			if (hr != S_OK)
			{
				*pPOS = INVALID_DB_SESSION_POS;
				*phEnum = NULL;
				delete pIteratorAccessor;
			}
		}
		return hr;
	}

	// The values for hEnum and pPos must have been initialized in a previous
	// call to BeginVariableEnum. On success, the out variant will hold the next
	// variable
	STDMETHOD(GetNextVariable)(HSESSIONENUM hEnum, POSITION *pPOS, LPSTR szName, DWORD dwLen, VARIANT *pVal) throw()
	{
		if (!pPOS)
			return E_INVALIDARG;

		if (pVal)
			VariantInit(pVal);
		else
			return E_POINTER;

		if (!hEnum)
			return E_UNEXPECTED;

		if (*pPOS <= INVALID_DB_SESSION_POS)
			return E_UNEXPECTED;

		iterator_accessor *pIteratorAccessor = reinterpret_cast<iterator_accessor*>(hEnum);

		// update the last access time.
		HRESULT hr = Access();

		POSITION posCurrent = *pPOS;
		
		if (szName)
		{
			// caller wants entry name
			size_t nNameLenChars = _tcslen(pIteratorAccessor->m_VariableName);
			if (dwLen > nNameLenChars)
			{
				_ATLTRY
				{
					CT2CA szVarName(pIteratorAccessor->m_VariableName);
					strcpy(szName, szVarName);
				}
				_ATLCATCHALL()
				{
					hr = E_OUTOFMEMORY;
				}
			}
			else
				hr = E_OUTOFMEMORY; // buffer not big enough
		}

		if (hr == S_OK)
		{
			CStreamOnByteArray stream(pIteratorAccessor->m_VariableValue);
			CComVariant vOut;
			hr = vOut.ReadFromStream(static_cast<IStream*>(&stream));
			if (hr == S_OK)
				vOut.Detach(pVal);
			else
				return hr;
		}
		else
			return hr;

		hr = pIteratorAccessor->MoveNext();
		*pPOS = ++posCurrent;

		if (hr == DB_S_ENDOFROWSET)
		{
			// We're done iterating, reset everything
			*pPOS = INVALID_DB_SESSION_POS;
			hr = S_OK;
		}

		if (hr != S_OK)
		{
			VariantClear(pVal);
		}
		return hr;
	}

	// CloseEnum frees up any resources allocated by the iterator
	STDMETHOD(CloseEnum)(HSESSIONENUM hEnum) throw()
	{
		iterator_accessor *pIteratorAccessor = reinterpret_cast<iterator_accessor*>(hEnum);
		if (!pIteratorAccessor)
			return E_INVALIDARG;
		pIteratorAccessor->Close();
		delete pIteratorAccessor;
		return S_OK;
	}

	//
	// Returns S_FALSE if it's not expired
	// S_OK if it is expired and an error HRESULT
	// if an error occurred.
	STDMETHOD(IsExpired)() throw()
	{
		HRESULT hrRet = S_FALSE;
		HRESULT hr = E_UNEXPECTED;

		// Get the data connection for this thread.
		CDataConnection dataconn;
		hr = GetSessionConnection(&dataconn, m_spServiceProvider);
		if (hr != S_OK)
			return hr;

		CCommand<CAccessor<CSessionRefIsExpired> > command;
		hr = command.Assign(m_szSessionName);
		if (hr != S_OK)
			return hr;

		hr = command.Open(dataconn, m_QueryObj.GetSessionRefIsExpired(), 
							NULL, NULL, DBGUID_DEFAULT, true);
		if (hr == S_OK)
		{
			if (S_OK == command.MoveFirst())
			{
				if (!_tcscmp(command.m_SessionIDOut, m_szSessionName))
					hrRet = S_OK;
			}
		}

		if (hr == S_OK)
			return hrRet;
		return hr;
	}

	STDMETHOD(SetTimeout)(unsigned __int64 dwNewTimeout) throw()
	{
		HRESULT hr = E_UNEXPECTED;

		// Get the data connection for this thread.
		CDataConnection dataconn;
		hr = GetSessionConnection(&dataconn, m_spServiceProvider);
		if (hr != S_OK)
			return hr;

		// allocate a command and set it's input parameters
		CCommand<CAccessor<CSessionRefUpdateTimeout> > command;
		hr = command.Assign(m_szSessionName, dwNewTimeout);
		if (hr != S_OK)
			return hr;

		hr = command.Open(dataconn, m_QueryObj.GetSessionRefUpdateTimeout(),
						NULL, NULL, DBGUID_DEFAULT, false);

		return hr;
	}

	// SessionLock increments the session reference count for this session.
	// If there is not a session by this name in the session references table,
	// a new session entry is created in the the table.
	HRESULT SessionLock() throw()
	{
		HRESULT hr = E_UNEXPECTED;
		if (!m_szSessionName || m_szSessionName[0]==0)
			return hr; // no session to lock.

		// retrieve the data connection for this thread
		CDataConnection dataconn;
		hr = GetSessionConnection(&dataconn, m_spServiceProvider);
		if (hr != S_OK)
			return hr;

		// first try to update a session with this name
		LONG nRows = 0;
		CCommand<CAccessor<CSessionRefUpdator> > updator;
		if (S_OK == updator.Assign(m_szSessionName))
		{
			if (S_OK != (hr = updator.Open(dataconn, m_QueryObj.GetSessionRefAddRef(),
				NULL, &nRows, DBGUID_DEFAULT, false)) ||
				nRows == 0)
			{
				// No session to update. Use the creator accessor
				// to create a new session reference.
				CCommand<CAccessor<CSessionRefCreator> > creator;
				hr = creator.Assign(m_szSessionName, m_dwTimeout);
				if (hr == S_OK)
					hr = creator.Open(dataconn, m_QueryObj.GetSessionRefCreate(),
					NULL, &nRows, DBGUID_DEFAULT, false);
			}
		}

		// We should have been able to create or update a session.
		ATLASSERT(nRows > 0);
		if (hr == S_OK && nRows <= 0)
			hr = E_UNEXPECTED;

		return hr;
	}

	// SessionUnlock decrements the session RefCount for this session.
	// Sessions cannot be removed from the database unless the session
	// refcount is 0
	HRESULT SessionUnlock() throw()
	{
		HRESULT hr = E_UNEXPECTED;
		if (!m_szSessionName ||
			m_szSessionName[0]==0)
			return hr; 

		// get the data connection for this thread
		CDataConnection dataconn;
		hr = GetSessionConnection(&dataconn, m_spServiceProvider);
		if (hr != S_OK)
			return hr;

		// The session must exist at this point in order to unlock it
		// so we can just use the session updator here.
		LONG nRows = 0;
		CCommand<CAccessor<CSessionRefUpdator> > updator;
		hr = updator.Assign(m_szSessionName);
		if (hr == S_OK)
		{
			hr = updator.Open(	dataconn,
								m_QueryObj.GetSessionRefRemoveRef(),
								NULL,
								&nRows,
								DBGUID_DEFAULT,
								false);
		}
		if (hr != S_OK)
			return hr;

        // delete the session from the database if 
		// nobody else is using it and it's expired.
		hr = FreeSession();
		return hr;
	}

	// Access updates the last access time for the session. The access
	// time for sessions is updated using the SQL GETDATE function on the
	// database server so that all clients will be using the same clock
	// to compare access times against.
	HRESULT Access() throw()
	{
		HRESULT hr = E_UNEXPECTED;

		if (!m_szSessionName || 
			m_szSessionName[0]==0)
			return hr; // no session to access

		// get the data connection for this thread
		CDataConnection dataconn;
		hr = GetSessionConnection(&dataconn, m_spServiceProvider);
		if (hr != S_OK)
			return hr;

		// The session reference entry in the references table must
		// be created prior to calling this function so we can just
		// use an updator to update the current entry.
		CCommand<CAccessor<CSessionRefUpdator> > updator;
		
		LONG nRows = 0;
		hr = updator.Assign(m_szSessionName);
		if (hr == S_OK)
		{
			hr = updator.Open(	dataconn,
								m_QueryObj.GetSessionRefAccess(),
								NULL,
								&nRows,
								DBGUID_DEFAULT,
								false);
		}

		ATLASSERT(nRows > 0);
		if (hr == S_OK && nRows <= 0)
			hr = E_UNEXPECTED;
		return hr;
	}

	// If the session is expired and it's reference is 0,
	// it can be deleted. SessionUnlock calls this function to
	// unlock the session and delete it after we release a session
	// lock. Note that our SQL command will only delete the session
	// if it is expired and it's refcount is <= 0
	HRESULT FreeSession() throw()
	{
		HRESULT hr = E_UNEXPECTED;
		if (!m_szSessionName ||
			m_szSessionName[0]==0)
			return hr;

		// Get the data connection for this thread.
		CDataConnection dataconn;
		hr = GetSessionConnection(&dataconn, m_spServiceProvider);
		if (hr != S_OK)
			return hr;

		CCommand<CAccessor<CSessionRefUpdator> > updator;

		// The SQL for this command only deletes the
		// session reference from the references table if it's access
		// count is 0 and it has expired.
		return updator.Open(dataconn,
							m_QueryObj.GetSessionRefDelete(),
							NULL,
							NULL,
							DBGUID_DEFAULT,
							false);
	}

	// Initialize is called each time a new session is created.
	HRESULT Initialize( LPCSTR szSessionName, 
						IServiceProvider *pServiceProvider,
						DWORD_PTR dwCookie,
						PFN_GETPROVIDERINFO pfnInfo) throw()
	{
		if (!szSessionName)
			return E_INVALIDARG;

		if (!pServiceProvider)
			return E_INVALIDARG;

		if (!pfnInfo)
			return E_INVALIDARG;

		m_pfnInfo = pfnInfo;
		m_dwProvCookie = dwCookie;
		m_spServiceProvider = pServiceProvider;

		_ATLTRY
		{
			CA2CT tcsSessionName(szSessionName);
			if (_tcslen(tcsSessionName) < MAX_SESSION_KEY_LEN)
				_tcscpy(m_szSessionName, tcsSessionName);
			else
				return E_OUTOFMEMORY;
		}
		_ATLCATCHALL()
		{
			return E_OUTOFMEMORY;
		}
		return SessionLock();
	}

	HRESULT GetSessionConnection(CDataConnection *pConn,
        					     IServiceProvider *pProv) throw()
	{
		if (!pProv)
			return E_INVALIDARG;
		
		if (!m_pfnInfo || 
			!m_dwProvCookie)
			return E_UNEXPECTED;

		wchar_t *wszProv = NULL;
		if (m_pfnInfo(m_dwProvCookie, &wszProv) && wszProv!=NULL)
		{
			return GetDataSource(pProv,
						ATL_DBSESSION_ID,
						wszProv,
						pConn);
		}
		return E_FAIL;
	}


protected:
	TCHAR m_szSessionName[MAX_SESSION_KEY_LEN];
	unsigned __int64 m_dwTimeout;
	CComPtr<IServiceProvider> m_spServiceProvider;
	DWORD_PTR m_dwProvCookie;
	PFN_GETPROVIDERINFO m_pfnInfo;
	DBQUERYCLASS_TYPE m_QueryObj;
}; // CDBSession


template <class TDBSession=CDBSession<> >
class CDBSessionServiceImplT
{
	wchar_t m_szConnectionString[MAX_CONNECTION_STRING_LEN];
	CComPtr<IServiceProvider> m_spServiceProvider;
	TDBSession::DBQUERYCLASS_TYPE m_QueryObj;
public:
	typedef const wchar_t* SERVICEIMPL_INITPARAM_TYPE;
	CDBSessionServiceImplT() throw()
	{
		m_dwTimeout = ATL_SESSION_TIMEOUT;
		m_szConnectionString[0] = '\0';
	}

	static bool GetProviderInfo(DWORD_PTR dwProvCookie, wchar_t **ppszProvInfo) throw()
	{
		if (dwProvCookie &&
			ppszProvInfo)
		{
			CDBSessionServiceImplT<TDBSession> *pSvc = 
				reinterpret_cast<CDBSessionServiceImplT<TDBSession>*>(dwProvCookie);
			*ppszProvInfo = pSvc->m_szConnectionString;
			return true;
		}
		return false;
	}

	HRESULT GetSessionConnection(CDataConnection *pConn,
        					     IServiceProvider *pProv) throw()
	{
		if (!pProv)
			return E_INVALIDARG;

		if(!m_szConnectionString[0])
			return E_UNEXPECTED;
		
		return GetDataSource(pProv,
					ATL_DBSESSION_ID,
					m_szConnectionString,
					pConn);
	}

	HRESULT Initialize(SERVICEIMPL_INITPARAM_TYPE pData,
						IServiceProvider *pProvider,
						unsigned __int64 dwInitialTimeout) throw()
	{
		if (!pData || !pProvider)
			return E_INVALIDARG;

		if (wcslen(pData) < MAX_CONNECTION_STRING_LEN)
		{
			wcscpy(m_szConnectionString, pData); 
		}
		else
			return E_OUTOFMEMORY;

		m_dwTimeout = dwInitialTimeout;
		m_spServiceProvider = pProvider;
		return S_OK;
	}

	HRESULT CreateNewSession(LPSTR szNewID, DWORD *pdwSize, ISession** ppSession) throw()
	{
		HRESULT hr = E_FAIL;
		CComObject<TDBSession> *pNewSession = NULL;

		if (!pdwSize)
			return E_INVALIDARG;

		if (ppSession)
			*ppSession = NULL;
		else
			return E_POINTER;

		if (szNewID)
			*szNewID = NULL;
		else
			return E_INVALIDARG;


		// Create new session
		CComObject<TDBSession>::CreateInstance(&pNewSession);
		if (pNewSession == NULL)
			return E_OUTOFMEMORY;

		// Create a session name and initialize the object
		hr = m_SessionNameGenerator.GetNewSessionName(szNewID, pdwSize);
		if (hr == S_OK)
		{
			hr = pNewSession->Initialize(szNewID, 
										m_spServiceProvider,
										reinterpret_cast<DWORD_PTR>(this),
										GetProviderInfo);
			if (hr == S_OK)
			{
				// we don't hold a reference to the object
				hr = pNewSession->QueryInterface(ppSession);
			}
		}

		if (hr != S_OK)
			delete pNewSession;
		return hr;
	}

	HRESULT GetSession(LPCSTR szID, ISession **ppSession) throw()
	{
		HRESULT hr = E_FAIL;
		if (!szID)
			return E_INVALIDARG;

		if (ppSession)
			*ppSession = NULL;
		else
			return E_POINTER;

		CComObject<TDBSession> *pNewSession = NULL;

		// Check the DB to see if the session ID is a valid session
		_ATLTRY
		{
			CA2CT session(szID);
			hr = IsValidSession(session);
		}
		_ATLCATCHALL()
		{
			hr = E_OUTOFMEMORY;
		}
		if (hr == S_OK)
		{
			// Create new session object to represent this session
			CComObject<TDBSession>::CreateInstance(&pNewSession);
			if (pNewSession == NULL)
				return E_OUTOFMEMORY;

			hr = pNewSession->Initialize(szID,
										m_spServiceProvider,
										reinterpret_cast<DWORD_PTR>(this),
										GetProviderInfo);
			if (hr == S_OK)
			{
				// we don't hold a reference to the object
				hr = pNewSession->QueryInterface(ppSession);
			}
		}

		if (hr != S_OK && pNewSession)
			delete pNewSession;
		return hr;
	}

	HRESULT CloseSession(LPCSTR szID) throw()
	{
		if (!szID)
			return E_INVALIDARG;

		CDataConnection conn;
		HRESULT hr = GetSessionConnection(&conn,
										  m_spServiceProvider);
		if (hr != S_OK)
			return hr;

		// set up accessors
		CCommand<CAccessor<CSessionRefUpdator> > updator;
		CCommand<CAccessor<CSessionDataDeleteAll> > command;
		_ATLTRY
		{
			CA2CT session(szID);
			hr = updator.Assign(session);
			if (hr == S_OK)
				hr = command.Assign(session);
		}
		_ATLCATCHALL()
		{
			hr = E_OUTOFMEMORY;
		}

		if (hr == S_OK)
		{
			// delete all session variables
			hr = command.Open(conn,
								m_QueryObj.GetSessionVarDeleteAllVars(),
								NULL,
								NULL,
								DBGUID_DEFAULT,
								false);
			if (hr == S_OK)
			{
				// delete references in the session references table
				hr = updator.Open(conn,
							m_QueryObj.GetSessionRefDeleteFinal(),
							NULL,
							NULL,
							DBGUID_DEFAULT,
							false);
			}
		}
		return hr;
	}

	HRESULT SetSessionTimeout(unsigned __int64 nTimeout) throw()
	{
		// Get the data connection for this thread
		CDataConnection conn;
	
		HRESULT hr = GetSessionConnection(&conn, m_spServiceProvider);
		if (hr != S_OK)
			return hr;

		// all sessions get the same timeout
		CCommand<CAccessor<CSetAllTimeouts> > command;
		hr = command.Assign(nTimeout);
		if (hr == S_OK)
		{
			hr = command.Open(conn, m_QueryObj.GetSessionReferencesSet(),
						NULL,
						NULL,
						DBGUID_DEFAULT,
						false);
			if (hr == S_OK)
			{
				m_dwTimeout = nTimeout;
			}
		}
		return hr;
	}


	HRESULT GetSessionTimeout(unsigned __int64* pnTimeout) throw()
	{
		if (pnTimeout)
			*pnTimeout = m_dwTimeout;
		else
			return E_INVALIDARG;

		return S_OK;
	}

	HRESULT GetSessionCount(DWORD *pnCount) throw()
	{
		if (pnCount)
			*pnCount = 0;
		else
			return E_INVALIDARG;

		CCommand<CAccessor<CSessionRefCount> > command;
		CDataConnection conn;
		HRESULT hr = GetSessionConnection(&conn,
											m_spServiceProvider);
		if (hr != S_OK)
			return hr;

		hr = command.Open(conn,
							m_QueryObj.GetSessionRefGetCount());
		if (hr == S_OK)
		{
			hr = command.MoveFirst();
			if (hr == S_OK)
			{
				*pnCount = (DWORD)command.m_nCount;
			}
		}

		return hr;
	}

	void ReleaseAllSessions() throw()
	{
		// nothing to do
	}

	void SweepSessions() throw()
	{
		// nothing to do
	}


	// Helpers
	HRESULT IsValidSession(LPCTSTR szID) throw()
	{
		if (!szID)
			return E_INVALIDARG;
		// Look in the sessionreferences table to see if there is an entry
		// for this session.
		if (m_szConnectionString[0] == 0)
			return E_UNEXPECTED;

		CDataConnection conn;
		HRESULT hr = GetSessionConnection(&conn,
											m_spServiceProvider);
		if (hr != S_OK)
			return hr;
		
		// Check the session references table to see if
		// this is a valid session
		CCommand<CAccessor<CSessionRefSelector> > selector;
		hr = selector.Assign(szID);
		if (hr != S_OK)
			return hr;

		// The SQL for this command only deletes the
		// session reference from the references table if it's access
		// count is 0 and it has expired.
		hr = selector.Open(conn,
							m_QueryObj.GetSessionRefSelect(),
							NULL,
							NULL,
							DBGUID_DEFAULT,
							true);
		if (hr == S_OK)
			return selector.MoveFirst();
		return hr;
	}

	CSessionNameGenerator m_SessionNameGenerator; // Object for generating session names
	unsigned __int64 m_dwTimeout;
}; // CDBSessionServiceImplT

typedef CDBSessionServiceImplT<> CDBSessionServiceImpl;





//////////////////////////////////////////////////////////////////
//
// In-memory persisted session
//
//////////////////////////////////////////////////////////////////

// In-memory persisted session service keeps a pointer
// to the session obejct around in memory. The pointer is
// contained in a CComPtr, which is stored in a CAtlMap, so
// we have to have a CElementTraits class for that.
typedef CComPtr<ISession> SESSIONPTRTYPE;

template<>
class CElementTraits<SESSIONPTRTYPE> :
	public CElementTraitsBase<SESSIONPTRTYPE>
{
public:
	static ULONG Hash( INARGTYPE obj ) throw()
	{
		return( (ULONG)(ULONG_PTR)obj.p);
	}

	static BOOL CompareElements( OUTARGTYPE element1, OUTARGTYPE element2 ) throw()
	{
		return element1.IsEqualObject(element2.p) ? TRUE : FALSE;
	}

	static int CompareElementsOrdered( INARGTYPE , INARGTYPE ) throw()
	{
		ATLASSERT(0); // NOT IMPLEMENTED
		return 0;
	}
};


// CMemSession
// This session persistance class persists session variables in memory.
// Note that this type of persistance should only be used on single server
// web sites.
class CMemSession :
	public ISession,
	public CComObjectRootEx<CComGlobalsThreadModel>
{
public:
	BEGIN_COM_MAP(CMemSession)
		COM_INTERFACE_ENTRY(ISession)
	END_COM_MAP()

	CMemSession() throw(...)
	{
	}

	STDMETHOD(GetVariable)(LPCSTR szName, VARIANT *pVal) throw()
	{
		if (!szName)
			return E_INVALIDARG;

		if (pVal)
			VariantInit(pVal);
		else
			return E_POINTER;

		HRESULT hr = Access();
		if (hr == S_OK)
		{
			CSLockType lock(m_cs, false);
			hr = lock.Lock();
			if (FAILED(hr))
				return hr;
			_ATLTRY
			{
				CComVariant val;
				if (m_Variables.Lookup(szName, val))
				{
					hr = VariantCopy(pVal, &val);
				}
			}
			_ATLCATCHALL()
			{
				hr = E_UNEXPECTED;
			}
		}
		return hr;
	}

	STDMETHOD(SetVariable)(LPCSTR szName, VARIANT vNewVal) throw()
	{
		if (!szName)
			return E_INVALIDARG;
		
		HRESULT hr = Access();
		if (hr == S_OK)
		{
			CSLockType lock(m_cs, false);
			hr = lock.Lock();
			if (FAILED(hr))
				return hr;
			_ATLTRY
			{
				hr = m_Variables.SetAt(szName, vNewVal) ? S_OK : E_FAIL;
			}
			_ATLCATCHALL()
			{
				hr = E_UNEXPECTED;
			}
		}
		return hr;
	}

	STDMETHOD(RemoveVariable)(LPCSTR szName) throw()
	{
		if (!szName)
			return E_INVALIDARG;
		
		HRESULT hr = Access();
		if (hr == S_OK)
		{
			CSLockType lock(m_cs, false);
			hr = lock.Lock();
			if (FAILED(hr))
				return hr;
			_ATLTRY
			{
				hr = m_Variables.RemoveKey(szName) ? S_OK : E_FAIL;
			}
			_ATLCATCHALL()
			{
				hr = E_UNEXPECTED;
			}
		}
		return hr;
	}

	STDMETHOD(GetCount)(long *pnCount) throw()
	{
		if (pnCount)
			return *pnCount = 0;
		else
			return E_POINTER;

		HRESULT hr = Access();
		if (hr == S_OK)
		{
			CSLockType lock(m_cs, false);
			hr = lock.Lock();
			if (FAILED(hr))
				return hr;
			*pnCount = (long) m_Variables.GetCount();
		}
		return hr;
	}

	STDMETHOD(RemoveAllVariables)() throw()
	{
		HRESULT hr = Access();
		if (hr == S_OK)
		{
			CSLockType lock(m_cs, false);
			hr = lock.Lock();
			if (FAILED(hr))
				return hr;
			m_Variables.RemoveAll();
		}
		
		return hr;
	}

	STDMETHOD(BeginVariableEnum)(HSESSIONENUM *phEnumHandle,  POSITION *pPOS) throw()
	{
		if (phEnumHandle)
			*phEnumHandle = NULL;
		else
			return E_POINTER;

		if (pPOS)
			*pPOS = NULL;
		else
			return E_POINTER;
		
		HRESULT hr = Access();
		if (hr == S_OK)
		{
			CSLockType lock(m_cs, false);
			hr = lock.Lock();
			if (FAILED(hr))
				return hr;
			*pPOS = m_Variables.GetStartPosition();
		}
		return hr;
	}

	STDMETHOD(GetNextVariable)(HSESSIONENUM /*hEnum*/,
							   POSITION *pPOS, LPSTR szName,
					           DWORD dwLen, VARIANT *pVal) throw()
	{
		if (!szName)
			return E_INVALIDARG;

		if (pVal)
			VariantInit(pVal);
		else
			return E_POINTER;

		if (!pPOS)
			return E_POINTER;

		CComVariant val;
		POSITION pos = *pPOS;
		HRESULT hr = Access();
		if (hr == S_OK)
		{
			CSLockType lock(m_cs, false);
			hr = lock.Lock();
			if (FAILED(hr))
				return hr;

			_ATLTRY
			{
				CStringA strName = m_Variables.GetKeyAt(pos);
				if (strName.GetLength())
				{
					if (dwLen > (DWORD)strName.GetLength())
						strcpy(szName, strName);
					else
						hr = E_OUTOFMEMORY;
				}
				if (hr == S_OK)
				{
					val = m_Variables.GetNextValue(pos);
					hr = VariantCopy(pVal, &val);
					if (hr == S_OK)
						*pPOS = pos;
				}

			}
			_ATLCATCHALL()
			{
				hr = E_UNEXPECTED;
			}
		}

		return hr;
	}

	STDMETHOD(CloseEnum)(HSESSIONENUM /*hEnumHandle*/) throw()
	{
		return S_OK;
	}

	STDMETHOD(IsExpired)() throw()
	{
		CTime tmNow = CTime::GetCurrentTime();
		CTimeSpan span = tmNow-m_tLastAccess;
		if ((unsigned __int64)((span.GetTotalSeconds()*1000)) > m_dwTimeout)
			return S_OK;
		return S_FALSE;
	}

	HRESULT Access() throw()
	{
		// We lock here to protect against multiple threads
		// updating the same member concurrently.
		CSLockType lock(m_cs, false);
		HRESULT hr = lock.Lock();
		if (FAILED(hr))
			return hr;
		m_tLastAccess = CTime::GetCurrentTime();
		return S_OK;
	}

	STDMETHOD(SetTimeout)(unsigned __int64 dwNewTimeout) throw()
	{
		// We lock here to protect against multiple threads
		// updating the same member concurrently
		CSLockType lock(m_cs, false);
		HRESULT hr = lock.Lock();
		if (FAILED(hr))
			return hr;
		m_dwTimeout = dwNewTimeout;
		return S_OK;
	}

	HRESULT SessionLock() throw()
	{
		Access();
		return S_OK;
	}

	HRESULT SessionUnlock() throw()
	{
		return S_OK;
	}

protected:
	typedef CAtlMap<CStringA,
					CComVariant,
					CStringElementTraits<CStringA> > VarMapType;
	unsigned __int64 m_dwTimeout;
	CTime m_tLastAccess;
	VarMapType m_Variables;
	CComAutoCriticalSection m_cs;
	typedef CComCritSecLock<CComAutoCriticalSection> CSLockType;
}; // CMemSession


//
// CMemSessionServiceImpl
// Implements the service part of in-memory persisted session services.
//
class CMemSessionServiceImpl
{
public:
	typedef void* SERVICEIMPL_INITPARAM_TYPE;
	CMemSessionServiceImpl() throw()
	{
		m_dwTimeout = ATL_SESSION_TIMEOUT;
	}

	HRESULT CreateNewSession(LPSTR szNewID, DWORD *pdwSize, ISession** ppSession) throw()
	{
		HRESULT hr = E_FAIL;
		CComObject<CMemSession> *pNewSession = NULL;

		if (!szNewID)
			return E_INVALIDARG;

		if (!pdwSize)
			return E_POINTER;

		if (ppSession)
			*ppSession = NULL;
		else
			return E_POINTER;

		_ATLTRY
		{
			// Create new session
			CComObject<CMemSession>::CreateInstance(&pNewSession);
			if (pNewSession == NULL)
				return E_OUTOFMEMORY;

			// Initialize and add to list of CSessionData
			hr = m_SessionNameGenerator.GetNewSessionName(szNewID, pdwSize);

			if (SUCCEEDED(hr))
			{
				CComPtr<ISession> spSession;
				hr = pNewSession->QueryInterface(&spSession);
				if (SUCCEEDED(hr))
				{
					pNewSession->SetTimeout(m_dwTimeout);
					pNewSession->Access();
					CSLockType lock(m_CritSec, false);
					hr = lock.Lock();
					if (FAILED(hr))
						return hr;
					m_Sessions.SetAt(szNewID, spSession);
					*ppSession = spSession.Detach();
				}
			}
		}
		_ATLCATCHALL()
		{
			hr = E_UNEXPECTED;
		}

		return hr;
	
	}

	HRESULT GetSession(LPCSTR szID, ISession **ppSession) throw()
	{
		HRESULT hr = E_FAIL;
		SessMapType::CPair *pPair = NULL;

		if (ppSession)
			*ppSession = NULL;
		else
			return E_POINTER;

		if (!szID)
			return E_INVALIDARG;

		CSLockType lock(m_CritSec, false);
		hr = lock.Lock();
		if (FAILED(hr))
			return hr;
		_ATLTRY
		{
			pPair = m_Sessions.Lookup(szID); 
			if (pPair) // the session exists and is in our local map of sessions
			{
				hr = pPair->m_value.QueryInterface(ppSession);
			}
		}
		_ATLCATCHALL()
		{
			return E_UNEXPECTED;
		}

		return hr;	
	}

	HRESULT CloseSession(LPCSTR szID) throw()
	{
		if (!szID)
			return E_INVALIDARG;

		HRESULT hr = E_FAIL;
		CSLockType lock(m_CritSec, false);
		hr = lock.Lock();
		if (FAILED(hr))
			return hr;
		_ATLTRY
		{
			hr = m_Sessions.RemoveKey(szID) ? S_OK : E_FAIL;
		}
		_ATLCATCHALL()
		{
			hr = E_UNEXPECTED;
		}
		return hr;
	}

	void SweepSessions() throw()
	{
		POSITION posRemove = NULL;
		const SessMapType::CPair *pPair = NULL;
		POSITION pos = NULL;

		CSLockType lock(m_CritSec, false);
		if (FAILED(lock.Lock()))
			return;
		pos = m_Sessions.GetStartPosition();
		while (pos)
		{
			posRemove = pos;
			pPair = m_Sessions.GetNext(pos);
			if (pPair)
			{

				if (pPair->m_value.p &&
					S_OK == pPair->m_value->IsExpired())
				{
					// remove our reference on the session
					m_Sessions.RemoveAtPos(posRemove);
				}
			}
		}
	}

	HRESULT SetSessionTimeout(unsigned __int64 nTimeout) throw()
	{
		HRESULT hr = S_OK;
		CComPtr<ISession> spSession;
		m_dwTimeout = nTimeout;
		POSITION pos = m_Sessions.GetStartPosition();

		CSLockType lock(m_CritSec, false);
		hr = lock.Lock();
		if (FAILED(hr))
			return hr;

		while (pos)
		{
			SessMapType::CPair *pPair = const_cast<SessMapType::CPair*>(m_Sessions.GetNext(pos));
			if (pPair)
			{
				spSession = pPair->m_value;
				if (spSession)
				{
					// if we fail on any of the sets we will return the
					// error code immediately
					hr = spSession->SetTimeout(nTimeout);
					spSession.Release();
					if (hr != S_OK)
						break;
				}
			}
		}

		return hr;
	}

	HRESULT GetSessionTimeout(unsigned __int64* pnTimeout) throw()
	{
		if (pnTimeout)
			*pnTimeout = m_dwTimeout;
		else
			return E_POINTER;
		
		return S_OK;
	}

	HRESULT GetSessionCount(DWORD *pnCount) throw()
	{
		if (pnCount)
			*pnCount = 0;
		else
			return E_POINTER;

		CSLockType lock(m_CritSec, false);
		HRESULT hr = lock.Lock();
		if (FAILED(hr))
			return hr;
		*pnCount = (DWORD)m_Sessions.GetCount();

		return S_OK;
	}

	void ReleaseAllSessions() throw()
	{
		CSLockType lock(m_CritSec, false);
		if (FAILED(lock.Lock()))
			return;
		m_Sessions.RemoveAll();
	}

	HRESULT Initialize(SERVICEIMPL_INITPARAM_TYPE,
					   IServiceProvider*,
					   unsigned __int64 dwNewTimeout) throw()
	{
		m_dwTimeout = dwNewTimeout;
		return m_CritSec.Init();
	}

	typedef CAtlMap<CStringA,
					SESSIONPTRTYPE,
					CStringElementTraits<CStringA>,
					CElementTraitsBase<SESSIONPTRTYPE> > SessMapType;

	SessMapType m_Sessions; // map for holding sessions in memory
	CComCriticalSection m_CritSec; // for synchronizing access to map
	typedef CComCritSecLock<CComCriticalSection> CSLockType;
	CSessionNameGenerator m_SessionNameGenerator; // Object for generating session names
	unsigned __int64 m_dwTimeout;
}; // CMemSessionServiceImpl



//
// CSessionStateService
// This class implements the session state service which can be
// exposed to request handlers.
//
// Template Parameters:
// CMonitorClass: Provides periodic sweeping services for the session service class.
// TServiceImplClass: The class that actually implements the methods of the
//                    ISessionStateService and ISessionStateControl interfaces.
template <class CMonitorClass, class TServiceImplClass >
class CSessionStateService : 
	public ISessionStateService,
	public ISessionStateControl,
	public IWorkerThreadClient,
	public CComObjectRootEx<CComGlobalsThreadModel>
{
protected:
	CMonitorClass m_Monitor;
	HANDLE m_hTimer;
	CComPtr<IServiceProvider> m_spServiceProvider;
	TServiceImplClass m_SessionServiceImpl;
public:
	// Construction/Initialization
	CSessionStateService() throw() :
	  m_hTimer(NULL)
	  {
		
	  }
	~CSessionStateService() throw()
	{
		ATLASSERT(m_hTimer == NULL);
	}
	BEGIN_COM_MAP(CSessionStateService)
		COM_INTERFACE_ENTRY(ISessionStateService)
		COM_INTERFACE_ENTRY(ISessionStateControl)
	END_COM_MAP()

// ISessionStateServie methods
	STDMETHOD(CreateNewSession)(LPSTR szNewID, DWORD *pdwSize, ISession** ppSession) throw()
	{
		return m_SessionServiceImpl.CreateNewSession(szNewID, pdwSize, ppSession);
	}

	STDMETHOD(GetSession)(LPCSTR szID, ISession **ppSession) throw()
	{
		return m_SessionServiceImpl.GetSession(szID, ppSession);
	}

	STDMETHOD(CloseSession)(LPCSTR szSessionID) throw()
	{
		return m_SessionServiceImpl.CloseSession(szSessionID);
	}

	STDMETHOD(SetSessionTimeout)(unsigned __int64 nTimeout) throw()
	{
		return m_SessionServiceImpl.SetSessionTimeout(nTimeout);
	}

	STDMETHOD(GetSessionTimeout)(unsigned __int64 *pnTimeout) throw()
	{
		return m_SessionServiceImpl.GetSessionTimeout(pnTimeout);
	}

	STDMETHOD(GetSessionCount)(DWORD *pnSessionCount) throw()
	{
		return m_SessionServiceImpl.GetSessionCount(pnSessionCount);
	}

	void SweepSessions() throw()
	{
		m_SessionServiceImpl.SweepSessions();
	}

	void ReleaseAllSessions() throw()
	{
		m_SessionServiceImpl.ReleaseAllSessions();
	}

	HRESULT Initialize(
		IServiceProvider *pServiceProvider = NULL,
		unsigned __int64 dwTimeout = ATL_SESSION_TIMEOUT,
		TServiceImplClass::SERVICEIMPL_INITPARAM_TYPE pInitData = NULL) throw()
	{
		HRESULT hr = S_OK;
		if (pServiceProvider)
			m_spServiceProvider = pServiceProvider;

		hr = m_SessionServiceImpl.Initialize(pInitData, pServiceProvider, dwTimeout);

		return hr;
	}

	template <class ThreadTraits>
	HRESULT Initialize(
		CWorkerThread<ThreadTraits> *pWorker,
		IServiceProvider *pServiceProvider = NULL,
		unsigned __int64 dwTimeout = ATL_SESSION_TIMEOUT,
		TServiceImplClass::SERVICEIMPL_INITPARAM_TYPE pInitData = NULL) throw()
	{
		if (!pWorker)
			return E_INVALIDARG;

		HRESULT hr = Initialize(pServiceProvider, dwTimeout, pInitData);
		if (hr == S_OK)
		{
			hr = m_Monitor.Initialize(pWorker);
			if (hr == S_OK)
			{
				//sweep every 500ms
				hr = m_Monitor.AddTimer(ATL_SESSION_SWEEPER_TIMEOUT, this, 0, &m_hTimer);
			}
		}
		return hr;
	}

	HRESULT Execute(DWORD_PTR /*dwParam*/, HANDLE /*hObject*/) throw()
	{
		SweepSessions();
		return S_OK;
	}

	HRESULT CloseHandle(HANDLE hHandle) throw()
	{
		::CloseHandle(hHandle);
		m_hTimer = NULL;
		return S_OK;
	}

	void Shutdown() throw()
	{
		if (m_hTimer)
		{
			m_Monitor.RemoveHandle(m_hTimer);
			m_hTimer = NULL;
		}
		ReleaseAllSessions();
	}
}; // CSessionStateService

} // namespace ATL

#pragma warning(pop)
#endif // __ATLSESSION_H__