Team Fortress 2 Source Code as on 22/4/2020
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

546 lines
12 KiB

  1. // socketft.cpp - written and placed in the public domain by Wei Dai
  2. #include "pch.h"
  3. // TODO: http://github.com/weidai11/cryptopp/issues/19
  4. #define _WINSOCK_DEPRECATED_NO_WARNINGS
  5. #include "socketft.h"
  6. #ifdef SOCKETS_AVAILABLE
  7. #include "wait.h"
  8. #ifdef USE_BERKELEY_STYLE_SOCKETS
  9. #include <errno.h>
  10. #include <netdb.h>
  11. #include <unistd.h>
  12. #include <arpa/inet.h>
  13. #include <netinet/in.h>
  14. #include <sys/ioctl.h>
  15. #endif
  16. #ifdef PREFER_WINDOWS_STYLE_SOCKETS
  17. # pragma comment(lib, "ws2_32.lib")
  18. #endif
  19. NAMESPACE_BEGIN(CryptoPP)
  20. #ifdef USE_WINDOWS_STYLE_SOCKETS
  21. const int SOCKET_EINVAL = WSAEINVAL;
  22. const int SOCKET_EWOULDBLOCK = WSAEWOULDBLOCK;
  23. typedef int socklen_t;
  24. #else
  25. const int SOCKET_EINVAL = EINVAL;
  26. const int SOCKET_EWOULDBLOCK = EWOULDBLOCK;
  27. #endif
  28. // Solaris doesn't have INADDR_NONE
  29. #ifndef INADDR_NONE
  30. # define INADDR_NONE 0xffffffff
  31. #endif /* INADDR_NONE */
  32. Socket::Err::Err(socket_t s, const std::string& operation, int error)
  33. : OS_Error(IO_ERROR, "Socket: " + operation + " operation failed with error " + IntToString(error), operation, error)
  34. , m_s(s)
  35. {
  36. }
  37. Socket::~Socket()
  38. {
  39. if (m_own)
  40. {
  41. try
  42. {
  43. CloseSocket();
  44. }
  45. catch (const Exception&)
  46. {
  47. assert(0);
  48. }
  49. }
  50. }
  51. void Socket::AttachSocket(socket_t s, bool own)
  52. {
  53. if (m_own)
  54. CloseSocket();
  55. m_s = s;
  56. m_own = own;
  57. SocketChanged();
  58. }
  59. socket_t Socket::DetachSocket()
  60. {
  61. socket_t s = m_s;
  62. m_s = INVALID_SOCKET;
  63. SocketChanged();
  64. return s;
  65. }
  66. void Socket::Create(int nType)
  67. {
  68. assert(m_s == INVALID_SOCKET);
  69. m_s = socket(AF_INET, nType, 0);
  70. CheckAndHandleError("socket", m_s);
  71. m_own = true;
  72. SocketChanged();
  73. }
  74. void Socket::CloseSocket()
  75. {
  76. if (m_s != INVALID_SOCKET)
  77. {
  78. #ifdef USE_WINDOWS_STYLE_SOCKETS
  79. CancelIo((HANDLE) m_s);
  80. CheckAndHandleError_int("closesocket", closesocket(m_s));
  81. #else
  82. CheckAndHandleError_int("close", close(m_s));
  83. #endif
  84. m_s = INVALID_SOCKET;
  85. SocketChanged();
  86. }
  87. }
  88. void Socket::Bind(unsigned int port, const char *addr)
  89. {
  90. sockaddr_in sa;
  91. memset(&sa, 0, sizeof(sa));
  92. sa.sin_family = AF_INET;
  93. if (addr == NULL)
  94. sa.sin_addr.s_addr = htonl(INADDR_ANY);
  95. else
  96. {
  97. unsigned long result = inet_addr(addr);
  98. if (result == INADDR_NONE)
  99. {
  100. SetLastError(SOCKET_EINVAL);
  101. CheckAndHandleError_int("inet_addr", SOCKET_ERROR);
  102. }
  103. sa.sin_addr.s_addr = result;
  104. }
  105. sa.sin_port = htons((u_short)port);
  106. Bind((sockaddr *)&sa, sizeof(sa));
  107. }
  108. void Socket::Bind(const sockaddr *psa, socklen_t saLen)
  109. {
  110. assert(m_s != INVALID_SOCKET);
  111. // cygwin workaround: needs const_cast
  112. CheckAndHandleError_int("bind", bind(m_s, const_cast<sockaddr *>(psa), saLen));
  113. }
  114. void Socket::Listen(int backlog)
  115. {
  116. assert(m_s != INVALID_SOCKET);
  117. CheckAndHandleError_int("listen", listen(m_s, backlog));
  118. }
  119. bool Socket::Connect(const char *addr, unsigned int port)
  120. {
  121. assert(addr != NULL);
  122. sockaddr_in sa;
  123. memset(&sa, 0, sizeof(sa));
  124. sa.sin_family = AF_INET;
  125. sa.sin_addr.s_addr = inet_addr(addr);
  126. if (sa.sin_addr.s_addr == INADDR_NONE)
  127. {
  128. hostent *lphost = gethostbyname(addr);
  129. if (lphost == NULL)
  130. {
  131. SetLastError(SOCKET_EINVAL);
  132. CheckAndHandleError_int("gethostbyname", SOCKET_ERROR);
  133. }
  134. else
  135. {
  136. sa.sin_addr.s_addr = ((in_addr *)lphost->h_addr)->s_addr;
  137. }
  138. }
  139. sa.sin_port = htons((u_short)port);
  140. return Connect((const sockaddr *)&sa, sizeof(sa));
  141. }
  142. bool Socket::Connect(const sockaddr* psa, socklen_t saLen)
  143. {
  144. assert(m_s != INVALID_SOCKET);
  145. int result = connect(m_s, const_cast<sockaddr*>(psa), saLen);
  146. if (result == SOCKET_ERROR && GetLastError() == SOCKET_EWOULDBLOCK)
  147. return false;
  148. CheckAndHandleError_int("connect", result);
  149. return true;
  150. }
  151. bool Socket::Accept(Socket& target, sockaddr *psa, socklen_t *psaLen)
  152. {
  153. assert(m_s != INVALID_SOCKET);
  154. socket_t s = accept(m_s, psa, psaLen);
  155. if (s == INVALID_SOCKET && GetLastError() == SOCKET_EWOULDBLOCK)
  156. return false;
  157. CheckAndHandleError("accept", s);
  158. target.AttachSocket(s, true);
  159. return true;
  160. }
  161. void Socket::GetSockName(sockaddr *psa, socklen_t *psaLen)
  162. {
  163. assert(m_s != INVALID_SOCKET);
  164. CheckAndHandleError_int("getsockname", getsockname(m_s, psa, psaLen));
  165. }
  166. void Socket::GetPeerName(sockaddr *psa, socklen_t *psaLen)
  167. {
  168. assert(m_s != INVALID_SOCKET);
  169. CheckAndHandleError_int("getpeername", getpeername(m_s, psa, psaLen));
  170. }
  171. unsigned int Socket::Send(const byte* buf, size_t bufLen, int flags)
  172. {
  173. assert(m_s != INVALID_SOCKET);
  174. int result = send(m_s, (const char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
  175. CheckAndHandleError_int("send", result);
  176. return result;
  177. }
  178. unsigned int Socket::Receive(byte* buf, size_t bufLen, int flags)
  179. {
  180. assert(m_s != INVALID_SOCKET);
  181. int result = recv(m_s, (char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
  182. CheckAndHandleError_int("recv", result);
  183. return result;
  184. }
  185. void Socket::ShutDown(int how)
  186. {
  187. assert(m_s != INVALID_SOCKET);
  188. int result = shutdown(m_s, how);
  189. CheckAndHandleError_int("shutdown", result);
  190. }
  191. void Socket::IOCtl(long cmd, unsigned long *argp)
  192. {
  193. assert(m_s != INVALID_SOCKET);
  194. #ifdef USE_WINDOWS_STYLE_SOCKETS
  195. CheckAndHandleError_int("ioctlsocket", ioctlsocket(m_s, cmd, argp));
  196. #else
  197. CheckAndHandleError_int("ioctl", ioctl(m_s, cmd, argp));
  198. #endif
  199. }
  200. bool Socket::SendReady(const timeval *timeout)
  201. {
  202. fd_set fds;
  203. FD_ZERO(&fds);
  204. FD_SET(m_s, &fds);
  205. int ready;
  206. if (timeout == NULL)
  207. ready = select((int)m_s+1, NULL, &fds, NULL, NULL);
  208. else
  209. {
  210. timeval timeoutCopy = *timeout; // select() modified timeout on Linux
  211. ready = select((int)m_s+1, NULL, &fds, NULL, &timeoutCopy);
  212. }
  213. CheckAndHandleError_int("select", ready);
  214. return ready > 0;
  215. }
  216. bool Socket::ReceiveReady(const timeval *timeout)
  217. {
  218. fd_set fds;
  219. FD_ZERO(&fds);
  220. FD_SET(m_s, &fds);
  221. int ready;
  222. if (timeout == NULL)
  223. ready = select((int)m_s+1, &fds, NULL, NULL, NULL);
  224. else
  225. {
  226. timeval timeoutCopy = *timeout; // select() modified timeout on Linux
  227. ready = select((int)m_s+1, &fds, NULL, NULL, &timeoutCopy);
  228. }
  229. CheckAndHandleError_int("select", ready);
  230. return ready > 0;
  231. }
  232. unsigned int Socket::PortNameToNumber(const char *name, const char *protocol)
  233. {
  234. int port = atoi(name);
  235. if (IntToString(port) == name)
  236. return port;
  237. servent *se = getservbyname(name, protocol);
  238. if (!se)
  239. throw Err(INVALID_SOCKET, "getservbyname", SOCKET_EINVAL);
  240. return ntohs(se->s_port);
  241. }
  242. void Socket::StartSockets()
  243. {
  244. #ifdef USE_WINDOWS_STYLE_SOCKETS
  245. WSADATA wsd;
  246. int result = WSAStartup(0x0202, &wsd);
  247. if (result != 0)
  248. throw Err(INVALID_SOCKET, "WSAStartup", result);
  249. #endif
  250. }
  251. void Socket::ShutdownSockets()
  252. {
  253. #ifdef USE_WINDOWS_STYLE_SOCKETS
  254. int result = WSACleanup();
  255. if (result != 0)
  256. throw Err(INVALID_SOCKET, "WSACleanup", result);
  257. #endif
  258. }
  259. int Socket::GetLastError()
  260. {
  261. #ifdef USE_WINDOWS_STYLE_SOCKETS
  262. return WSAGetLastError();
  263. #else
  264. return errno;
  265. #endif
  266. }
  267. void Socket::SetLastError(int errorCode)
  268. {
  269. #ifdef USE_WINDOWS_STYLE_SOCKETS
  270. WSASetLastError(errorCode);
  271. #else
  272. errno = errorCode;
  273. #endif
  274. }
  275. void Socket::HandleError(const char *operation) const
  276. {
  277. int err = GetLastError();
  278. throw Err(m_s, operation, err);
  279. }
  280. #ifdef USE_WINDOWS_STYLE_SOCKETS
  281. SocketReceiver::SocketReceiver(Socket &s)
  282. : m_s(s), m_eofReceived(false), m_resultPending(false)
  283. {
  284. m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
  285. m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
  286. memset(&m_overlapped, 0, sizeof(m_overlapped));
  287. m_overlapped.hEvent = m_event;
  288. }
  289. SocketReceiver::~SocketReceiver()
  290. {
  291. #ifdef USE_WINDOWS_STYLE_SOCKETS
  292. CancelIo((HANDLE) m_s.GetSocket());
  293. #endif
  294. }
  295. bool SocketReceiver::Receive(byte* buf, size_t bufLen)
  296. {
  297. assert(!m_resultPending && !m_eofReceived);
  298. DWORD flags = 0;
  299. // don't queue too much at once, or we might use up non-paged memory
  300. WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
  301. if (WSARecv(m_s, &wsabuf, 1, &m_lastResult, &flags, &m_overlapped, NULL) == 0)
  302. {
  303. if (m_lastResult == 0)
  304. m_eofReceived = true;
  305. }
  306. else
  307. {
  308. switch (WSAGetLastError())
  309. {
  310. default:
  311. m_s.CheckAndHandleError_int("WSARecv", SOCKET_ERROR);
  312. case WSAEDISCON:
  313. m_lastResult = 0;
  314. m_eofReceived = true;
  315. break;
  316. case WSA_IO_PENDING:
  317. m_resultPending = true;
  318. }
  319. }
  320. return !m_resultPending;
  321. }
  322. void SocketReceiver::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
  323. {
  324. if (m_resultPending)
  325. container.AddHandle(m_event, CallStack("SocketReceiver::GetWaitObjects() - result pending", &callStack));
  326. else if (!m_eofReceived)
  327. container.SetNoWait(CallStack("SocketReceiver::GetWaitObjects() - result ready", &callStack));
  328. }
  329. unsigned int SocketReceiver::GetReceiveResult()
  330. {
  331. if (m_resultPending)
  332. {
  333. DWORD flags = 0;
  334. if (WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags))
  335. {
  336. if (m_lastResult == 0)
  337. m_eofReceived = true;
  338. }
  339. else
  340. {
  341. switch (WSAGetLastError())
  342. {
  343. default:
  344. m_s.CheckAndHandleError("WSAGetOverlappedResult", FALSE);
  345. case WSAEDISCON:
  346. m_lastResult = 0;
  347. m_eofReceived = true;
  348. }
  349. }
  350. m_resultPending = false;
  351. }
  352. return m_lastResult;
  353. }
  354. // *************************************************************
  355. SocketSender::SocketSender(Socket &s)
  356. : m_s(s), m_resultPending(false), m_lastResult(0)
  357. {
  358. m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
  359. m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
  360. memset(&m_overlapped, 0, sizeof(m_overlapped));
  361. m_overlapped.hEvent = m_event;
  362. }
  363. SocketSender::~SocketSender()
  364. {
  365. #ifdef USE_WINDOWS_STYLE_SOCKETS
  366. CancelIo((HANDLE) m_s.GetSocket());
  367. #endif
  368. }
  369. void SocketSender::Send(const byte* buf, size_t bufLen)
  370. {
  371. assert(!m_resultPending);
  372. DWORD written = 0;
  373. // don't queue too much at once, or we might use up non-paged memory
  374. WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
  375. if (WSASend(m_s, &wsabuf, 1, &written, 0, &m_overlapped, NULL) == 0)
  376. {
  377. m_resultPending = false;
  378. m_lastResult = written;
  379. }
  380. else
  381. {
  382. if (WSAGetLastError() != WSA_IO_PENDING)
  383. m_s.CheckAndHandleError_int("WSASend", SOCKET_ERROR);
  384. m_resultPending = true;
  385. }
  386. }
  387. void SocketSender::SendEof()
  388. {
  389. assert(!m_resultPending);
  390. m_s.ShutDown(SD_SEND);
  391. m_s.CheckAndHandleError("ResetEvent", ResetEvent(m_event));
  392. m_s.CheckAndHandleError_int("WSAEventSelect", WSAEventSelect(m_s, m_event, FD_CLOSE));
  393. m_resultPending = true;
  394. }
  395. bool SocketSender::EofSent()
  396. {
  397. if (m_resultPending)
  398. {
  399. WSANETWORKEVENTS events;
  400. m_s.CheckAndHandleError_int("WSAEnumNetworkEvents", WSAEnumNetworkEvents(m_s, m_event, &events));
  401. if ((events.lNetworkEvents & FD_CLOSE) != FD_CLOSE)
  402. throw Socket::Err(m_s, "WSAEnumNetworkEvents (FD_CLOSE not present)", E_FAIL);
  403. if (events.iErrorCode[FD_CLOSE_BIT] != 0)
  404. throw Socket::Err(m_s, "FD_CLOSE (via WSAEnumNetworkEvents)", events.iErrorCode[FD_CLOSE_BIT]);
  405. m_resultPending = false;
  406. }
  407. return m_lastResult != 0;
  408. }
  409. void SocketSender::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
  410. {
  411. if (m_resultPending)
  412. container.AddHandle(m_event, CallStack("SocketSender::GetWaitObjects() - result pending", &callStack));
  413. else
  414. container.SetNoWait(CallStack("SocketSender::GetWaitObjects() - result ready", &callStack));
  415. }
  416. unsigned int SocketSender::GetSendResult()
  417. {
  418. if (m_resultPending)
  419. {
  420. DWORD flags = 0;
  421. BOOL result = WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags);
  422. m_s.CheckAndHandleError("WSAGetOverlappedResult", result);
  423. m_resultPending = false;
  424. }
  425. return m_lastResult;
  426. }
  427. #endif
  428. #ifdef USE_BERKELEY_STYLE_SOCKETS
  429. SocketReceiver::SocketReceiver(Socket &s)
  430. : m_s(s), m_eofReceived(false), m_lastResult(0)
  431. {
  432. }
  433. void SocketReceiver::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
  434. {
  435. if (!m_eofReceived)
  436. container.AddReadFd(m_s, CallStack("SocketReceiver::GetWaitObjects()", &callStack));
  437. }
  438. bool SocketReceiver::Receive(byte* buf, size_t bufLen)
  439. {
  440. m_lastResult = m_s.Receive(buf, bufLen);
  441. if (bufLen > 0 && m_lastResult == 0)
  442. m_eofReceived = true;
  443. return true;
  444. }
  445. unsigned int SocketReceiver::GetReceiveResult()
  446. {
  447. return m_lastResult;
  448. }
  449. SocketSender::SocketSender(Socket &s)
  450. : m_s(s), m_lastResult(0)
  451. {
  452. }
  453. void SocketSender::Send(const byte* buf, size_t bufLen)
  454. {
  455. m_lastResult = m_s.Send(buf, bufLen);
  456. }
  457. void SocketSender::SendEof()
  458. {
  459. m_s.ShutDown(SD_SEND);
  460. }
  461. unsigned int SocketSender::GetSendResult()
  462. {
  463. return m_lastResult;
  464. }
  465. void SocketSender::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
  466. {
  467. container.AddWriteFd(m_s, CallStack("SocketSender::GetWaitObjects()", &callStack));
  468. }
  469. #endif
  470. NAMESPACE_END
  471. #endif // #ifdef SOCKETS_AVAILABLE