Team Fortress 2 Source Code as on 22/4/2020
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1178 lines
26 KiB

  1. //========= Copyright Valve Corporation, All rights reserved. ============//
  2. //
  3. // Purpose:
  4. //
  5. // $NoKeywords: $
  6. //=============================================================================//
  7. //#define PARANOID
  8. #if defined( PARANOID )
  9. #include <stdlib.h>
  10. #include <crtdbg.h>
  11. #endif
  12. #include <winsock2.h>
  13. #include <mswsock.h>
  14. #include "tcpsocket.h"
  15. #include "tier1/utllinkedlist.h"
  16. #include <stdio.h>
  17. #include "threadhelpers.h"
  18. #include "tier0/dbg.h"
  19. #error "I am TCPSocket and I suck. Use IThreadedTCPSocket or ThreadedTCPSocketEmu instead."
  20. extern TIMEVAL SetupTimeVal( double flTimeout );
  21. extern void IPAddrToSockAddr( const CIPAddr *pIn, sockaddr_in *pOut );
  22. extern void SockAddrToIPAddr( const sockaddr_in *pIn, CIPAddr *pOut );
  23. #define SENTINEL_DISCONNECT -1
  24. #define SENTINEL_KEEPALIVE -2
  25. #define KEEPALIVE_INTERVAL_MS 3000 // keepalives are sent every N MS
  26. #define KEEPALIVE_TIMEOUT_SECONDS 15.0 // connections timeout after this long
  27. static bool g_bEnableTCPTimeout = true;
  28. class CRecvData
  29. {
  30. public:
  31. int m_Count;
  32. unsigned char m_Data[1];
  33. };
  34. SOCKET TCPBind( const CIPAddr *pAddr )
  35. {
  36. // Create a socket to send and receive through.
  37. SOCKET sock = WSASocket( AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED );
  38. if ( sock == INVALID_SOCKET )
  39. {
  40. Assert( false );
  41. return INVALID_SOCKET;
  42. }
  43. // bind to it!
  44. sockaddr_in addr;
  45. IPAddrToSockAddr( pAddr, &addr );
  46. int status = bind( sock, (sockaddr*)&addr, sizeof(addr) );
  47. if ( status == 0 )
  48. {
  49. return sock;
  50. }
  51. else
  52. {
  53. closesocket( sock );
  54. return INVALID_SOCKET;
  55. }
  56. }
  57. // ---------------------------------------------------------------------------------------- //
  58. // TCP sockets.
  59. // ---------------------------------------------------------------------------------------- //
  60. enum
  61. {
  62. OP_RECV=111,
  63. OP_SEND
  64. };
  65. // We use this for all OVERLAPPED structures.
  66. class COverlappedPlus : public WSAOVERLAPPED
  67. {
  68. public:
  69. COverlappedPlus()
  70. {
  71. memset( this, 0, sizeof( WSAOVERLAPPED ) );
  72. }
  73. int m_OPType; // One of the OP_ defines.
  74. };
  75. typedef struct SendBuf_t
  76. {
  77. COverlappedPlus m_Overlapped;
  78. int m_Index; // Index into m_SendBufs.
  79. int m_DataLength;
  80. char m_Data[1];
  81. } SendBuf_s;
  82. // These manage a thread that calls SendKeepalive() on all TCPSockets.
  83. // AddGlobalTCPSocket shouldn't be called until you're ready for SendKeepalive() to be called.
  84. class CTCPSocket;
  85. void AddGlobalTCPSocket( CTCPSocket *pSocket );
  86. void RemoveGlobalTCPSocket( CTCPSocket *pSocket );
  87. // ------------------------------------------------------------------------------------------ //
  88. // CTCPSocket implementation.
  89. // ------------------------------------------------------------------------------------------ //
  90. class CTCPSocket : public ITCPSocket
  91. {
  92. friend class CTCPListenSocket;
  93. public:
  94. CTCPSocket()
  95. {
  96. m_Socket = INVALID_SOCKET;
  97. m_bConnected = false;
  98. m_hIOCP = NULL;
  99. m_bShouldExitThreads = false;
  100. m_bConnectionLost = false;
  101. m_nSizeBytesReceived = 0;
  102. m_pIncomingData = NULL;
  103. memset( &m_RecvOverlapped, 0, sizeof( m_RecvOverlapped ) );
  104. m_RecvOverlapped.m_OPType = OP_RECV;
  105. m_hRecvSignal = CreateEvent( NULL, FALSE, FALSE, NULL );
  106. m_RecvStage = -1;
  107. m_MainThreadID = GetCurrentThreadId();
  108. }
  109. virtual ~CTCPSocket()
  110. {
  111. Term();
  112. CloseHandle( m_hRecvSignal );
  113. }
  114. void Term()
  115. {
  116. Assert( GetCurrentThreadId() == m_MainThreadID );
  117. RemoveGlobalTCPSocket( this );
  118. if ( m_Socket != SOCKET_ERROR && !m_bConnectionLost )
  119. {
  120. SendDisconnectSentinel();
  121. // Give the sends a second to complete. SO_LINGER is having trouble for some reason.
  122. WaitForSendsToComplete( 1 );
  123. }
  124. StopThreads();
  125. if ( m_Socket != INVALID_SOCKET )
  126. {
  127. closesocket( m_Socket );
  128. m_Socket = INVALID_SOCKET;
  129. }
  130. if ( m_hIOCP )
  131. {
  132. CloseHandle( m_hIOCP );
  133. m_hIOCP = NULL;
  134. }
  135. m_bConnected = false;
  136. m_bConnectionLost = true;
  137. m_RecvStage = -1;
  138. FOR_EACH_LL( m_SendBufs, i )
  139. {
  140. SendBuf_t *pSendBuf = m_SendBufs[i];
  141. ParanoidMemoryCheck( pSendBuf );
  142. free( pSendBuf );
  143. }
  144. m_SendBufs.Purge();
  145. FOR_EACH_LL( m_RecvDatas, j )
  146. {
  147. CRecvData *pRecvData = m_RecvDatas[j];
  148. ParanoidMemoryCheck( pRecvData );
  149. free( pRecvData );
  150. }
  151. m_RecvDatas.Purge();
  152. if ( m_pIncomingData )
  153. {
  154. ParanoidMemoryCheck( m_pIncomingData );
  155. free( m_pIncomingData );
  156. m_pIncomingData = 0;
  157. }
  158. }
  159. virtual void Release()
  160. {
  161. delete this;
  162. }
  163. void ParanoidMemoryCheck( void *ptr = NULL )
  164. {
  165. #if defined( PARANOID )
  166. Assert( _CrtIsValidHeapPointer( this ) );
  167. if ( ptr )
  168. {
  169. Assert( _CrtIsValidHeapPointer( ptr ) );
  170. }
  171. Assert( _CrtCheckMemory() == TRUE );
  172. #endif
  173. }
  174. virtual bool BindToAny( const unsigned short port )
  175. {
  176. Term();
  177. CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY
  178. m_Socket = TCPBind( &addr );
  179. if ( m_Socket == INVALID_SOCKET )
  180. {
  181. return false;
  182. }
  183. else
  184. {
  185. SetInitialSocketOptions();
  186. return true;
  187. }
  188. }
  189. // Set the initial socket options that we want.
  190. void SetInitialSocketOptions()
  191. {
  192. // Set nodelay to improve latency.
  193. BOOL val = TRUE;
  194. setsockopt( m_Socket, IPPROTO_TCP, TCP_NODELAY, (const char FAR *)&val, sizeof(BOOL) );
  195. // Make it linger for 3 seconds when it exits.
  196. LINGER linger;
  197. linger.l_onoff = 1;
  198. linger.l_linger = 3;
  199. setsockopt( m_Socket, SOL_SOCKET, SO_LINGER, (char*)&linger, sizeof( linger ) );
  200. }
  201. // Called only by main thread interface functions.
  202. // Returns true if the connection is lost.
  203. bool CheckConnectionLost()
  204. {
  205. Assert( GetCurrentThreadId() == m_MainThreadID );
  206. if ( m_Socket == SOCKET_ERROR )
  207. return true;
  208. // Have we timed out?
  209. if ( g_bEnableTCPTimeout && (Plat_FloatTime() - m_LastRecvTime > KEEPALIVE_TIMEOUT_SECONDS) )
  210. {
  211. SetConnectionLost( "Connection timed out." );
  212. }
  213. // Has any thread posted that the connection has been lost?
  214. CCriticalSectionLock postLock( &m_ConnectionLostCS );
  215. postLock.Lock();
  216. if ( m_bConnectionLost )
  217. {
  218. Term();
  219. return true;
  220. }
  221. else
  222. {
  223. return false;
  224. }
  225. }
  226. // Called by any thread. All interface functions call CheckConnectionLost() and return errors if it's lost.
  227. void SetConnectionLost( const char *pErrorString, int err = -1 )
  228. {
  229. CCriticalSectionLock postLock( &m_ConnectionLostCS );
  230. postLock.Lock();
  231. m_bConnectionLost = true;
  232. postLock.Unlock();
  233. // Handle it right away if we're in the main thread. If we're in an IO thread,
  234. // it has to wait until the next interface function calls CheckConnectionLost().
  235. if ( GetCurrentThreadId() == m_MainThreadID )
  236. {
  237. Term();
  238. }
  239. if ( pErrorString )
  240. {
  241. m_ErrorString.CopyArray( pErrorString, strlen( pErrorString ) + 1 );
  242. }
  243. else
  244. {
  245. char *lpMsgBuf;
  246. FormatMessage(
  247. FORMAT_MESSAGE_ALLOCATE_BUFFER |
  248. FORMAT_MESSAGE_FROM_SYSTEM |
  249. FORMAT_MESSAGE_IGNORE_INSERTS,
  250. NULL,
  251. err,
  252. MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language
  253. (LPTSTR) &lpMsgBuf,
  254. 0,
  255. NULL
  256. );
  257. m_ErrorString.CopyArray( lpMsgBuf, strlen( lpMsgBuf ) + 1 );
  258. LocalFree( lpMsgBuf );
  259. }
  260. }
  261. // -------------------------------------------------------------------------------------------------- //
  262. // The receive code.
  263. // -------------------------------------------------------------------------------------------------- //
  264. virtual bool StartWaitingForSize( bool bFresh )
  265. {
  266. Assert( m_Socket != INVALID_SOCKET );
  267. Assert( m_bConnected );
  268. m_RecvStage = 0;
  269. m_RecvDataSize = -1;
  270. if ( bFresh )
  271. m_nSizeBytesReceived = 0;
  272. DWORD dwNumBytesReceived = 0;
  273. WSABUF buf = { sizeof( &m_RecvDataSize ) - m_nSizeBytesReceived, ((char*)&m_RecvDataSize) + m_nSizeBytesReceived };
  274. DWORD dwFlags = 0;
  275. int status = WSARecv(
  276. m_Socket,
  277. &buf,
  278. 1,
  279. &dwNumBytesReceived,
  280. &dwFlags,
  281. &m_RecvOverlapped,
  282. NULL );
  283. int err = -1;
  284. if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING )
  285. {
  286. SetConnectionLost( NULL, err );
  287. return false;
  288. }
  289. else
  290. {
  291. return true;
  292. }
  293. }
  294. bool PostNextDataPart()
  295. {
  296. DWORD dwNumBytesReceived = 0;
  297. WSABUF buf = { m_RecvDataSize - m_AmountReceived, (char*)m_pIncomingData->m_Data + m_AmountReceived };
  298. DWORD dwFlags = 0;
  299. int status = WSARecv(
  300. m_Socket,
  301. &buf,
  302. 1,
  303. &dwNumBytesReceived,
  304. &dwFlags,
  305. &m_RecvOverlapped,
  306. NULL );
  307. int err = -1;
  308. if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING )
  309. {
  310. SetConnectionLost( NULL, err );
  311. return false;
  312. }
  313. else
  314. {
  315. return true;
  316. }
  317. }
  318. bool StartWaitingForData()
  319. {
  320. Assert( m_Socket != INVALID_SOCKET );
  321. Assert( m_RecvStage == 0 );
  322. Assert( m_bConnected );
  323. Assert( m_RecvDataSize > 0 );
  324. m_RecvStage = 1;
  325. // Add a CRecvData element.
  326. ParanoidMemoryCheck();
  327. m_pIncomingData = (CRecvData*)malloc( sizeof( CRecvData ) - 1 + m_RecvDataSize );
  328. if ( !m_pIncomingData )
  329. {
  330. char str[512];
  331. _snprintf( str, sizeof( str ), "malloc() failed. m_RecvDataSize = %d\n", m_RecvDataSize );
  332. SetConnectionLost( str );
  333. return false;
  334. }
  335. m_pIncomingData->m_Count = m_RecvDataSize;
  336. m_AmountReceived = 0;
  337. return PostNextDataPart();
  338. }
  339. virtual bool Recv( CUtlVector<unsigned char> &data, double flTimeout )
  340. {
  341. if ( CheckConnectionLost() )
  342. return false;
  343. // Wait in 50ms chunks, checking for disconnections along the way.
  344. bool bGotData = false;
  345. DWORD msToWait = (DWORD)( flTimeout * 1000.0 );
  346. do
  347. {
  348. DWORD curWaitTime = min( msToWait, 50 );
  349. DWORD ret = WaitForSingleObject( m_hRecvSignal, curWaitTime );
  350. if ( ret == WAIT_OBJECT_0 )
  351. {
  352. bGotData = true;
  353. break;
  354. }
  355. // Did the connection timeout?
  356. if ( CheckConnectionLost() )
  357. return false;
  358. msToWait -= curWaitTime;
  359. } while ( msToWait );
  360. // If we never got a WAIT_OBJECT_0, then we never received anything.
  361. if ( !bGotData )
  362. return false;
  363. CCriticalSectionLock csLock( &m_RecvDataCS );
  364. csLock.Lock();
  365. // Pickup the head m_RecvDatas element.
  366. CRecvData *pRecvData = m_RecvDatas[ m_RecvDatas.Head() ];
  367. data.CopyArray( pRecvData->m_Data, pRecvData->m_Count );
  368. // Now free it.
  369. m_RecvDatas.Remove( m_RecvDatas.Head() );
  370. ParanoidMemoryCheck( pRecvData );
  371. free( pRecvData );
  372. // Set the event again for the next time around, if there is more data waiting.
  373. if ( m_RecvDatas.Count() > 0 )
  374. SetEvent( m_hRecvSignal );
  375. return true;
  376. }
  377. // INSIDE IO THREAD.
  378. void HandleRecvCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes )
  379. {
  380. if ( dwNumBytes == 0 )
  381. {
  382. SetConnectionLost( "Got 0 bytes in HandleRecvCompletion" );
  383. return;
  384. }
  385. m_LastRecvTime = Plat_FloatTime();
  386. if ( m_RecvStage == 0 )
  387. {
  388. m_nSizeBytesReceived += dwNumBytes;
  389. if ( m_nSizeBytesReceived == sizeof( m_RecvDataSize ) )
  390. {
  391. // Size of -1 means the other size is breaking the connection.
  392. if ( m_RecvDataSize == SENTINEL_DISCONNECT )
  393. {
  394. SetConnectionLost( "Got a graceful disconnect message." );
  395. return;
  396. }
  397. else if ( m_RecvDataSize == SENTINEL_KEEPALIVE )
  398. {
  399. // No data follows this. Just let m_LastRecvTime get updated.
  400. StartWaitingForSize( true );
  401. return;
  402. }
  403. StartWaitingForData();
  404. }
  405. else if ( m_nSizeBytesReceived < sizeof( m_RecvDataSize ) )
  406. {
  407. // Handle the case where we only got some of the data (maybe one of the clients got disconnected).
  408. StartWaitingForSize( false );
  409. }
  410. else
  411. {
  412. // This case should never ever happen!
  413. #if defined( _DEBUG )
  414. __asm int 3;
  415. #endif
  416. SetConnectionLost( "Received too much data in a packet!" );
  417. return;
  418. }
  419. }
  420. else if ( m_RecvStage == 1 )
  421. {
  422. // Got the data, make sure we got it all.
  423. m_AmountReceived += dwNumBytes;
  424. // Sanity check.
  425. #if defined( _DEBUG )
  426. Assert( m_RecvDataSize == m_pIncomingData->m_Count );
  427. Assert( m_AmountReceived <= m_RecvDataSize ); // TODO: make this threadsafe for multiple IO threads.
  428. #endif
  429. if ( m_AmountReceived == m_RecvDataSize )
  430. {
  431. m_RecvStage = 2;
  432. // Add the data to the list of packets waiting to be picked up.
  433. CCriticalSectionLock csLock( &m_RecvDataCS );
  434. csLock.Lock();
  435. m_RecvDatas.AddToTail( m_pIncomingData );
  436. m_pIncomingData = NULL;
  437. if ( m_RecvDatas.Count() == 1 )
  438. SetEvent( m_hRecvSignal ); // Notify the Recv() function.
  439. StartWaitingForSize( true );
  440. }
  441. else
  442. {
  443. PostNextDataPart();
  444. }
  445. }
  446. else
  447. {
  448. Assert( false );
  449. }
  450. }
  451. // -------------------------------------------------------------------------------------------------- //
  452. // The send code.
  453. // -------------------------------------------------------------------------------------------------- //
  454. virtual void WaitForSendsToComplete( double flTimeout )
  455. {
  456. CWaitTimer waitTimer( flTimeout );
  457. while ( 1 )
  458. {
  459. CCriticalSectionLock sendBufLock( &m_SendCS );
  460. sendBufLock.Lock();
  461. if( m_SendBufs.Count() == 0 )
  462. return;
  463. sendBufLock.Unlock();
  464. if ( waitTimer.ShouldKeepWaiting() )
  465. Sleep( 10 );
  466. else
  467. break;
  468. }
  469. }
  470. // This is called in the keepalive thread.
  471. void SendKeepalive()
  472. {
  473. // Send a message saying we're exiting.
  474. ParanoidMemoryCheck();
  475. SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) );
  476. if ( !pBuf )
  477. {
  478. SetConnectionLost( "malloc() in SendKeepalive() failed." );
  479. return;
  480. }
  481. pBuf->m_DataLength = sizeof( int );
  482. *((int*)pBuf->m_Data) = SENTINEL_KEEPALIVE;
  483. InternalSendDataBuf( pBuf );
  484. }
  485. void SendDisconnectSentinel()
  486. {
  487. // Send a message saying we're exiting.
  488. ParanoidMemoryCheck();
  489. SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) );
  490. if ( pBuf )
  491. {
  492. pBuf->m_DataLength = sizeof( int );
  493. *((int*)pBuf->m_Data) = SENTINEL_DISCONNECT; // This signifies that we're exiting.
  494. InternalSendDataBuf( pBuf );
  495. }
  496. }
  497. virtual bool Send( const void *pData, int len )
  498. {
  499. const void *pChunks[1] = { pData };
  500. int chunkLengths[1] = { len };
  501. return SendChunks( pChunks, chunkLengths, 1 );
  502. }
  503. virtual bool SendChunks( void const * const *pChunks, const int *pChunkLengths, int nChunks )
  504. {
  505. if ( CheckConnectionLost() )
  506. return false;
  507. CChunkWalker walker( pChunks, pChunkLengths, nChunks );
  508. int totalLength = walker.GetTotalLength();
  509. if ( !totalLength )
  510. return true;
  511. // Create a buffer to hold the data and copy the data in.
  512. ParanoidMemoryCheck();
  513. SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + totalLength + sizeof( int ) );
  514. if ( !pBuf )
  515. {
  516. char str[512];
  517. _snprintf( str, sizeof( str ), "malloc() in SendChunks() failed. totalLength = %d.", totalLength );
  518. SetConnectionLost( str );
  519. return false;
  520. }
  521. pBuf->m_DataLength = totalLength + sizeof( int );
  522. int *pByteCountPos = (int*)pBuf->m_Data;
  523. *pByteCountPos = totalLength;
  524. char *pDataPos = &pBuf->m_Data[ sizeof( int ) ];
  525. walker.CopyTo( pDataPos, totalLength );
  526. int status = InternalSendDataBuf( pBuf );
  527. int err = -1;
  528. if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING )
  529. {
  530. SetConnectionLost( NULL, err );
  531. return false;
  532. }
  533. else
  534. {
  535. return true;
  536. }
  537. }
  538. int InternalSendDataBuf( SendBuf_t *pBuf )
  539. {
  540. // Protect against interference from the keepalive thread.
  541. CCriticalSectionLock csLock( &m_SendCS );
  542. csLock.Lock();
  543. pBuf->m_Overlapped.m_OPType = OP_SEND;
  544. pBuf->m_Overlapped.hEvent = NULL;
  545. // Add it to our list of buffers.
  546. pBuf->m_Index = m_SendBufs.AddToTail( pBuf );
  547. // Tell Winsock to send it.
  548. WSABUF buf = { pBuf->m_DataLength, pBuf->m_Data };
  549. DWORD dwNumBytesSent = 0;
  550. return WSASend(
  551. m_Socket,
  552. &buf,
  553. 1,
  554. &dwNumBytesSent,
  555. 0,
  556. &pBuf->m_Overlapped,
  557. NULL );
  558. }
  559. // INSIDE IO THREAD.
  560. void HandleSendCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes )
  561. {
  562. if ( dwNumBytes == 0 )
  563. {
  564. SetConnectionLost( "0 bytes in HandleSendCompletion." );
  565. return;
  566. }
  567. // Just free the buffer.
  568. SendBuf_t *pBuf = (SendBuf_t*)pInfo;
  569. Assert( dwNumBytes == (DWORD)pBuf->m_DataLength );
  570. CCriticalSectionLock sendBufLock( &m_SendCS );
  571. sendBufLock.Lock();
  572. m_SendBufs.Remove( pBuf->m_Index );
  573. sendBufLock.Unlock();
  574. ParanoidMemoryCheck( pBuf );
  575. free( pBuf );
  576. }
  577. // -------------------------------------------------------------------------------------------------- //
  578. // The connect code.
  579. // -------------------------------------------------------------------------------------------------- //
  580. virtual bool BeginConnect( const CIPAddr &inputAddr )
  581. {
  582. sockaddr_in addr;
  583. IPAddrToSockAddr( &inputAddr, &addr );
  584. m_bConnected = false;
  585. int ret = connect( m_Socket, (struct sockaddr*)&addr, sizeof( addr ) );
  586. ret=ret;
  587. return true;
  588. }
  589. virtual bool UpdateConnect()
  590. {
  591. // We're still ok.. just wait until the socket becomes writable (is connected) or we timeout.
  592. fd_set writeSet;
  593. writeSet.fd_count = 1;
  594. writeSet.fd_array[0] = m_Socket;
  595. TIMEVAL timeVal = SetupTimeVal( 0 );
  596. // See if it has a packet waiting.
  597. int status = select( 0, NULL, &writeSet, NULL, &timeVal );
  598. if ( status > 0 )
  599. {
  600. SetupConnected();
  601. return true;
  602. }
  603. return false;
  604. }
  605. void SetupConnected()
  606. {
  607. m_bConnected = true;
  608. m_bConnectionLost = false;
  609. m_LastRecvTime = Plat_FloatTime();
  610. CreateThreads();
  611. StartWaitingForSize( true );
  612. AddGlobalTCPSocket( this );
  613. }
  614. virtual bool IsConnected()
  615. {
  616. CheckConnectionLost();
  617. return m_bConnected;
  618. }
  619. virtual void GetDisconnectReason( CUtlVector<char> &reason )
  620. {
  621. reason = m_ErrorString;
  622. }
  623. // -------------------------------------------------------------------------------------------------- //
  624. // Threads code.
  625. // -------------------------------------------------------------------------------------------------- //
  626. // Create our IO Completion Port threads.
  627. bool CreateThreads()
  628. {
  629. int nThreads = 1;
  630. SetShouldExitThreads( false );
  631. // Create our IO completion port and hook it to our socket.
  632. m_hIOCP = CreateIoCompletionPort(
  633. INVALID_HANDLE_VALUE, NULL, 0, 0);
  634. m_hIOCP = CreateIoCompletionPort( (HANDLE)m_Socket, m_hIOCP, (unsigned long)this, nThreads );
  635. for ( int i=0; i < nThreads; i++ )
  636. {
  637. DWORD dwThreadID = 0;
  638. HANDLE hThread = CreateThread(
  639. NULL,
  640. 0,
  641. &CTCPSocket::StaticThreadFn,
  642. this,
  643. 0,
  644. &dwThreadID );
  645. if ( hThread )
  646. {
  647. SetThreadPriority( hThread, THREAD_PRIORITY_ABOVE_NORMAL );
  648. m_Threads.AddToTail( hThread );
  649. }
  650. else
  651. {
  652. StopThreads();
  653. return false;
  654. }
  655. }
  656. return true;
  657. }
  658. void StopThreads()
  659. {
  660. // Tell the threads to exit, then wait for them to do so.
  661. SetShouldExitThreads( true );
  662. WaitForMultipleObjects( m_Threads.Count(), m_Threads.Base(), TRUE, INFINITE );
  663. for ( int i=0; i < m_Threads.Count(); i++ )
  664. {
  665. CloseHandle( m_Threads[i] );
  666. }
  667. m_Threads.Purge();
  668. }
  669. void SetShouldExitThreads( bool bShouldExit )
  670. {
  671. CCriticalSectionLock lock( &m_ThreadsCS );
  672. lock.Lock();
  673. m_bShouldExitThreads = bShouldExit;
  674. }
  675. bool ShouldExitThreads()
  676. {
  677. CCriticalSectionLock lock( &m_ThreadsCS );
  678. lock.Lock();
  679. bool bRet = m_bShouldExitThreads;
  680. return bRet;
  681. }
  682. DWORD ThreadFn()
  683. {
  684. while ( 1 )
  685. {
  686. DWORD dwNumBytes = 0;
  687. unsigned long pInputTCPSocket;
  688. LPOVERLAPPED pOverlapped;
  689. if ( GetQueuedCompletionStatus(
  690. m_hIOCP, // the port we're listening on
  691. &dwNumBytes, // # bytes received on the port
  692. &pInputTCPSocket,// "completion key" = CTCPSocket*
  693. &pOverlapped, // the overlapped info that was passed into AcceptEx, WSARecv, or WSASend.
  694. 100 // listen for 100ms at a time so we can exit gracefully when the socket is deleted.
  695. ) )
  696. {
  697. COverlappedPlus *pInfo = (COverlappedPlus*)pOverlapped;
  698. ParanoidMemoryCheck( pInfo );
  699. if ( pInfo->m_OPType == OP_RECV )
  700. {
  701. Assert( pInfo == &m_RecvOverlapped );
  702. HandleRecvCompletion( pInfo, dwNumBytes );
  703. }
  704. else
  705. {
  706. Assert( pInfo->m_OPType == OP_SEND );
  707. HandleSendCompletion( pInfo, dwNumBytes );
  708. }
  709. }
  710. if ( ShouldExitThreads() )
  711. break;
  712. }
  713. return 0;
  714. }
  715. static DWORD WINAPI StaticThreadFn( LPVOID pParameter )
  716. {
  717. return ((CTCPSocket*)pParameter)->ThreadFn();
  718. }
  719. private:
  720. SOCKET m_Socket;
  721. bool m_bConnected;
  722. // m_RecvOverlapped is setup to first wait for the size, then the data.
  723. // Then it is not posted until the app grabs the data.
  724. HANDLE m_hRecvSignal; // Tells Recv() when we have data.
  725. COverlappedPlus m_RecvOverlapped;
  726. int m_RecvStage; // -1 = not initialized
  727. // 0 = waiting for size
  728. // 1 = waiting for data
  729. // 2 = waiting for app to pickup the data
  730. CUtlLinkedList<CRecvData*,int> m_RecvDatas; // The head element is the next one to be picked up.
  731. CRecvData *m_pIncomingData; // The packet we're currently receiving.
  732. CCriticalSection m_RecvDataCS; // This protects adds and removes in the list.
  733. // These reference the element at the tail of m_RecvData. It is the current one getting
  734. volatile int m_nSizeBytesReceived; // How much of m_RecvDataSize have we received yet?
  735. int m_RecvDataSize; // this is received over the network
  736. int m_AmountReceived; // How much we've received so far.
  737. // Last time we received anything from this connection. Used to determine if the connection is
  738. // still active.
  739. double m_LastRecvTime;
  740. // Outgoing send buffers.
  741. CUtlLinkedList<SendBuf_t*,int> m_SendBufs;
  742. CCriticalSection m_SendCS;
  743. // All the threads waiting for IO.
  744. CUtlVector<HANDLE> m_Threads;
  745. HANDLE m_hIOCP;
  746. // Used during shutdown.
  747. volatile bool m_bShouldExitThreads;
  748. CCriticalSection m_ThreadsCS;
  749. // For debugging.
  750. DWORD m_MainThreadID;
  751. // Set by the main thread or IO threads to signal connection lost.
  752. bool m_bConnectionLost;
  753. CCriticalSection m_ConnectionLostCS;
  754. // This is set when we get disconnected.
  755. CUtlVector<char> m_ErrorString;
  756. };
  757. // ------------------------------------------------------------------------------------------ //
  758. // ITCPListenSocket implementation.
  759. // ------------------------------------------------------------------------------------------ //
  760. class CTCPListenSocket : public ITCPListenSocket
  761. {
  762. public:
  763. CTCPListenSocket()
  764. {
  765. m_Socket = INVALID_SOCKET;
  766. }
  767. virtual ~CTCPListenSocket()
  768. {
  769. if ( m_Socket != INVALID_SOCKET )
  770. {
  771. closesocket( m_Socket );
  772. }
  773. }
  774. // The main function to create one of these suckers.
  775. static ITCPListenSocket* Create( const unsigned short port, int nQueueLength )
  776. {
  777. CTCPListenSocket *pRet = new CTCPListenSocket;
  778. if ( !pRet )
  779. return NULL;
  780. // Bind it to a socket and start listening.
  781. CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY
  782. pRet->m_Socket = TCPBind( &addr );
  783. if ( pRet->m_Socket == INVALID_SOCKET ||
  784. listen( pRet->m_Socket, nQueueLength == -1 ? SOMAXCONN : nQueueLength ) != 0 )
  785. {
  786. pRet->Release();
  787. return false;
  788. }
  789. return pRet;
  790. }
  791. virtual void Release()
  792. {
  793. delete this;
  794. }
  795. virtual ITCPSocket* UpdateListen( CIPAddr *pAddr )
  796. {
  797. // We're still ok.. just wait until the socket becomes writable (is connected) or we timeout.
  798. fd_set readSet;
  799. readSet.fd_count = 1;
  800. readSet.fd_array[0] = m_Socket;
  801. TIMEVAL timeVal = SetupTimeVal( 0 );
  802. // Wait until it connects.
  803. int status = select( 0, &readSet, NULL, NULL, &timeVal );
  804. if ( status > 0 )
  805. {
  806. sockaddr_in addr;
  807. int addrSize = sizeof( addr );
  808. // Now accept the final connection.
  809. SOCKET newSock = accept( m_Socket, (struct sockaddr*)&addr, &addrSize );
  810. if ( newSock == INVALID_SOCKET )
  811. {
  812. Assert( false );
  813. }
  814. else
  815. {
  816. CTCPSocket *pRet = new CTCPSocket;
  817. if ( !pRet )
  818. {
  819. closesocket( newSock );
  820. return NULL;
  821. }
  822. pRet->m_Socket = newSock;
  823. pRet->SetInitialSocketOptions();
  824. pRet->SetupConnected();
  825. // Report the address..
  826. SockAddrToIPAddr( &addr, pAddr );
  827. return pRet;
  828. }
  829. }
  830. return NULL;
  831. }
  832. private:
  833. SOCKET m_Socket;
  834. };
  835. ITCPListenSocket* CreateTCPListenSocket( const unsigned short port, int nQueueLength )
  836. {
  837. return CTCPListenSocket::Create( port, nQueueLength );
  838. }
  839. ITCPSocket* CreateTCPSocket()
  840. {
  841. return new CTCPSocket;
  842. }
  843. void TCPSocket_EnableTimeout( bool bEnable )
  844. {
  845. g_bEnableTCPTimeout = bEnable;
  846. }
  847. // --------------------------------------------------------------------------------- //
  848. // This thread sends keepalives on all active TCP sockets.
  849. // --------------------------------------------------------------------------------- //
  850. HANDLE g_hKeepaliveThread;
  851. HANDLE g_hKeepaliveThreadSignal;
  852. HANDLE g_hKeepaliveThreadReply;
  853. CUtlLinkedList<CTCPSocket*,int> g_TCPSockets;
  854. CCriticalSection g_TCPSocketsCS;
  855. DWORD WINAPI TCPKeepaliveThread( LPVOID pParameter )
  856. {
  857. while ( 1 )
  858. {
  859. if ( WaitForSingleObject( g_hKeepaliveThreadSignal, KEEPALIVE_INTERVAL_MS ) == WAIT_OBJECT_0 )
  860. break;
  861. // Tell all TCP sockets to send a keepalive.
  862. CCriticalSectionLock csLock( &g_TCPSocketsCS );
  863. csLock.Lock();
  864. FOR_EACH_LL( g_TCPSockets, i )
  865. {
  866. g_TCPSockets[i]->SendKeepalive();
  867. }
  868. }
  869. SetEvent( g_hKeepaliveThreadReply );
  870. return 0;
  871. }
  872. void AddGlobalTCPSocket( CTCPSocket *pSocket )
  873. {
  874. CCriticalSectionLock csLock( &g_TCPSocketsCS );
  875. csLock.Lock();
  876. Assert( g_TCPSockets.Find( pSocket ) == g_TCPSockets.InvalidIndex() );
  877. g_TCPSockets.AddToTail( pSocket );
  878. // If this is the first one, create the keepalive thread.
  879. if ( g_TCPSockets.Count() == 1 )
  880. {
  881. g_hKeepaliveThreadSignal = CreateEvent( NULL, false, false, NULL );
  882. g_hKeepaliveThreadReply = CreateEvent( NULL, false, false, NULL );
  883. DWORD dwThreadID = 0;
  884. g_hKeepaliveThread = CreateThread(
  885. NULL,
  886. 0,
  887. TCPKeepaliveThread,
  888. NULL,
  889. 0,
  890. &dwThreadID
  891. );
  892. }
  893. }
  894. void RemoveGlobalTCPSocket( CTCPSocket *pSocket )
  895. {
  896. bool bThreadRunning = false;
  897. DWORD dwExitCode = 0;
  898. if ( GetExitCodeThread( g_hKeepaliveThread, &dwExitCode ) && dwExitCode == STILL_ACTIVE )
  899. {
  900. bThreadRunning = true;
  901. }
  902. CCriticalSectionLock csLock( &g_TCPSocketsCS );
  903. csLock.Lock();
  904. int index = g_TCPSockets.Find( pSocket );
  905. if ( index != g_TCPSockets.InvalidIndex() )
  906. {
  907. g_TCPSockets.Remove( index );
  908. // If this was the last one, delete the thread.
  909. if ( g_TCPSockets.Count() == 0 )
  910. {
  911. csLock.Unlock();
  912. if ( bThreadRunning )
  913. {
  914. SetEvent( g_hKeepaliveThreadSignal );
  915. WaitForSingleObject( g_hKeepaliveThreadReply, INFINITE );
  916. }
  917. CloseHandle( g_hKeepaliveThreadSignal );
  918. CloseHandle( g_hKeepaliveThreadReply );
  919. CloseHandle( g_hKeepaliveThread );
  920. return;
  921. }
  922. }
  923. csLock.Unlock();
  924. }