|
|
//========= Copyright Valve Corporation, All rights reserved. ============//
//
// Purpose:
//
// $NoKeywords: $
//=============================================================================//
//#define PARANOID
#if defined( PARANOID )
#include <stdlib.h>
#include <crtdbg.h>
#endif
#include <winsock2.h>
#include <mswsock.h>
#include "tcpsocket.h"
#include "tier1/utllinkedlist.h"
#include <stdio.h>
#include "threadhelpers.h"
#include "tier0/dbg.h"
#error "I am TCPSocket and I suck. Use IThreadedTCPSocket or ThreadedTCPSocketEmu instead."
extern TIMEVAL SetupTimeVal( double flTimeout ); extern void IPAddrToSockAddr( const CIPAddr *pIn, sockaddr_in *pOut ); extern void SockAddrToIPAddr( const sockaddr_in *pIn, CIPAddr *pOut );
#define SENTINEL_DISCONNECT -1
#define SENTINEL_KEEPALIVE -2
#define KEEPALIVE_INTERVAL_MS 3000 // keepalives are sent every N MS
#define KEEPALIVE_TIMEOUT_SECONDS 15.0 // connections timeout after this long
static bool g_bEnableTCPTimeout = true;
class CRecvData { public: int m_Count; unsigned char m_Data[1]; };
SOCKET TCPBind( const CIPAddr *pAddr ) { // Create a socket to send and receive through.
SOCKET sock = WSASocket( AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED ); if ( sock == INVALID_SOCKET ) { Assert( false ); return INVALID_SOCKET; }
// bind to it!
sockaddr_in addr; IPAddrToSockAddr( pAddr, &addr );
int status = bind( sock, (sockaddr*)&addr, sizeof(addr) ); if ( status == 0 ) { return sock; } else { closesocket( sock ); return INVALID_SOCKET; } }
// ---------------------------------------------------------------------------------------- //
// TCP sockets.
// ---------------------------------------------------------------------------------------- //
enum { OP_RECV=111, OP_SEND };
// We use this for all OVERLAPPED structures.
class COverlappedPlus : public WSAOVERLAPPED { public: COverlappedPlus() { memset( this, 0, sizeof( WSAOVERLAPPED ) ); }
int m_OPType; // One of the OP_ defines.
};
typedef struct SendBuf_t { COverlappedPlus m_Overlapped; int m_Index; // Index into m_SendBufs.
int m_DataLength; char m_Data[1]; } SendBuf_s;
// These manage a thread that calls SendKeepalive() on all TCPSockets.
// AddGlobalTCPSocket shouldn't be called until you're ready for SendKeepalive() to be called.
class CTCPSocket; void AddGlobalTCPSocket( CTCPSocket *pSocket ); void RemoveGlobalTCPSocket( CTCPSocket *pSocket );
// ------------------------------------------------------------------------------------------ //
// CTCPSocket implementation.
// ------------------------------------------------------------------------------------------ //
class CTCPSocket : public ITCPSocket { friend class CTCPListenSocket;
public:
CTCPSocket() { m_Socket = INVALID_SOCKET; m_bConnected = false; m_hIOCP = NULL; m_bShouldExitThreads = false; m_bConnectionLost = false; m_nSizeBytesReceived = 0;
m_pIncomingData = NULL;
memset( &m_RecvOverlapped, 0, sizeof( m_RecvOverlapped ) ); m_RecvOverlapped.m_OPType = OP_RECV;
m_hRecvSignal = CreateEvent( NULL, FALSE, FALSE, NULL ); m_RecvStage = -1;
m_MainThreadID = GetCurrentThreadId(); } virtual ~CTCPSocket() { Term(); CloseHandle( m_hRecvSignal ); }
void Term() { Assert( GetCurrentThreadId() == m_MainThreadID );
RemoveGlobalTCPSocket( this );
if ( m_Socket != SOCKET_ERROR && !m_bConnectionLost ) { SendDisconnectSentinel(); // Give the sends a second to complete. SO_LINGER is having trouble for some reason.
WaitForSendsToComplete( 1 ); }
StopThreads();
if ( m_Socket != INVALID_SOCKET ) { closesocket( m_Socket ); m_Socket = INVALID_SOCKET; }
if ( m_hIOCP ) { CloseHandle( m_hIOCP ); m_hIOCP = NULL; }
m_bConnected = false; m_bConnectionLost = true; m_RecvStage = -1; FOR_EACH_LL( m_SendBufs, i ) { SendBuf_t *pSendBuf = m_SendBufs[i]; ParanoidMemoryCheck( pSendBuf ); free( pSendBuf ); } m_SendBufs.Purge();
FOR_EACH_LL( m_RecvDatas, j ) { CRecvData *pRecvData = m_RecvDatas[j]; ParanoidMemoryCheck( pRecvData ); free( pRecvData ); } m_RecvDatas.Purge();
if ( m_pIncomingData ) { ParanoidMemoryCheck( m_pIncomingData ); free( m_pIncomingData ); m_pIncomingData = 0; } }
virtual void Release() { delete this; }
void ParanoidMemoryCheck( void *ptr = NULL ) { #if defined( PARANOID )
Assert( _CrtIsValidHeapPointer( this ) );
if ( ptr ) { Assert( _CrtIsValidHeapPointer( ptr ) ); }
Assert( _CrtCheckMemory() == TRUE ); #endif
}
virtual bool BindToAny( const unsigned short port ) { Term();
CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY
m_Socket = TCPBind( &addr ); if ( m_Socket == INVALID_SOCKET ) { return false; } else { SetInitialSocketOptions(); return true; } }
// Set the initial socket options that we want.
void SetInitialSocketOptions() { // Set nodelay to improve latency.
BOOL val = TRUE; setsockopt( m_Socket, IPPROTO_TCP, TCP_NODELAY, (const char FAR *)&val, sizeof(BOOL) );
// Make it linger for 3 seconds when it exits.
LINGER linger; linger.l_onoff = 1; linger.l_linger = 3; setsockopt( m_Socket, SOL_SOCKET, SO_LINGER, (char*)&linger, sizeof( linger ) ); }
// Called only by main thread interface functions.
// Returns true if the connection is lost.
bool CheckConnectionLost() { Assert( GetCurrentThreadId() == m_MainThreadID );
if ( m_Socket == SOCKET_ERROR ) return true;
// Have we timed out?
if ( g_bEnableTCPTimeout && (Plat_FloatTime() - m_LastRecvTime > KEEPALIVE_TIMEOUT_SECONDS) ) { SetConnectionLost( "Connection timed out." ); }
// Has any thread posted that the connection has been lost?
CCriticalSectionLock postLock( &m_ConnectionLostCS ); postLock.Lock(); if ( m_bConnectionLost ) { Term(); return true; } else { return false; } }
// Called by any thread. All interface functions call CheckConnectionLost() and return errors if it's lost.
void SetConnectionLost( const char *pErrorString, int err = -1 ) { CCriticalSectionLock postLock( &m_ConnectionLostCS ); postLock.Lock(); m_bConnectionLost = true; postLock.Unlock();
// Handle it right away if we're in the main thread. If we're in an IO thread,
// it has to wait until the next interface function calls CheckConnectionLost().
if ( GetCurrentThreadId() == m_MainThreadID ) { Term(); } if ( pErrorString ) { m_ErrorString.CopyArray( pErrorString, strlen( pErrorString ) + 1 ); } else { char *lpMsgBuf; FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language
(LPTSTR) &lpMsgBuf, 0, NULL );
m_ErrorString.CopyArray( lpMsgBuf, strlen( lpMsgBuf ) + 1 ); LocalFree( lpMsgBuf ); } }
// -------------------------------------------------------------------------------------------------- //
// The receive code.
// -------------------------------------------------------------------------------------------------- //
virtual bool StartWaitingForSize( bool bFresh ) { Assert( m_Socket != INVALID_SOCKET ); Assert( m_bConnected );
m_RecvStage = 0; m_RecvDataSize = -1; if ( bFresh ) m_nSizeBytesReceived = 0;
DWORD dwNumBytesReceived = 0; WSABUF buf = { sizeof( &m_RecvDataSize ) - m_nSizeBytesReceived, ((char*)&m_RecvDataSize) + m_nSizeBytesReceived }; DWORD dwFlags = 0;
int status = WSARecv( m_Socket, &buf, 1, &dwNumBytesReceived, &dwFlags, &m_RecvOverlapped, NULL );
int err = -1; if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING ) { SetConnectionLost( NULL, err ); return false; } else { return true; } }
bool PostNextDataPart() { DWORD dwNumBytesReceived = 0; WSABUF buf = { m_RecvDataSize - m_AmountReceived, (char*)m_pIncomingData->m_Data + m_AmountReceived }; DWORD dwFlags = 0;
int status = WSARecv( m_Socket, &buf, 1, &dwNumBytesReceived, &dwFlags, &m_RecvOverlapped, NULL );
int err = -1; if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING ) { SetConnectionLost( NULL, err ); return false; } else { return true; } }
bool StartWaitingForData() { Assert( m_Socket != INVALID_SOCKET ); Assert( m_RecvStage == 0 ); Assert( m_bConnected ); Assert( m_RecvDataSize > 0 );
m_RecvStage = 1;
// Add a CRecvData element.
ParanoidMemoryCheck(); m_pIncomingData = (CRecvData*)malloc( sizeof( CRecvData ) - 1 + m_RecvDataSize ); if ( !m_pIncomingData ) { char str[512]; _snprintf( str, sizeof( str ), "malloc() failed. m_RecvDataSize = %d\n", m_RecvDataSize ); SetConnectionLost( str ); return false; }
m_pIncomingData->m_Count = m_RecvDataSize;
m_AmountReceived = 0;
return PostNextDataPart(); }
virtual bool Recv( CUtlVector<unsigned char> &data, double flTimeout ) { if ( CheckConnectionLost() ) return false;
// Wait in 50ms chunks, checking for disconnections along the way.
bool bGotData = false; DWORD msToWait = (DWORD)( flTimeout * 1000.0 ); do { DWORD curWaitTime = min( msToWait, 50 ); DWORD ret = WaitForSingleObject( m_hRecvSignal, curWaitTime ); if ( ret == WAIT_OBJECT_0 ) { bGotData = true; break; }
// Did the connection timeout?
if ( CheckConnectionLost() ) return false;
msToWait -= curWaitTime; } while ( msToWait ); // If we never got a WAIT_OBJECT_0, then we never received anything.
if ( !bGotData ) return false; CCriticalSectionLock csLock( &m_RecvDataCS ); csLock.Lock();
// Pickup the head m_RecvDatas element.
CRecvData *pRecvData = m_RecvDatas[ m_RecvDatas.Head() ]; data.CopyArray( pRecvData->m_Data, pRecvData->m_Count );
// Now free it.
m_RecvDatas.Remove( m_RecvDatas.Head() ); ParanoidMemoryCheck( pRecvData ); free( pRecvData );
// Set the event again for the next time around, if there is more data waiting.
if ( m_RecvDatas.Count() > 0 ) SetEvent( m_hRecvSignal );
return true; }
// INSIDE IO THREAD.
void HandleRecvCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes ) { if ( dwNumBytes == 0 ) { SetConnectionLost( "Got 0 bytes in HandleRecvCompletion" ); return; }
m_LastRecvTime = Plat_FloatTime(); if ( m_RecvStage == 0 ) { m_nSizeBytesReceived += dwNumBytes; if ( m_nSizeBytesReceived == sizeof( m_RecvDataSize ) ) { // Size of -1 means the other size is breaking the connection.
if ( m_RecvDataSize == SENTINEL_DISCONNECT ) { SetConnectionLost( "Got a graceful disconnect message." ); return; } else if ( m_RecvDataSize == SENTINEL_KEEPALIVE ) { // No data follows this. Just let m_LastRecvTime get updated.
StartWaitingForSize( true ); return; }
StartWaitingForData(); } else if ( m_nSizeBytesReceived < sizeof( m_RecvDataSize ) ) { // Handle the case where we only got some of the data (maybe one of the clients got disconnected).
StartWaitingForSize( false ); } else { // This case should never ever happen!
#if defined( _DEBUG )
__asm int 3; #endif
SetConnectionLost( "Received too much data in a packet!" ); return; } } else if ( m_RecvStage == 1 ) { // Got the data, make sure we got it all.
m_AmountReceived += dwNumBytes;
// Sanity check.
#if defined( _DEBUG )
Assert( m_RecvDataSize == m_pIncomingData->m_Count ); Assert( m_AmountReceived <= m_RecvDataSize ); // TODO: make this threadsafe for multiple IO threads.
#endif
if ( m_AmountReceived == m_RecvDataSize ) { m_RecvStage = 2; // Add the data to the list of packets waiting to be picked up.
CCriticalSectionLock csLock( &m_RecvDataCS ); csLock.Lock(); m_RecvDatas.AddToTail( m_pIncomingData ); m_pIncomingData = NULL;
if ( m_RecvDatas.Count() == 1 ) SetEvent( m_hRecvSignal ); // Notify the Recv() function.
StartWaitingForSize( true ); } else { PostNextDataPart(); } } else { Assert( false ); } }
// -------------------------------------------------------------------------------------------------- //
// The send code.
// -------------------------------------------------------------------------------------------------- //
virtual void WaitForSendsToComplete( double flTimeout ) { CWaitTimer waitTimer( flTimeout ); while ( 1 ) { CCriticalSectionLock sendBufLock( &m_SendCS ); sendBufLock.Lock(); if( m_SendBufs.Count() == 0 ) return; sendBufLock.Unlock();
if ( waitTimer.ShouldKeepWaiting() ) Sleep( 10 ); else break; } }
// This is called in the keepalive thread.
void SendKeepalive() { // Send a message saying we're exiting.
ParanoidMemoryCheck(); SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) ); if ( !pBuf ) { SetConnectionLost( "malloc() in SendKeepalive() failed." ); return; }
pBuf->m_DataLength = sizeof( int ); *((int*)pBuf->m_Data) = SENTINEL_KEEPALIVE; InternalSendDataBuf( pBuf ); }
void SendDisconnectSentinel() { // Send a message saying we're exiting.
ParanoidMemoryCheck(); SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) ); if ( pBuf ) { pBuf->m_DataLength = sizeof( int ); *((int*)pBuf->m_Data) = SENTINEL_DISCONNECT; // This signifies that we're exiting.
InternalSendDataBuf( pBuf ); } }
virtual bool Send( const void *pData, int len ) { const void *pChunks[1] = { pData }; int chunkLengths[1] = { len }; return SendChunks( pChunks, chunkLengths, 1 ); }
virtual bool SendChunks( void const * const *pChunks, const int *pChunkLengths, int nChunks ) { if ( CheckConnectionLost() ) return false; CChunkWalker walker( pChunks, pChunkLengths, nChunks ); int totalLength = walker.GetTotalLength();
if ( !totalLength ) return true;
// Create a buffer to hold the data and copy the data in.
ParanoidMemoryCheck(); SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + totalLength + sizeof( int ) ); if ( !pBuf ) { char str[512]; _snprintf( str, sizeof( str ), "malloc() in SendChunks() failed. totalLength = %d.", totalLength ); SetConnectionLost( str ); return false; }
pBuf->m_DataLength = totalLength + sizeof( int );
int *pByteCountPos = (int*)pBuf->m_Data; *pByteCountPos = totalLength;
char *pDataPos = &pBuf->m_Data[ sizeof( int ) ]; walker.CopyTo( pDataPos, totalLength );
int status = InternalSendDataBuf( pBuf ); int err = -1; if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING ) { SetConnectionLost( NULL, err ); return false; } else { return true; } }
int InternalSendDataBuf( SendBuf_t *pBuf ) { // Protect against interference from the keepalive thread.
CCriticalSectionLock csLock( &m_SendCS ); csLock.Lock();
pBuf->m_Overlapped.m_OPType = OP_SEND; pBuf->m_Overlapped.hEvent = NULL;
// Add it to our list of buffers.
pBuf->m_Index = m_SendBufs.AddToTail( pBuf );
// Tell Winsock to send it.
WSABUF buf = { pBuf->m_DataLength, pBuf->m_Data }; DWORD dwNumBytesSent = 0; return WSASend( m_Socket, &buf, 1, &dwNumBytesSent, 0, &pBuf->m_Overlapped, NULL ); }
// INSIDE IO THREAD.
void HandleSendCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes ) { if ( dwNumBytes == 0 ) { SetConnectionLost( "0 bytes in HandleSendCompletion." ); return; }
// Just free the buffer.
SendBuf_t *pBuf = (SendBuf_t*)pInfo; Assert( dwNumBytes == (DWORD)pBuf->m_DataLength );
CCriticalSectionLock sendBufLock( &m_SendCS ); sendBufLock.Lock(); m_SendBufs.Remove( pBuf->m_Index ); sendBufLock.Unlock();
ParanoidMemoryCheck( pBuf ); free( pBuf ); }
// -------------------------------------------------------------------------------------------------- //
// The connect code.
// -------------------------------------------------------------------------------------------------- //
virtual bool BeginConnect( const CIPAddr &inputAddr ) { sockaddr_in addr; IPAddrToSockAddr( &inputAddr, &addr );
m_bConnected = false; int ret = connect( m_Socket, (struct sockaddr*)&addr, sizeof( addr ) ); ret=ret;
return true; }
virtual bool UpdateConnect() { // We're still ok.. just wait until the socket becomes writable (is connected) or we timeout.
fd_set writeSet; writeSet.fd_count = 1; writeSet.fd_array[0] = m_Socket; TIMEVAL timeVal = SetupTimeVal( 0 );
// See if it has a packet waiting.
int status = select( 0, NULL, &writeSet, NULL, &timeVal ); if ( status > 0 ) { SetupConnected(); return true; }
return false; }
void SetupConnected() { m_bConnected = true; m_bConnectionLost = false; m_LastRecvTime = Plat_FloatTime();
CreateThreads(); StartWaitingForSize( true ); AddGlobalTCPSocket( this ); }
virtual bool IsConnected() { CheckConnectionLost(); return m_bConnected; }
virtual void GetDisconnectReason( CUtlVector<char> &reason ) { reason = m_ErrorString; }
// -------------------------------------------------------------------------------------------------- //
// Threads code.
// -------------------------------------------------------------------------------------------------- //
// Create our IO Completion Port threads.
bool CreateThreads() { int nThreads = 1; SetShouldExitThreads( false );
// Create our IO completion port and hook it to our socket.
m_hIOCP = CreateIoCompletionPort( INVALID_HANDLE_VALUE, NULL, 0, 0);
m_hIOCP = CreateIoCompletionPort( (HANDLE)m_Socket, m_hIOCP, (unsigned long)this, nThreads );
for ( int i=0; i < nThreads; i++ ) { DWORD dwThreadID = 0; HANDLE hThread = CreateThread( NULL, 0, &CTCPSocket::StaticThreadFn, this, 0, &dwThreadID );
if ( hThread ) { SetThreadPriority( hThread, THREAD_PRIORITY_ABOVE_NORMAL ); m_Threads.AddToTail( hThread ); } else { StopThreads(); return false; } } return true; }
void StopThreads() { // Tell the threads to exit, then wait for them to do so.
SetShouldExitThreads( true ); WaitForMultipleObjects( m_Threads.Count(), m_Threads.Base(), TRUE, INFINITE );
for ( int i=0; i < m_Threads.Count(); i++ ) { CloseHandle( m_Threads[i] ); } m_Threads.Purge(); }
void SetShouldExitThreads( bool bShouldExit ) { CCriticalSectionLock lock( &m_ThreadsCS ); lock.Lock(); m_bShouldExitThreads = bShouldExit; }
bool ShouldExitThreads() { CCriticalSectionLock lock( &m_ThreadsCS ); lock.Lock();
bool bRet = m_bShouldExitThreads; return bRet; }
DWORD ThreadFn() { while ( 1 ) { DWORD dwNumBytes = 0; unsigned long pInputTCPSocket; LPOVERLAPPED pOverlapped;
if ( GetQueuedCompletionStatus( m_hIOCP, // the port we're listening on
&dwNumBytes, // # bytes received on the port
&pInputTCPSocket,// "completion key" = CTCPSocket*
&pOverlapped, // the overlapped info that was passed into AcceptEx, WSARecv, or WSASend.
100 // listen for 100ms at a time so we can exit gracefully when the socket is deleted.
) ) { COverlappedPlus *pInfo = (COverlappedPlus*)pOverlapped; ParanoidMemoryCheck( pInfo ); if ( pInfo->m_OPType == OP_RECV ) { Assert( pInfo == &m_RecvOverlapped ); HandleRecvCompletion( pInfo, dwNumBytes ); } else { Assert( pInfo->m_OPType == OP_SEND ); HandleSendCompletion( pInfo, dwNumBytes ); } } if ( ShouldExitThreads() ) break; }
return 0; }
static DWORD WINAPI StaticThreadFn( LPVOID pParameter ) { return ((CTCPSocket*)pParameter)->ThreadFn(); }
private:
SOCKET m_Socket; bool m_bConnected;
// m_RecvOverlapped is setup to first wait for the size, then the data.
// Then it is not posted until the app grabs the data.
HANDLE m_hRecvSignal; // Tells Recv() when we have data.
COverlappedPlus m_RecvOverlapped; int m_RecvStage; // -1 = not initialized
// 0 = waiting for size
// 1 = waiting for data
// 2 = waiting for app to pickup the data
CUtlLinkedList<CRecvData*,int> m_RecvDatas; // The head element is the next one to be picked up.
CRecvData *m_pIncomingData; // The packet we're currently receiving.
CCriticalSection m_RecvDataCS; // This protects adds and removes in the list.
// These reference the element at the tail of m_RecvData. It is the current one getting
volatile int m_nSizeBytesReceived; // How much of m_RecvDataSize have we received yet?
int m_RecvDataSize; // this is received over the network
int m_AmountReceived; // How much we've received so far.
// Last time we received anything from this connection. Used to determine if the connection is
// still active.
double m_LastRecvTime;
// Outgoing send buffers.
CUtlLinkedList<SendBuf_t*,int> m_SendBufs; CCriticalSection m_SendCS;
// All the threads waiting for IO.
CUtlVector<HANDLE> m_Threads; HANDLE m_hIOCP;
// Used during shutdown.
volatile bool m_bShouldExitThreads; CCriticalSection m_ThreadsCS;
// For debugging.
DWORD m_MainThreadID;
// Set by the main thread or IO threads to signal connection lost.
bool m_bConnectionLost; CCriticalSection m_ConnectionLostCS;
// This is set when we get disconnected.
CUtlVector<char> m_ErrorString; };
// ------------------------------------------------------------------------------------------ //
// ITCPListenSocket implementation.
// ------------------------------------------------------------------------------------------ //
class CTCPListenSocket : public ITCPListenSocket { public:
CTCPListenSocket() { m_Socket = INVALID_SOCKET; }
virtual ~CTCPListenSocket() { if ( m_Socket != INVALID_SOCKET ) { closesocket( m_Socket ); } }
// The main function to create one of these suckers.
static ITCPListenSocket* Create( const unsigned short port, int nQueueLength ) { CTCPListenSocket *pRet = new CTCPListenSocket; if ( !pRet ) return NULL;
// Bind it to a socket and start listening.
CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY
pRet->m_Socket = TCPBind( &addr ); if ( pRet->m_Socket == INVALID_SOCKET || listen( pRet->m_Socket, nQueueLength == -1 ? SOMAXCONN : nQueueLength ) != 0 ) { pRet->Release(); return false; }
return pRet; }
virtual void Release() { delete this; }
virtual ITCPSocket* UpdateListen( CIPAddr *pAddr ) { // We're still ok.. just wait until the socket becomes writable (is connected) or we timeout.
fd_set readSet; readSet.fd_count = 1; readSet.fd_array[0] = m_Socket; TIMEVAL timeVal = SetupTimeVal( 0 );
// Wait until it connects.
int status = select( 0, &readSet, NULL, NULL, &timeVal ); if ( status > 0 ) { sockaddr_in addr; int addrSize = sizeof( addr );
// Now accept the final connection.
SOCKET newSock = accept( m_Socket, (struct sockaddr*)&addr, &addrSize ); if ( newSock == INVALID_SOCKET ) { Assert( false ); } else { CTCPSocket *pRet = new CTCPSocket; if ( !pRet ) { closesocket( newSock ); return NULL; }
pRet->m_Socket = newSock; pRet->SetInitialSocketOptions(); pRet->SetupConnected();
// Report the address..
SockAddrToIPAddr( &addr, pAddr );
return pRet; } }
return NULL; }
private: SOCKET m_Socket; };
ITCPListenSocket* CreateTCPListenSocket( const unsigned short port, int nQueueLength ) { return CTCPListenSocket::Create( port, nQueueLength ); }
ITCPSocket* CreateTCPSocket() { return new CTCPSocket; }
void TCPSocket_EnableTimeout( bool bEnable ) { g_bEnableTCPTimeout = bEnable; }
// --------------------------------------------------------------------------------- //
// This thread sends keepalives on all active TCP sockets.
// --------------------------------------------------------------------------------- //
HANDLE g_hKeepaliveThread; HANDLE g_hKeepaliveThreadSignal; HANDLE g_hKeepaliveThreadReply; CUtlLinkedList<CTCPSocket*,int> g_TCPSockets; CCriticalSection g_TCPSocketsCS;
DWORD WINAPI TCPKeepaliveThread( LPVOID pParameter ) { while ( 1 ) { if ( WaitForSingleObject( g_hKeepaliveThreadSignal, KEEPALIVE_INTERVAL_MS ) == WAIT_OBJECT_0 ) break;
// Tell all TCP sockets to send a keepalive.
CCriticalSectionLock csLock( &g_TCPSocketsCS ); csLock.Lock();
FOR_EACH_LL( g_TCPSockets, i ) { g_TCPSockets[i]->SendKeepalive(); } }
SetEvent( g_hKeepaliveThreadReply ); return 0; }
void AddGlobalTCPSocket( CTCPSocket *pSocket ) { CCriticalSectionLock csLock( &g_TCPSocketsCS ); csLock.Lock(); Assert( g_TCPSockets.Find( pSocket ) == g_TCPSockets.InvalidIndex() ); g_TCPSockets.AddToTail( pSocket );
// If this is the first one, create the keepalive thread.
if ( g_TCPSockets.Count() == 1 ) { g_hKeepaliveThreadSignal = CreateEvent( NULL, false, false, NULL ); g_hKeepaliveThreadReply = CreateEvent( NULL, false, false, NULL );
DWORD dwThreadID = 0; g_hKeepaliveThread = CreateThread( NULL, 0, TCPKeepaliveThread, NULL, 0, &dwThreadID ); } }
void RemoveGlobalTCPSocket( CTCPSocket *pSocket ) { bool bThreadRunning = false; DWORD dwExitCode = 0; if ( GetExitCodeThread( g_hKeepaliveThread, &dwExitCode ) && dwExitCode == STILL_ACTIVE ) { bThreadRunning = true; }
CCriticalSectionLock csLock( &g_TCPSocketsCS ); csLock.Lock(); int index = g_TCPSockets.Find( pSocket ); if ( index != g_TCPSockets.InvalidIndex() ) { g_TCPSockets.Remove( index );
// If this was the last one, delete the thread.
if ( g_TCPSockets.Count() == 0 ) { csLock.Unlock();
if ( bThreadRunning ) { SetEvent( g_hKeepaliveThreadSignal ); WaitForSingleObject( g_hKeepaliveThreadReply, INFINITE ); }
CloseHandle( g_hKeepaliveThreadSignal ); CloseHandle( g_hKeepaliveThreadReply ); CloseHandle( g_hKeepaliveThread ); return; } }
csLock.Unlock(); }
|