Counter Strike : Global Offensive Source Code
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

531 lines
12 KiB

  1. // socketft.cpp - written and placed in the public domain by Wei Dai
  2. #include "pch.h"
  3. #include "socketft.h"
  4. #ifdef SOCKETS_AVAILABLE
  5. #include "wait.h"
  6. #ifdef USE_BERKELEY_STYLE_SOCKETS
  7. #include <errno.h>
  8. #include <netdb.h>
  9. #include <unistd.h>
  10. #include <arpa/inet.h>
  11. #include <netinet/in.h>
  12. #include <sys/ioctl.h>
  13. #endif
  14. NAMESPACE_BEGIN(CryptoPP)
  15. #ifdef USE_WINDOWS_STYLE_SOCKETS
  16. const int SOCKET_EINVAL = WSAEINVAL;
  17. const int SOCKET_EWOULDBLOCK = WSAEWOULDBLOCK;
  18. typedef int socklen_t;
  19. #else
  20. const int SOCKET_EINVAL = EINVAL;
  21. const int SOCKET_EWOULDBLOCK = EWOULDBLOCK;
  22. #endif
  23. Socket::Err::Err(socket_t s, const std::string& operation, int error)
  24. : OS_Error(IO_ERROR, "Socket: " + operation + " operation failed with error " + IntToString(error), operation, error)
  25. , m_s(s)
  26. {
  27. }
  28. Socket::~Socket()
  29. {
  30. if (m_own)
  31. {
  32. try
  33. {
  34. CloseSocket();
  35. }
  36. catch (...)
  37. {
  38. }
  39. }
  40. }
  41. void Socket::AttachSocket(socket_t s, bool own)
  42. {
  43. if (m_own)
  44. CloseSocket();
  45. m_s = s;
  46. m_own = own;
  47. SocketChanged();
  48. }
  49. socket_t Socket::DetachSocket()
  50. {
  51. socket_t s = m_s;
  52. m_s = INVALID_SOCKET;
  53. SocketChanged();
  54. return s;
  55. }
  56. void Socket::Create(int nType)
  57. {
  58. assert(m_s == INVALID_SOCKET);
  59. m_s = socket(AF_INET, nType, 0);
  60. CheckAndHandleError("socket", m_s);
  61. m_own = true;
  62. SocketChanged();
  63. }
  64. void Socket::CloseSocket()
  65. {
  66. if (m_s != INVALID_SOCKET)
  67. {
  68. #ifdef USE_WINDOWS_STYLE_SOCKETS
  69. CancelIo((HANDLE) m_s);
  70. CheckAndHandleError_int("closesocket", closesocket(m_s));
  71. #else
  72. CheckAndHandleError_int("close", close(m_s));
  73. #endif
  74. m_s = INVALID_SOCKET;
  75. SocketChanged();
  76. }
  77. }
  78. void Socket::Bind(unsigned int port, const char *addr)
  79. {
  80. sockaddr_in sa;
  81. memset(&sa, 0, sizeof(sa));
  82. sa.sin_family = AF_INET;
  83. if (addr == NULL)
  84. sa.sin_addr.s_addr = htonl(INADDR_ANY);
  85. else
  86. {
  87. unsigned long result = inet_addr(addr);
  88. if (result == -1) // Solaris doesn't have INADDR_NONE
  89. {
  90. SetLastError(SOCKET_EINVAL);
  91. CheckAndHandleError_int("inet_addr", SOCKET_ERROR);
  92. }
  93. sa.sin_addr.s_addr = result;
  94. }
  95. sa.sin_port = htons((u_short)port);
  96. Bind((sockaddr *)&sa, sizeof(sa));
  97. }
  98. void Socket::Bind(const sockaddr *psa, socklen_t saLen)
  99. {
  100. assert(m_s != INVALID_SOCKET);
  101. // cygwin workaround: needs const_cast
  102. CheckAndHandleError_int("bind", bind(m_s, const_cast<sockaddr *>(psa), saLen));
  103. }
  104. void Socket::Listen(int backlog)
  105. {
  106. assert(m_s != INVALID_SOCKET);
  107. CheckAndHandleError_int("listen", listen(m_s, backlog));
  108. }
  109. bool Socket::Connect(const char *addr, unsigned int port)
  110. {
  111. assert(addr != NULL);
  112. sockaddr_in sa;
  113. memset(&sa, 0, sizeof(sa));
  114. sa.sin_family = AF_INET;
  115. sa.sin_addr.s_addr = inet_addr(addr);
  116. if (sa.sin_addr.s_addr == -1) // Solaris doesn't have INADDR_NONE
  117. {
  118. hostent *lphost = gethostbyname(addr);
  119. if (lphost == NULL)
  120. {
  121. SetLastError(SOCKET_EINVAL);
  122. CheckAndHandleError_int("gethostbyname", SOCKET_ERROR);
  123. }
  124. sa.sin_addr.s_addr = ((in_addr *)lphost->h_addr)->s_addr;
  125. }
  126. sa.sin_port = htons((u_short)port);
  127. return Connect((const sockaddr *)&sa, sizeof(sa));
  128. }
  129. bool Socket::Connect(const sockaddr* psa, socklen_t saLen)
  130. {
  131. assert(m_s != INVALID_SOCKET);
  132. int result = connect(m_s, const_cast<sockaddr*>(psa), saLen);
  133. if (result == SOCKET_ERROR && GetLastError() == SOCKET_EWOULDBLOCK)
  134. return false;
  135. CheckAndHandleError_int("connect", result);
  136. return true;
  137. }
  138. bool Socket::Accept(Socket& target, sockaddr *psa, socklen_t *psaLen)
  139. {
  140. assert(m_s != INVALID_SOCKET);
  141. socket_t s = accept(m_s, psa, psaLen);
  142. if (s == INVALID_SOCKET && GetLastError() == SOCKET_EWOULDBLOCK)
  143. return false;
  144. CheckAndHandleError("accept", s);
  145. target.AttachSocket(s, true);
  146. return true;
  147. }
  148. void Socket::GetSockName(sockaddr *psa, socklen_t *psaLen)
  149. {
  150. assert(m_s != INVALID_SOCKET);
  151. CheckAndHandleError_int("getsockname", getsockname(m_s, psa, psaLen));
  152. }
  153. void Socket::GetPeerName(sockaddr *psa, socklen_t *psaLen)
  154. {
  155. assert(m_s != INVALID_SOCKET);
  156. CheckAndHandleError_int("getpeername", getpeername(m_s, psa, psaLen));
  157. }
  158. unsigned int Socket::Send(const byte* buf, size_t bufLen, int flags)
  159. {
  160. assert(m_s != INVALID_SOCKET);
  161. int result = send(m_s, (const char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
  162. CheckAndHandleError_int("send", result);
  163. return result;
  164. }
  165. unsigned int Socket::Receive(byte* buf, size_t bufLen, int flags)
  166. {
  167. assert(m_s != INVALID_SOCKET);
  168. int result = recv(m_s, (char *)buf, UnsignedMin(INT_MAX, bufLen), flags);
  169. CheckAndHandleError_int("recv", result);
  170. return result;
  171. }
  172. void Socket::ShutDown(int how)
  173. {
  174. assert(m_s != INVALID_SOCKET);
  175. int result = shutdown(m_s, how);
  176. CheckAndHandleError_int("shutdown", result);
  177. }
  178. void Socket::IOCtl(long cmd, unsigned long *argp)
  179. {
  180. assert(m_s != INVALID_SOCKET);
  181. #ifdef USE_WINDOWS_STYLE_SOCKETS
  182. CheckAndHandleError_int("ioctlsocket", ioctlsocket(m_s, cmd, argp));
  183. #else
  184. CheckAndHandleError_int("ioctl", ioctl(m_s, cmd, argp));
  185. #endif
  186. }
  187. bool Socket::SendReady(const timeval *timeout)
  188. {
  189. fd_set fds;
  190. FD_ZERO(&fds);
  191. FD_SET(m_s, &fds);
  192. int ready;
  193. if (timeout == NULL)
  194. ready = select((int)m_s+1, NULL, &fds, NULL, NULL);
  195. else
  196. {
  197. timeval timeoutCopy = *timeout; // select() modified timeout on Linux
  198. ready = select((int)m_s+1, NULL, &fds, NULL, &timeoutCopy);
  199. }
  200. CheckAndHandleError_int("select", ready);
  201. return ready > 0;
  202. }
  203. bool Socket::ReceiveReady(const timeval *timeout)
  204. {
  205. fd_set fds;
  206. FD_ZERO(&fds);
  207. FD_SET(m_s, &fds);
  208. int ready;
  209. if (timeout == NULL)
  210. ready = select((int)m_s+1, &fds, NULL, NULL, NULL);
  211. else
  212. {
  213. timeval timeoutCopy = *timeout; // select() modified timeout on Linux
  214. ready = select((int)m_s+1, &fds, NULL, NULL, &timeoutCopy);
  215. }
  216. CheckAndHandleError_int("select", ready);
  217. return ready > 0;
  218. }
  219. unsigned int Socket::PortNameToNumber(const char *name, const char *protocol)
  220. {
  221. int port = atoi(name);
  222. if (IntToString(port) == name)
  223. return port;
  224. servent *se = getservbyname(name, protocol);
  225. if (!se)
  226. throw Err(INVALID_SOCKET, "getservbyname", SOCKET_EINVAL);
  227. return ntohs(se->s_port);
  228. }
  229. void Socket::StartSockets()
  230. {
  231. #ifdef USE_WINDOWS_STYLE_SOCKETS
  232. WSADATA wsd;
  233. int result = WSAStartup(0x0202, &wsd);
  234. if (result != 0)
  235. throw Err(INVALID_SOCKET, "WSAStartup", result);
  236. #endif
  237. }
  238. void Socket::ShutdownSockets()
  239. {
  240. #ifdef USE_WINDOWS_STYLE_SOCKETS
  241. int result = WSACleanup();
  242. if (result != 0)
  243. throw Err(INVALID_SOCKET, "WSACleanup", result);
  244. #endif
  245. }
  246. int Socket::GetLastError()
  247. {
  248. #ifdef USE_WINDOWS_STYLE_SOCKETS
  249. return WSAGetLastError();
  250. #else
  251. return errno;
  252. #endif
  253. }
  254. void Socket::SetLastError(int errorCode)
  255. {
  256. #ifdef USE_WINDOWS_STYLE_SOCKETS
  257. WSASetLastError(errorCode);
  258. #else
  259. errno = errorCode;
  260. #endif
  261. }
  262. void Socket::HandleError(const char *operation) const
  263. {
  264. int err = GetLastError();
  265. throw Err(m_s, operation, err);
  266. }
  267. #ifdef USE_WINDOWS_STYLE_SOCKETS
  268. SocketReceiver::SocketReceiver(Socket &s)
  269. : m_s(s), m_resultPending(false), m_eofReceived(false)
  270. {
  271. m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
  272. m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
  273. memset(&m_overlapped, 0, sizeof(m_overlapped));
  274. m_overlapped.hEvent = m_event;
  275. }
  276. SocketReceiver::~SocketReceiver()
  277. {
  278. #ifdef USE_WINDOWS_STYLE_SOCKETS
  279. CancelIo((HANDLE) m_s.GetSocket());
  280. #endif
  281. }
  282. bool SocketReceiver::Receive(byte* buf, size_t bufLen)
  283. {
  284. assert(!m_resultPending && !m_eofReceived);
  285. DWORD flags = 0;
  286. // don't queue too much at once, or we might use up non-paged memory
  287. WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
  288. if (WSARecv(m_s, &wsabuf, 1, &m_lastResult, &flags, &m_overlapped, NULL) == 0)
  289. {
  290. if (m_lastResult == 0)
  291. m_eofReceived = true;
  292. }
  293. else
  294. {
  295. switch (WSAGetLastError())
  296. {
  297. default:
  298. m_s.CheckAndHandleError_int("WSARecv", SOCKET_ERROR);
  299. case WSAEDISCON:
  300. m_lastResult = 0;
  301. m_eofReceived = true;
  302. break;
  303. case WSA_IO_PENDING:
  304. m_resultPending = true;
  305. }
  306. }
  307. return !m_resultPending;
  308. }
  309. void SocketReceiver::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
  310. {
  311. if (m_resultPending)
  312. container.AddHandle(m_event, CallStack("SocketReceiver::GetWaitObjects() - result pending", &callStack));
  313. else if (!m_eofReceived)
  314. container.SetNoWait(CallStack("SocketReceiver::GetWaitObjects() - result ready", &callStack));
  315. }
  316. unsigned int SocketReceiver::GetReceiveResult()
  317. {
  318. if (m_resultPending)
  319. {
  320. DWORD flags = 0;
  321. if (WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags))
  322. {
  323. if (m_lastResult == 0)
  324. m_eofReceived = true;
  325. }
  326. else
  327. {
  328. switch (WSAGetLastError())
  329. {
  330. default:
  331. m_s.CheckAndHandleError("WSAGetOverlappedResult", FALSE);
  332. case WSAEDISCON:
  333. m_lastResult = 0;
  334. m_eofReceived = true;
  335. }
  336. }
  337. m_resultPending = false;
  338. }
  339. return m_lastResult;
  340. }
  341. // *************************************************************
  342. SocketSender::SocketSender(Socket &s)
  343. : m_s(s), m_resultPending(false), m_lastResult(0)
  344. {
  345. m_event.AttachHandle(CreateEvent(NULL, true, false, NULL), true);
  346. m_s.CheckAndHandleError("CreateEvent", m_event.HandleValid());
  347. memset(&m_overlapped, 0, sizeof(m_overlapped));
  348. m_overlapped.hEvent = m_event;
  349. }
  350. SocketSender::~SocketSender()
  351. {
  352. #ifdef USE_WINDOWS_STYLE_SOCKETS
  353. CancelIo((HANDLE) m_s.GetSocket());
  354. #endif
  355. }
  356. void SocketSender::Send(const byte* buf, size_t bufLen)
  357. {
  358. assert(!m_resultPending);
  359. DWORD written = 0;
  360. // don't queue too much at once, or we might use up non-paged memory
  361. WSABUF wsabuf = {UnsignedMin((u_long)128*1024, bufLen), (char *)buf};
  362. if (WSASend(m_s, &wsabuf, 1, &written, 0, &m_overlapped, NULL) == 0)
  363. {
  364. m_resultPending = false;
  365. m_lastResult = written;
  366. }
  367. else
  368. {
  369. if (WSAGetLastError() != WSA_IO_PENDING)
  370. m_s.CheckAndHandleError_int("WSASend", SOCKET_ERROR);
  371. m_resultPending = true;
  372. }
  373. }
  374. void SocketSender::SendEof()
  375. {
  376. assert(!m_resultPending);
  377. m_s.ShutDown(SD_SEND);
  378. m_s.CheckAndHandleError("ResetEvent", ResetEvent(m_event));
  379. m_s.CheckAndHandleError_int("WSAEventSelect", WSAEventSelect(m_s, m_event, FD_CLOSE));
  380. m_resultPending = true;
  381. }
  382. bool SocketSender::EofSent()
  383. {
  384. if (m_resultPending)
  385. {
  386. WSANETWORKEVENTS events;
  387. m_s.CheckAndHandleError_int("WSAEnumNetworkEvents", WSAEnumNetworkEvents(m_s, m_event, &events));
  388. if ((events.lNetworkEvents & FD_CLOSE) != FD_CLOSE)
  389. throw Socket::Err(m_s, "WSAEnumNetworkEvents (FD_CLOSE not present)", E_FAIL);
  390. if (events.iErrorCode[FD_CLOSE_BIT] != 0)
  391. throw Socket::Err(m_s, "FD_CLOSE (via WSAEnumNetworkEvents)", events.iErrorCode[FD_CLOSE_BIT]);
  392. m_resultPending = false;
  393. }
  394. return m_lastResult != 0;
  395. }
  396. void SocketSender::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
  397. {
  398. if (m_resultPending)
  399. container.AddHandle(m_event, CallStack("SocketSender::GetWaitObjects() - result pending", &callStack));
  400. else
  401. container.SetNoWait(CallStack("SocketSender::GetWaitObjects() - result ready", &callStack));
  402. }
  403. unsigned int SocketSender::GetSendResult()
  404. {
  405. if (m_resultPending)
  406. {
  407. DWORD flags = 0;
  408. BOOL result = WSAGetOverlappedResult(m_s, &m_overlapped, &m_lastResult, false, &flags);
  409. m_s.CheckAndHandleError("WSAGetOverlappedResult", result);
  410. m_resultPending = false;
  411. }
  412. return m_lastResult;
  413. }
  414. #endif
  415. #ifdef USE_BERKELEY_STYLE_SOCKETS
  416. SocketReceiver::SocketReceiver(Socket &s)
  417. : m_s(s), m_lastResult(0), m_eofReceived(false)
  418. {
  419. }
  420. void SocketReceiver::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
  421. {
  422. if (!m_eofReceived)
  423. container.AddReadFd(m_s, CallStack("SocketReceiver::GetWaitObjects()", &callStack));
  424. }
  425. bool SocketReceiver::Receive(byte* buf, size_t bufLen)
  426. {
  427. m_lastResult = m_s.Receive(buf, bufLen);
  428. if (bufLen > 0 && m_lastResult == 0)
  429. m_eofReceived = true;
  430. return true;
  431. }
  432. unsigned int SocketReceiver::GetReceiveResult()
  433. {
  434. return m_lastResult;
  435. }
  436. SocketSender::SocketSender(Socket &s)
  437. : m_s(s), m_lastResult(0)
  438. {
  439. }
  440. void SocketSender::Send(const byte* buf, size_t bufLen)
  441. {
  442. m_lastResult = m_s.Send(buf, bufLen);
  443. }
  444. void SocketSender::SendEof()
  445. {
  446. m_s.ShutDown(SD_SEND);
  447. }
  448. unsigned int SocketSender::GetSendResult()
  449. {
  450. return m_lastResult;
  451. }
  452. void SocketSender::GetWaitObjects(WaitObjectContainer &container, CallStack const& callStack)
  453. {
  454. container.AddWriteFd(m_s, CallStack("SocketSender::GetWaitObjects()", &callStack));
  455. }
  456. #endif
  457. NAMESPACE_END
  458. #endif // #ifdef SOCKETS_AVAILABLE