|
|
//========= Copyright Valve Corporation, All rights reserved. ============//
//
// Purpose:
//
// $NoKeywords: $
//=============================================================================//
#include <windows.h>
#include "tcpsocket.h"
#include "IThreadedTCPSocket.h"
#include "ThreadedTCPSocketEmu.h"
#include "ThreadHelpers.h"
// ---------------------------------------------------------------------------------------- //
// CThreadedTCPSocketEmu. This uses IThreadedTCPSocket to emulate the polling-type interface
// in ITCPSocket.
// ---------------------------------------------------------------------------------------- //
// This class uses the IThreadedTCPSocket interface to emulate the old ITCPSocket.
class CThreadedTCPSocketEmu : public ITCPSocket, public ITCPSocketHandler, public IHandlerCreator { public: CThreadedTCPSocketEmu() { m_pSocket = NULL; m_LocalPort = 0xFFFF; m_pConnectSocket = NULL; m_RecvPacketsEvent.Init( false, false ); m_bError = false; }
virtual ~CThreadedTCPSocketEmu() { Term(); }
void Init( IThreadedTCPSocket *pSocket ) { m_pSocket = pSocket; }
void Term() { if ( m_pSocket ) { m_pSocket->Release(); m_pSocket = NULL; } if ( m_pConnectSocket ) { m_pConnectSocket->Release(); m_pConnectSocket = NULL; } }
// ITCPSocketHandler implementation.
private:
virtual void OnPacketReceived( CTCPPacket *pPacket ) { CCriticalSectionLock csLock( &m_RecvPacketsCS ); csLock.Lock();
m_RecvPackets.AddToTail( pPacket ); m_RecvPacketsEvent.SetEvent(); }
virtual void OnError( int errorCode, const char *pErrorString ) { CCriticalSectionLock csLock( &m_ErrorStringCS ); csLock.Lock();
m_ErrorString.CopyArray( pErrorString, strlen( pErrorString ) + 1 ); m_bError = true; }
// IHandlerCreator implementation.
public: // This is used for connecting.
virtual ITCPSocketHandler* CreateNewHandler() { return this; }
// ITCPSocket implementation.
public:
virtual void Release() { delete this; }
virtual bool BindToAny( const unsigned short port ) { m_LocalPort = port; return true; }
virtual bool BeginConnect( const CIPAddr &addr ) { // They should have "bound" to a port before trying to connect.
Assert( m_LocalPort != 0xFFFF ); if ( m_pConnectSocket ) m_pConnectSocket->Release();
m_pConnectSocket = ThreadedTCP_CreateConnector( addr, CIPAddr( 0, 0, 0, 0, m_LocalPort ), this );
return m_pConnectSocket != 0; }
virtual bool UpdateConnect() { Assert( !m_pSocket ); if ( !m_pConnectSocket ) return false;
if ( m_pConnectSocket->Update( &m_pSocket ) ) { if ( m_pSocket ) { // Ok, we're connected now.
m_pConnectSocket->Release(); m_pConnectSocket = NULL; return true; } else { return false; } } else { Assert( false ); m_pConnectSocket->Release(); m_pConnectSocket = NULL; return false; } }
virtual bool IsConnected() { if ( m_bError ) { Term(); return false; } else { return m_pSocket != NULL; } }
virtual void GetDisconnectReason( CUtlVector<char> &reason ) { CCriticalSectionLock csLock( &m_ErrorStringCS ); csLock.Lock();
reason = m_ErrorString; }
virtual bool Send( const void *pData, int size ) { Assert( m_pSocket ); if ( !m_pSocket ) return false;
return m_pSocket->Send( pData, size ); }
virtual bool SendChunks( void const * const *pChunks, const int *pChunkLengths, int nChunks ) { Assert( m_pSocket ); if ( !m_pSocket || !m_pSocket->IsValid() ) return false;
return m_pSocket->SendChunks( pChunks, pChunkLengths, nChunks ); }
virtual bool Recv( CUtlVector<unsigned char> &data, double flTimeout ) { // Use our m_RecvPacketsEvent event to determine if there is data to receive yet.
DWORD nMilliseconds = (DWORD)( flTimeout * 1000.0f ); DWORD ret = WaitForSingleObject( m_RecvPacketsEvent.GetEventHandle(), nMilliseconds ); if ( ret == WAIT_OBJECT_0 ) { // Ok, there's a packet.
CCriticalSectionLock csLock( &m_RecvPacketsCS ); csLock.Lock(); Assert( m_RecvPackets.Count() > 0 ); int iHead = m_RecvPackets.Head(); CTCPPacket *pPacket = m_RecvPackets[ iHead ]; data.CopyArray( (const unsigned char*)pPacket->GetData(), pPacket->GetLen() ); pPacket->Release(); m_RecvPackets.Remove( iHead ); // Re-set the event if there are more packets left to receive.
if ( m_RecvPackets.Count() > 0 ) { m_RecvPacketsEvent.SetEvent(); }
return true; } else { return false; } }
private: IThreadedTCPSocket *m_pSocket; unsigned short m_LocalPort; // The port we bind to when we want to connect.
ITCPConnectSocket *m_pConnectSocket;
// All the received data is stored in here.
CEvent m_RecvPacketsEvent; CCriticalSection m_RecvPacketsCS; CUtlLinkedList<CTCPPacket*, int> m_RecvPackets;
CCriticalSection m_ErrorStringCS; CUtlVector<char> m_ErrorString; bool m_bError; // Set to true when there's an error. Next chance we get in the main thread, we'll close the socket.
};
ITCPSocket* CreateTCPSocketEmu() { return new CThreadedTCPSocketEmu; }
// ---------------------------------------------------------------------------------------- //
// CThreadedTCPListenSocketEmu implementation.
// ---------------------------------------------------------------------------------------- //
class CThreadedTCPListenSocketEmu : public ITCPListenSocket, public IHandlerCreator { public: CThreadedTCPListenSocketEmu() { m_pListener = NULL; m_pLastCreatedSocket = NULL; }
virtual ~CThreadedTCPListenSocketEmu() { if ( m_pListener ) m_pListener->Release(); }
bool StartListening( const unsigned short port, int nQueueLength ) { m_pListener = ThreadedTCP_CreateListener( this, port, nQueueLength );
return m_pListener != 0; }
// ITCPListenSocket implementation.
private:
virtual void Release() { delete this; }
virtual ITCPSocket* UpdateListen( CIPAddr *pAddr ) { if ( !m_pListener ) return NULL;
IThreadedTCPSocket *pSocket; if ( m_pListener->Update( &pSocket ) && pSocket ) { *pAddr = pSocket->GetRemoteAddr(); // This is pretty hacky, but this stuff is just around for test code.
CThreadedTCPSocketEmu *pLast = m_pLastCreatedSocket; pLast->Init( pSocket ); m_pLastCreatedSocket = NULL; return pLast; } else { return NULL; } }
// IHandlerCreator implementation.
private:
virtual ITCPSocketHandler* CreateNewHandler() { m_pLastCreatedSocket = new CThreadedTCPSocketEmu; return m_pLastCreatedSocket; }
private:
ITCPConnectSocket *m_pListener; CThreadedTCPSocketEmu *m_pLastCreatedSocket; };
ITCPListenSocket* CreateTCPListenSocketEmu( const unsigned short port, int nQueueLength ) { CThreadedTCPListenSocketEmu *pSocket = new CThreadedTCPListenSocketEmu; if ( pSocket->StartListening( port, nQueueLength ) ) { return pSocket; } else { delete pSocket; return NULL; } }
|