Team Fortress 2 Source Code as on 22/4/2020
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

344 lines
6.8 KiB

  1. //========= Copyright Valve Corporation, All rights reserved. ============//
  2. //
  3. // Purpose:
  4. //
  5. // $NoKeywords: $
  6. //=============================================================================//
  7. #include <windows.h>
  8. #include "tcpsocket.h"
  9. #include "IThreadedTCPSocket.h"
  10. #include "ThreadedTCPSocketEmu.h"
  11. #include "ThreadHelpers.h"
  12. // ---------------------------------------------------------------------------------------- //
  13. // CThreadedTCPSocketEmu. This uses IThreadedTCPSocket to emulate the polling-type interface
  14. // in ITCPSocket.
  15. // ---------------------------------------------------------------------------------------- //
  16. // This class uses the IThreadedTCPSocket interface to emulate the old ITCPSocket.
  17. class CThreadedTCPSocketEmu : public ITCPSocket, public ITCPSocketHandler, public IHandlerCreator
  18. {
  19. public:
  20. CThreadedTCPSocketEmu()
  21. {
  22. m_pSocket = NULL;
  23. m_LocalPort = 0xFFFF;
  24. m_pConnectSocket = NULL;
  25. m_RecvPacketsEvent.Init( false, false );
  26. m_bError = false;
  27. }
  28. virtual ~CThreadedTCPSocketEmu()
  29. {
  30. Term();
  31. }
  32. void Init( IThreadedTCPSocket *pSocket )
  33. {
  34. m_pSocket = pSocket;
  35. }
  36. void Term()
  37. {
  38. if ( m_pSocket )
  39. {
  40. m_pSocket->Release();
  41. m_pSocket = NULL;
  42. }
  43. if ( m_pConnectSocket )
  44. {
  45. m_pConnectSocket->Release();
  46. m_pConnectSocket = NULL;
  47. }
  48. }
  49. // ITCPSocketHandler implementation.
  50. private:
  51. virtual void OnPacketReceived( CTCPPacket *pPacket )
  52. {
  53. CCriticalSectionLock csLock( &m_RecvPacketsCS );
  54. csLock.Lock();
  55. m_RecvPackets.AddToTail( pPacket );
  56. m_RecvPacketsEvent.SetEvent();
  57. }
  58. virtual void OnError( int errorCode, const char *pErrorString )
  59. {
  60. CCriticalSectionLock csLock( &m_ErrorStringCS );
  61. csLock.Lock();
  62. m_ErrorString.CopyArray( pErrorString, strlen( pErrorString ) + 1 );
  63. m_bError = true;
  64. }
  65. // IHandlerCreator implementation.
  66. public:
  67. // This is used for connecting.
  68. virtual ITCPSocketHandler* CreateNewHandler()
  69. {
  70. return this;
  71. }
  72. // ITCPSocket implementation.
  73. public:
  74. virtual void Release()
  75. {
  76. delete this;
  77. }
  78. virtual bool BindToAny( const unsigned short port )
  79. {
  80. m_LocalPort = port;
  81. return true;
  82. }
  83. virtual bool BeginConnect( const CIPAddr &addr )
  84. {
  85. // They should have "bound" to a port before trying to connect.
  86. Assert( m_LocalPort != 0xFFFF );
  87. if ( m_pConnectSocket )
  88. m_pConnectSocket->Release();
  89. m_pConnectSocket = ThreadedTCP_CreateConnector(
  90. addr,
  91. CIPAddr( 0, 0, 0, 0, m_LocalPort ),
  92. this );
  93. return m_pConnectSocket != 0;
  94. }
  95. virtual bool UpdateConnect()
  96. {
  97. Assert( !m_pSocket );
  98. if ( !m_pConnectSocket )
  99. return false;
  100. if ( m_pConnectSocket->Update( &m_pSocket ) )
  101. {
  102. if ( m_pSocket )
  103. {
  104. // Ok, we're connected now.
  105. m_pConnectSocket->Release();
  106. m_pConnectSocket = NULL;
  107. return true;
  108. }
  109. else
  110. {
  111. return false;
  112. }
  113. }
  114. else
  115. {
  116. Assert( false );
  117. m_pConnectSocket->Release();
  118. m_pConnectSocket = NULL;
  119. return false;
  120. }
  121. }
  122. virtual bool IsConnected()
  123. {
  124. if ( m_bError )
  125. {
  126. Term();
  127. return false;
  128. }
  129. else
  130. {
  131. return m_pSocket != NULL;
  132. }
  133. }
  134. virtual void GetDisconnectReason( CUtlVector<char> &reason )
  135. {
  136. CCriticalSectionLock csLock( &m_ErrorStringCS );
  137. csLock.Lock();
  138. reason = m_ErrorString;
  139. }
  140. virtual bool Send( const void *pData, int size )
  141. {
  142. Assert( m_pSocket );
  143. if ( !m_pSocket )
  144. return false;
  145. return m_pSocket->Send( pData, size );
  146. }
  147. virtual bool SendChunks( void const * const *pChunks, const int *pChunkLengths, int nChunks )
  148. {
  149. Assert( m_pSocket );
  150. if ( !m_pSocket || !m_pSocket->IsValid() )
  151. return false;
  152. return m_pSocket->SendChunks( pChunks, pChunkLengths, nChunks );
  153. }
  154. virtual bool Recv( CUtlVector<unsigned char> &data, double flTimeout )
  155. {
  156. // Use our m_RecvPacketsEvent event to determine if there is data to receive yet.
  157. DWORD nMilliseconds = (DWORD)( flTimeout * 1000.0f );
  158. DWORD ret = WaitForSingleObject( m_RecvPacketsEvent.GetEventHandle(), nMilliseconds );
  159. if ( ret == WAIT_OBJECT_0 )
  160. {
  161. // Ok, there's a packet.
  162. CCriticalSectionLock csLock( &m_RecvPacketsCS );
  163. csLock.Lock();
  164. Assert( m_RecvPackets.Count() > 0 );
  165. int iHead = m_RecvPackets.Head();
  166. CTCPPacket *pPacket = m_RecvPackets[ iHead ];
  167. data.CopyArray( (const unsigned char*)pPacket->GetData(), pPacket->GetLen() );
  168. pPacket->Release();
  169. m_RecvPackets.Remove( iHead );
  170. // Re-set the event if there are more packets left to receive.
  171. if ( m_RecvPackets.Count() > 0 )
  172. {
  173. m_RecvPacketsEvent.SetEvent();
  174. }
  175. return true;
  176. }
  177. else
  178. {
  179. return false;
  180. }
  181. }
  182. private:
  183. IThreadedTCPSocket *m_pSocket;
  184. unsigned short m_LocalPort; // The port we bind to when we want to connect.
  185. ITCPConnectSocket *m_pConnectSocket;
  186. // All the received data is stored in here.
  187. CEvent m_RecvPacketsEvent;
  188. CCriticalSection m_RecvPacketsCS;
  189. CUtlLinkedList<CTCPPacket*, int> m_RecvPackets;
  190. CCriticalSection m_ErrorStringCS;
  191. CUtlVector<char> m_ErrorString;
  192. bool m_bError; // Set to true when there's an error. Next chance we get in the main thread, we'll close the socket.
  193. };
  194. ITCPSocket* CreateTCPSocketEmu()
  195. {
  196. return new CThreadedTCPSocketEmu;
  197. }
  198. // ---------------------------------------------------------------------------------------- //
  199. // CThreadedTCPListenSocketEmu implementation.
  200. // ---------------------------------------------------------------------------------------- //
  201. class CThreadedTCPListenSocketEmu : public ITCPListenSocket, public IHandlerCreator
  202. {
  203. public:
  204. CThreadedTCPListenSocketEmu()
  205. {
  206. m_pListener = NULL;
  207. m_pLastCreatedSocket = NULL;
  208. }
  209. virtual ~CThreadedTCPListenSocketEmu()
  210. {
  211. if ( m_pListener )
  212. m_pListener->Release();
  213. }
  214. bool StartListening( const unsigned short port, int nQueueLength )
  215. {
  216. m_pListener = ThreadedTCP_CreateListener(
  217. this,
  218. port,
  219. nQueueLength );
  220. return m_pListener != 0;
  221. }
  222. // ITCPListenSocket implementation.
  223. private:
  224. virtual void Release()
  225. {
  226. delete this;
  227. }
  228. virtual ITCPSocket* UpdateListen( CIPAddr *pAddr )
  229. {
  230. if ( !m_pListener )
  231. return NULL;
  232. IThreadedTCPSocket *pSocket;
  233. if ( m_pListener->Update( &pSocket ) && pSocket )
  234. {
  235. *pAddr = pSocket->GetRemoteAddr();
  236. // This is pretty hacky, but this stuff is just around for test code.
  237. CThreadedTCPSocketEmu *pLast = m_pLastCreatedSocket;
  238. pLast->Init( pSocket );
  239. m_pLastCreatedSocket = NULL;
  240. return pLast;
  241. }
  242. else
  243. {
  244. return NULL;
  245. }
  246. }
  247. // IHandlerCreator implementation.
  248. private:
  249. virtual ITCPSocketHandler* CreateNewHandler()
  250. {
  251. m_pLastCreatedSocket = new CThreadedTCPSocketEmu;
  252. return m_pLastCreatedSocket;
  253. }
  254. private:
  255. ITCPConnectSocket *m_pListener;
  256. CThreadedTCPSocketEmu *m_pLastCreatedSocket;
  257. };
  258. ITCPListenSocket* CreateTCPListenSocketEmu( const unsigned short port, int nQueueLength )
  259. {
  260. CThreadedTCPListenSocketEmu *pSocket = new CThreadedTCPListenSocketEmu;
  261. if ( pSocket->StartListening( port, nQueueLength ) )
  262. {
  263. return pSocket;
  264. }
  265. else
  266. {
  267. delete pSocket;
  268. return NULL;
  269. }
  270. }