Leaked source code of windows server 2003
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1042 lines
37 KiB

  1. /*++
  2. Copyright (c) 2001 Microsoft Corporation
  3. Module Name:
  4. io.c
  5. Abstract:
  6. This module contains functions to manage all socket I/O
  7. between the server and clients, including socket management
  8. and overlapped completion indication. It also contains
  9. buffer management.
  10. Author:
  11. Jeffrey C. Venable, Sr. (jeffv) 01-Jun-2001
  12. Revision History:
  13. --*/
  14. #include "precomp.h"
  15. void
  16. TftpdIoFreeBuffer(PTFTPD_BUFFER buffer) {
  17. PTFTPD_SOCKET socket = buffer->internal.socket;
  18. TFTPD_DEBUG((TFTPD_TRACE_IO,
  19. "TftpdIoFreeBuffer(buffer = %p).\n",
  20. buffer));
  21. HeapFree(globals.hServiceHeap, 0, buffer);
  22. if ((InterlockedDecrement((PLONG)&socket->numBuffers) == -1) &&
  23. (socket->context != NULL))
  24. HeapFree(globals.hServiceHeap, 0, socket);
  25. if (InterlockedDecrement(&globals.io.numBuffers) == -1)
  26. TftpdServiceAttemptCleanup();
  27. } // TftpdIoFreeBuffer()
  28. PTFTPD_BUFFER
  29. TftpdIoAllocateBuffer(PTFTPD_SOCKET socket) {
  30. PTFTPD_BUFFER buffer;
  31. TFTPD_DEBUG((TFTPD_TRACE_IO,
  32. "TftpdIoAllocateBuffer(socket = %s).\n",
  33. ((socket == &globals.io.master) ? "master" :
  34. ((socket == &globals.io.def) ? "def" :
  35. ((socket == &globals.io.mtu) ? "mtu" :
  36. ((socket == &globals.io.max) ? "max" :
  37. "private")))) ));
  38. buffer = (PTFTPD_BUFFER)HeapAlloc(globals.hServiceHeap, 0,
  39. socket->buffersize);
  40. if (buffer == NULL) {
  41. TFTPD_DEBUG((TFTPD_DBG_IO,
  42. "TftpdIoAllocateBuffer(socket = %s): "
  43. "HeapAlloc() failed, error 0x%08X.\n",
  44. ((socket == &globals.io.master) ? "master" :
  45. ((socket == &globals.io.def) ? "def" :
  46. ((socket == &globals.io.mtu) ? "mtu" :
  47. ((socket == &globals.io.max) ? "max" :
  48. "private")))), GetLastError()));
  49. return (NULL);
  50. }
  51. ZeroMemory(buffer, sizeof(buffer->internal));
  52. InterlockedIncrement(&globals.io.numBuffers);
  53. InterlockedIncrement((PLONG)&socket->numBuffers);
  54. buffer->internal.socket = socket;
  55. buffer->internal.datasize = socket->datasize;
  56. if (globals.service.shutdown) {
  57. TftpdIoFreeBuffer(buffer);
  58. buffer = NULL;
  59. }
  60. return (buffer);
  61. } // TftpdIoAllocateBuffer()
  62. PTFTPD_BUFFER
  63. TftpdIoSwapBuffer(PTFTPD_BUFFER buffer, PTFTPD_SOCKET socket) {
  64. PTFTPD_BUFFER tmp;
  65. ASSERT((buffer->message.opcode == TFTPD_RRQ) ||
  66. (buffer->message.opcode == TFTPD_WRQ));
  67. // Allocate a buffer for the new socket.
  68. tmp = TftpdIoAllocateBuffer(socket);
  69. // Copy information we need to retain.
  70. if (tmp != NULL) {
  71. tmp->internal.context = buffer->internal.context;
  72. tmp->internal.io.peerLen = buffer->internal.io.peerLen;
  73. CopyMemory(&tmp->internal.io.peer,
  74. &buffer->internal.io.peer,
  75. buffer->internal.io.peerLen);
  76. CopyMemory(&tmp->internal.io.msg,
  77. &buffer->internal.io.msg,
  78. sizeof(tmp->internal.io.msg));
  79. CopyMemory(&tmp->internal.io.control,
  80. &buffer->internal.io.control,
  81. sizeof(tmp->internal.io.control));
  82. } // if (tmp != NULL)
  83. TFTPD_DEBUG((TFTPD_TRACE_IO,
  84. "TftpdIoCompletionCallback(buffer = %p): "
  85. "new buffer = %p.\n",
  86. buffer, tmp));
  87. // Return the original buffer.
  88. TftpdIoPostReceiveBuffer(buffer->internal.socket, buffer);
  89. return (tmp);
  90. } // TftpdIoSwapBuffer()
  91. void
  92. TftpdIoCompletionCallback(DWORD dwErrorCode,
  93. DWORD dwBytes,
  94. LPOVERLAPPED overlapped) {
  95. PTFTPD_BUFFER buffer = CONTAINING_RECORD(overlapped, TFTPD_BUFFER,
  96. internal.io.overlapped);
  97. PTFTPD_CONTEXT context = buffer->internal.context;
  98. PTFTPD_SOCKET socket = buffer->internal.socket;
  99. TFTPD_DEBUG((TFTPD_TRACE_IO,
  100. "TftpdIoCompletionCallback(buffer = %p): bytes = %d.\n",
  101. buffer, dwBytes));
  102. if (context == NULL)
  103. InterlockedDecrement((PLONG)&socket->postedBuffers);
  104. switch (dwErrorCode) {
  105. case STATUS_SUCCESS :
  106. if (context == NULL) {
  107. if (dwBytes < TFTPD_MIN_RECEIVED_DATA)
  108. goto exit_completion_callback;
  109. buffer->internal.io.bytes = dwBytes;
  110. buffer = TftpdProcessReceivedBuffer(buffer);
  111. } // if (context == NULL)
  112. break;
  113. case STATUS_PORT_UNREACHABLE :
  114. TFTPD_DEBUG((TFTPD_TRACE_IO,
  115. "TftpdIoCompletionCallback(buffer = %p, context = %p): "
  116. "STATUS_PORT_UNREACHABLE.\n",
  117. buffer, context));
  118. // If this was a write operation, kill the context.
  119. if (context != NULL) {
  120. TftpdProcessError(buffer);
  121. context = NULL;
  122. }
  123. goto exit_completion_callback;
  124. case STATUS_CANCELLED :
  125. // If this was a write operation, kill the context.
  126. if (context != NULL) {
  127. TFTPD_DEBUG((TFTPD_TRACE_IO,
  128. "TftpdIoCompletionCallback(buffer = %p, context = %p): "
  129. "STATUS_CANCELLED.\n",
  130. buffer, context));
  131. TftpdProcessError(buffer);
  132. context = NULL;
  133. }
  134. TftpdIoFreeBuffer(buffer);
  135. buffer = NULL;
  136. goto exit_completion_callback;
  137. default :
  138. TFTPD_DEBUG((TFTPD_DBG_IO,
  139. "TftpdIoCompletionCallback(buffer = %p): "
  140. "dwErrorcode = 0x%08X.\n",
  141. buffer, dwErrorCode));
  142. goto exit_completion_callback;
  143. } // switch (dwErrorCode)
  144. exit_completion_callback :
  145. if (context != NULL) {
  146. // Do we bother reposting the buffer?
  147. if (context->state & TFTPD_STATE_DEAD) {
  148. TftpdIoFreeBuffer(buffer);
  149. buffer = NULL;
  150. }
  151. // Release the overlapped send reference.
  152. TftpdContextRelease(context);
  153. } // if (context != NULL)
  154. if (buffer != NULL)
  155. TftpdIoPostReceiveBuffer(buffer->internal.socket, buffer);
  156. } // TftpdIoCompletionCallback()
  157. void CALLBACK
  158. TftpdIoReadNotification(PTFTPD_SOCKET socket, BOOLEAN timeout) {
  159. TFTPD_DEBUG((TFTPD_TRACE_IO,
  160. "TftpdIoReadNotification(socket = %s).\n",
  161. ((socket == &globals.io.master) ? "master" :
  162. ((socket == &globals.io.def) ? "def" :
  163. ((socket == &globals.io.mtu) ? "mtu" :
  164. ((socket == &globals.io.max) ? "max" :
  165. "private")))) ));
  166. // If this fails, the event triggering this callback will stop signalling
  167. // due to a lack of a successful WSARecvFrom() ... this will likely occur
  168. // during low-memory/stress conditions. When the system returns to normal,
  169. // the low water-mark buffers will be reposted, thus receiving data and
  170. // re-enabling the event which triggers this callback.
  171. while (!globals.service.shutdown)
  172. if (TftpdIoPostReceiveBuffer(socket, NULL) >= socket->lowWaterMark)
  173. break;
  174. } // TftpdIoReadNotification()
  175. DWORD
  176. TftpdIoPostReceiveBuffer(PTFTPD_SOCKET socket, PTFTPD_BUFFER buffer) {
  177. DWORD postedBuffers = 0, successfulPosts = 0;
  178. int error;
  179. TFTPD_DEBUG((TFTPD_TRACE_IO,
  180. "TftpdIoPostReceiveBuffer(buffer = %p, socket = %s).\n",
  181. buffer,
  182. ((socket == &globals.io.master) ? "master" :
  183. ((socket == &globals.io.def) ? "def" :
  184. ((socket == &globals.io.mtu) ? "mtu" :
  185. ((socket == &globals.io.max) ? "max" :
  186. "private")))) ));
  187. postedBuffers = InterlockedIncrement((PLONG)&socket->postedBuffers);
  188. //
  189. // Attempt to post a buffer:
  190. //
  191. while (TRUE) {
  192. WSABUF buf;
  193. if (globals.service.shutdown ||
  194. (postedBuffers > globals.parameters.highWaterMark))
  195. goto exit_post_buffer;
  196. // Allocate the buffer if we're not reusing one.
  197. if (buffer == NULL) {
  198. buffer = TftpdIoAllocateBuffer(socket);
  199. if (buffer == NULL) {
  200. TFTPD_DEBUG((TFTPD_DBG_IO,
  201. "TftpdIoPostReceiveBuffer(buffer = %p): "
  202. "TftpdIoAllocateBuffer() failed.\n",
  203. buffer));
  204. goto exit_post_buffer;
  205. }
  206. TFTPD_DEBUG((TFTPD_TRACE_IO,
  207. "TftpdIoPostReceiveBuffer(buffer = %p).\n",
  208. buffer));
  209. } else {
  210. if (socket->s == INVALID_SOCKET)
  211. goto exit_post_buffer;
  212. ASSERT(buffer->internal.socket == socket);
  213. ZeroMemory(buffer, sizeof(buffer->internal));
  214. buffer->internal.socket = socket;
  215. buffer->internal.datasize = socket->datasize;
  216. } // if (buffer == NULL)
  217. buf.buf = ((char *)buffer + FIELD_OFFSET(TFTPD_BUFFER, message.opcode));
  218. buf.len = (FIELD_OFFSET(TFTPD_BUFFER, message.data.data) -
  219. FIELD_OFFSET(TFTPD_BUFFER, message.opcode) +
  220. socket->datasize);
  221. error = NO_ERROR;
  222. if (socket == &globals.io.master) {
  223. DWORD bytes = 0;
  224. buffer->internal.io.msg.lpBuffers = &buf;
  225. buffer->internal.io.msg.dwBufferCount = 1;
  226. buffer->internal.io.msg.name = (LPSOCKADDR)&buffer->internal.io.peer;
  227. buffer->internal.io.msg.namelen = sizeof(buffer->internal.io.peer);
  228. buffer->internal.io.peerLen = sizeof(buffer->internal.io.peer);
  229. buffer->internal.io.msg.Control.buf = (char *)&buffer->internal.io.control;
  230. buffer->internal.io.msg.Control.len = sizeof(buffer->internal.io.control);
  231. buffer->internal.io.msg.dwFlags = 0;
  232. if (globals.fp.WSARecvMsg(socket->s, &buffer->internal.io.msg, &bytes,
  233. &buffer->internal.io.overlapped, NULL) == SOCKET_ERROR)
  234. error = WSAGetLastError();
  235. } else {
  236. DWORD bytes = 0;
  237. buffer->internal.io.peerLen = sizeof(buffer->internal.io.peer);
  238. if (WSARecvFrom(socket->s, &buf, 1, &bytes, &buffer->internal.io.flags,
  239. (PSOCKADDR)&buffer->internal.io.peer, &buffer->internal.io.peerLen,
  240. &buffer->internal.io.overlapped, NULL) == SOCKET_ERROR)
  241. error = WSAGetLastError();
  242. } // if (socket == &globals.io.master)
  243. switch (error) {
  244. case NO_ERROR :
  245. if (successfulPosts < 10) {
  246. successfulPosts++;
  247. postedBuffers = InterlockedIncrement((PLONG)&socket->postedBuffers);
  248. buffer = NULL;
  249. continue;
  250. } else {
  251. return (postedBuffers);
  252. }
  253. case WSA_IO_PENDING :
  254. return (postedBuffers);
  255. case WSAECONNRESET :
  256. TFTPD_DEBUG((TFTPD_DBG_IO,
  257. "TftpdIoPostReceiveBuffer(buffer = %p): "
  258. "%s() failed for TID = <%s:%d>, WSAECONNRESET.\n",
  259. buffer,
  260. (socket == &globals.io.master) ? "WSARecvMsg" : "WSARecvFrom",
  261. inet_ntoa(buffer->internal.io.peer.sin_addr),
  262. ntohs(buffer->internal.io.peer.sin_port)));
  263. TftpdProcessError(buffer);
  264. continue;
  265. default :
  266. TFTPD_DEBUG((TFTPD_DBG_IO,
  267. "TftpdIoPostReceiveBuffer(buffer = %p): "
  268. "WSARecvMsg/From() failed, error 0x%08X.\n",
  269. buffer, error));
  270. goto exit_post_buffer;
  271. } // switch (error)
  272. } // while (true)
  273. exit_post_buffer :
  274. postedBuffers = InterlockedDecrement((PLONG)&socket->postedBuffers);
  275. if (buffer != NULL)
  276. TftpdIoFreeBuffer(buffer);
  277. return (postedBuffers);
  278. } // TftpdIoPostReceiveBuffer()
  279. void
  280. TftpdIoSendErrorPacket(PTFTPD_BUFFER buffer, TFTPD_ERROR_CODE error, char *reason) {
  281. DWORD bytes = 0;
  282. WSABUF buf;
  283. TFTPD_DEBUG((TFTPD_TRACE_IO,
  284. "TftpdIoSendErrorPacket(buffer = %p): %s\n",
  285. buffer, reason));
  286. // Build the error message.
  287. buffer->message.opcode = htons(TFTPD_ERROR);
  288. buffer->message.error.code = htons(error);
  289. strncpy(buffer->message.error.error, reason, buffer->internal.datasize);
  290. buffer->message.error.error[buffer->internal.datasize - 1] = '\0';
  291. // Send it non-blocking only. If it fails, who cares, let the client deal with it.
  292. buf.buf = (char *)&buffer->message.opcode;
  293. buf.len = (FIELD_OFFSET(TFTPD_BUFFER, message.error.error) -
  294. FIELD_OFFSET(TFTPD_BUFFER, message.opcode) +
  295. (strlen(buffer->message.error.error) + 1));
  296. if (WSASendTo(buffer->internal.socket->s, &buf, 1, &bytes, 0,
  297. (PSOCKADDR)&buffer->internal.io.peer, sizeof(SOCKADDR_IN),
  298. NULL, NULL) == SOCKET_ERROR) {
  299. TFTPD_DEBUG((TFTPD_DBG_IO,
  300. "TftpdIoSendErrorPacket(buffer = %p): WSASendTo() failed, error = %d.\n",
  301. buffer, WSAGetLastError()));
  302. }
  303. } // TftpdIoSendErrorPacket()
  304. PTFTPD_BUFFER
  305. TftpdIoSendPacket(PTFTPD_BUFFER buffer) {
  306. PTFTPD_CONTEXT context = buffer->internal.context;
  307. DWORD bytes = 0;
  308. WSABUF buf;
  309. // NOTE: 'context' must be referenced before this call!
  310. ASSERT(context != NULL);
  311. ASSERT(context->reference >= 1);
  312. ASSERT(buffer->internal.socket != NULL);
  313. TFTPD_DEBUG((TFTPD_TRACE_IO,
  314. "TftpdIoSendPacket(buffer = %p, context = %p): bytes = %d.\n",
  315. buffer, context, buffer->internal.io.bytes));
  316. // First try sending it non-blocking.
  317. buf.buf = (char *)&buffer->message.opcode;
  318. buf.len = buffer->internal.io.bytes;
  319. if (WSASendTo(context->socket->s, &buf, 1, &bytes, 0,
  320. (PSOCKADDR)&context->peer, sizeof(SOCKADDR_IN),
  321. NULL, NULL) == SOCKET_ERROR) {
  322. if (WSAGetLastError() == WSAEWOULDBLOCK) {
  323. // Keep an overlapped-operation reference to the context.
  324. TftpdContextAddReference(context);
  325. // Send it overlapped. When completion occurs, we'll know it was a send
  326. // when buffer->internal.context is non-NULL.
  327. if (WSASendTo(context->socket->s, &buf, 1, &bytes, 0,
  328. (PSOCKADDR)&context->peer, sizeof(SOCKADDR_IN),
  329. &buffer->internal.io.overlapped, NULL) == SOCKET_ERROR) {
  330. if (WSAGetLastError() != WSA_IO_PENDING) {
  331. TFTPD_DEBUG((TFTPD_TRACE_IO,
  332. "TftpdIoSendPacket(buffer = %p, context = %p): "
  333. "overlapped send failed.\n",
  334. buffer, context));
  335. // Release the overlapped-operation reference to the context.
  336. TftpdContextRelease(context);
  337. goto exit_send_packet;
  338. }
  339. } // if (WSASendTo(...) == SOCKET_ERROR)
  340. buffer = NULL; // Tell the caller not to recycle a buffer.
  341. } // if (WSAGetLastError() == WSAEWOULDBLOCK)
  342. goto exit_send_packet;
  343. } // if (WSASendTo(...) == SOCKET_ERROR)
  344. //
  345. // Non-blocking send succeeded.
  346. //
  347. exit_send_packet :
  348. return (buffer);
  349. } // TftpdIoSendPacket()
  350. void
  351. TftpdIoLeakSocketContext(PTFTPD_SOCKET socket) {
  352. PLIST_ENTRY entry;
  353. EnterCriticalSection(&globals.reaper.socketCS); {
  354. // If shutdown is occuring, we're in trouble anyways.
  355. // Just let it go.
  356. if (globals.service.shutdown) {
  357. LeaveCriticalSection(&globals.reaper.socketCS);
  358. return;
  359. }
  360. TFTPD_DEBUG((TFTPD_TRACE_CONTEXT,
  361. "TftpdIoLeakSocketContext(context = %p).\n",
  362. socket));
  363. // Is the socket already in the list?
  364. for (entry = globals.reaper.leakedSockets.Flink;
  365. entry != &globals.reaper.leakedSockets;
  366. entry = entry->Flink) {
  367. if (CONTAINING_RECORD(entry, TFTPD_SOCKET, linkage) == socket) {
  368. LeaveCriticalSection(&globals.reaper.socketCS);
  369. return;
  370. }
  371. }
  372. InsertHeadList(&globals.reaper.leakedSockets, &socket->linkage);
  373. globals.reaper.numLeakedSockets++;
  374. } LeaveCriticalSection(&globals.reaper.socketCS);
  375. } // TftpdIoLeakSocketContext()
  376. PTFTPD_SOCKET
  377. TftpdIoAllocateSocketContext() {
  378. PTFTPD_SOCKET socket = NULL;
  379. if (globals.reaper.leakedSockets.Flink != &globals.reaper.leakedSockets) {
  380. BOOL failAllocate = FALSE;
  381. // Try to recover leaked contexts.
  382. EnterCriticalSection(&globals.reaper.socketCS); {
  383. PLIST_ENTRY entry;
  384. while ((entry = RemoveHeadList(&globals.reaper.leakedSockets)) !=
  385. &globals.reaper.leakedSockets) {
  386. PTFTPD_SOCKET s = CONTAINING_RECORD(entry, TFTPD_SOCKET, linkage);
  387. globals.reaper.numLeakedSockets--;
  388. if (!TftpdIoDestroySocketContext(s)) {
  389. TftpdIoLeakSocketContext(s);
  390. failAllocate = TRUE;
  391. break;
  392. }
  393. }
  394. } LeaveCriticalSection(&globals.reaper.socketCS);
  395. if (failAllocate)
  396. goto exit_allocate_context;
  397. } // if (globals.reaper.leakedSockets.Flink != &globals.reaper.leakedSockets)
  398. socket = (PTFTPD_SOCKET)HeapAlloc(globals.hServiceHeap,
  399. HEAP_ZERO_MEMORY,
  400. sizeof(TFTPD_SOCKET));
  401. exit_allocate_context :
  402. return (socket);
  403. } // TftpdIoAllocateSocketContext()
  404. void
  405. TftpdIoInitializeSocketContext(PTFTPD_SOCKET socket, PSOCKADDR_IN addr, PTFTPD_CONTEXT context) {
  406. BOOL one = TRUE;
  407. TFTPD_DEBUG((TFTPD_TRACE_IO,
  408. "TftpdIoInitializeSocketContext(socket = %s): TID = <%s:%d>.\n",
  409. ((socket == &globals.io.master) ? "master" :
  410. ((socket == &globals.io.def) ? "def" :
  411. ((socket == &globals.io.mtu) ? "mtu" :
  412. ((socket == &globals.io.max) ? "max" : "private")))),
  413. inet_ntoa(addr->sin_addr), ntohs(addr->sin_port)));
  414. // NOTE: Do NOT zero-out 'socket', it has been initialized with some
  415. // values we need to work with.
  416. // Create the socket.
  417. socket->s = WSASocket(AF_INET, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
  418. if (socket->s == INVALID_SOCKET) {
  419. TFTPD_DEBUG((TFTPD_DBG_IO,
  420. "TftpdIoInitializeSocketContext: "
  421. "WSASocket() failed, error 0x%08X.\n",
  422. GetLastError()));
  423. SetLastError(WSAGetLastError());
  424. goto fail_create_context;
  425. }
  426. // Ensure that we will exclusively own our local port so nobody can hijack us.
  427. if (setsockopt(socket->s,
  428. SOL_SOCKET,
  429. SO_EXCLUSIVEADDRUSE,
  430. (const char *)&one,
  431. sizeof(one)) == SOCKET_ERROR) {
  432. TFTPD_DEBUG((TFTPD_DBG_IO,
  433. "TftpdIoInitializeSocketContext: "
  434. "setsockopt(SO_EXCLUSIVEADDRUSE) failed, error 0x%08X.\n",
  435. GetLastError()));
  436. SetLastError(WSAGetLastError());
  437. goto fail_create_context;
  438. }
  439. // Bind the socket on the correct port.
  440. if (bind(socket->s, (PSOCKADDR)addr, sizeof(SOCKADDR)) == SOCKET_ERROR) {
  441. TFTPD_DEBUG((TFTPD_DBG_IO,
  442. "TftpdIoInitializeSocketContext: "
  443. "bind() failed, error 0x%08X.\n",
  444. GetLastError()));
  445. SetLastError(WSAGetLastError());
  446. goto fail_create_context;
  447. }
  448. // Register for completion callbacks on the socket.
  449. if (!BindIoCompletionCallback((HANDLE)socket->s, TftpdIoCompletionCallback, 0)) {
  450. TFTPD_DEBUG((TFTPD_DBG_IO,
  451. "TftpdIoInitializeSocketContext: "
  452. "BindIoCompletionCallback() failed, error 0x%08X.\n",
  453. GetLastError()));
  454. goto fail_create_context;
  455. }
  456. // Indicate that we want WSARecvMsg() to fill-in packet information.
  457. // Note we only do this on the master-socket only where we can receive TFTPD_RECV and
  458. // TFTPD_WRITE requests, and we need to determine which socket to set the context to.
  459. if (socket == &globals.io.master) {
  460. // Obtain the WSARecvMsg() extension API pointer.
  461. GUID g = WSAID_WSARECVMSG;
  462. int opt = TRUE;
  463. DWORD len;
  464. if (WSAIoctl(socket->s, SIO_GET_EXTENSION_FUNCTION_POINTER, &g, sizeof(g),
  465. &globals.fp.WSARecvMsg, sizeof(globals.fp.WSARecvMsg),
  466. &len, NULL, NULL) == SOCKET_ERROR) {
  467. TFTPD_DEBUG((TFTPD_DBG_IO,
  468. "TftpdIoInitializeSocketContext: "
  469. "WSAIoctl() failed, error 0x%08X.\n",
  470. WSAGetLastError()));
  471. goto fail_create_context;
  472. }
  473. // Indicate that we want WSARecvMsg() to fill-in packet information.
  474. if (setsockopt(socket->s, IPPROTO_IP, IP_PKTINFO,
  475. (char *)&opt, sizeof(opt)) == SOCKET_ERROR) {
  476. TFTPD_DEBUG((TFTPD_DBG_IO,
  477. "TftpdIoInitializeSocketContext: "
  478. "setsockopt() failed, error 0x%08X.\n",
  479. WSAGetLastError()));
  480. goto fail_create_context;
  481. }
  482. } // if (socket == &globals.io.master)
  483. // Record the port used for this context.
  484. CopyMemory(&socket->addr, addr, sizeof(socket->addr));
  485. if (context == NULL) {
  486. // Select the socket for read and write notifications.
  487. // Read so when we know to get data, write so when we know
  488. // whether to do send operations non-blocking or overlapped.
  489. if ((socket->hSelect = CreateEvent(NULL, FALSE, FALSE, NULL)) == NULL) {
  490. TFTPD_DEBUG((TFTPD_DBG_IO,
  491. "TftpdIoInitializeSocketContext: "
  492. "CreateEvent() failed, error 0x%08X.\n",
  493. GetLastError()));
  494. goto fail_create_context;
  495. }
  496. if (WSAEventSelect(socket->s, socket->hSelect, FD_READ) == SOCKET_ERROR) {
  497. TFTPD_DEBUG((TFTPD_DBG_IO,
  498. "TftpdIoInitializeSocketContext: "
  499. "WSAEventSelect() failed, error 0x%08X.\n",
  500. GetLastError()));
  501. SetLastError(WSAGetLastError());
  502. goto fail_create_context;
  503. }
  504. // Register for FD_READ notification on the socket.
  505. if (!RegisterWaitForSingleObject(&socket->wSelectWait,
  506. socket->hSelect,
  507. (WAITORTIMERCALLBACK)TftpdIoReadNotification,
  508. socket,
  509. INFINITE,
  510. WT_EXECUTEINWAITTHREAD)) {
  511. TFTPD_DEBUG((TFTPD_DBG_IO,
  512. "TftpdIoInitializeSocketContext: "
  513. "RegisterWaitForSingleObject() failed, error 0x%08X.\n",
  514. GetLastError()));
  515. goto fail_create_context;
  516. }
  517. // Prepost the low water-mark number of receive buffers.
  518. // If the FD_READ event signals on the master socket before we're done, we'll
  519. // exceed the low water-mark here but that's harmless as the excess buffers
  520. // will be freed upon completion.
  521. if (!socket->lowWaterMark)
  522. socket->lowWaterMark = 1;
  523. if (!socket->highWaterMark)
  524. socket->highWaterMark = 1;
  525. SetEvent(socket->hSelect);
  526. } else {
  527. // Is this a private socket (ie, not master, def, mtu, or max).
  528. // If so, it will be destroyed when it's one and only one owning context is destroyed.
  529. socket->context = context;
  530. // Initialize read notification variables to NULL.
  531. socket->hSelect = NULL;
  532. socket->wSelectWait = NULL;
  533. socket->lowWaterMark = 1;
  534. TftpdIoPostReceiveBuffer(socket, NULL);
  535. } // if (context == NULL)
  536. return;
  537. fail_create_context :
  538. if (socket->s != INVALID_SOCKET)
  539. closesocket(socket->s), socket->s = INVALID_SOCKET;
  540. if (socket->hSelect != NULL)
  541. CloseHandle(socket->hSelect), socket->hSelect = NULL;
  542. } // TftpdIoInitializeSocketContext()
  543. BOOL
  544. TftpdIoAssignSocket(PTFTPD_CONTEXT context, PTFTPD_BUFFER buffer) {
  545. SOCKADDR_IN addr;
  546. DWORD len = 0;
  547. TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
  548. "TftpdIoAssignSocket(context = %p, buffer = %p).\n",
  549. context, buffer));
  550. if (!(buffer->internal.io.msg.dwFlags & MSG_BCAST)) {
  551. PWSACMSGHDR header;
  552. IN_PKTINFO *packetInfo;
  553. // Determine if routing problems force us to use a private socket so we can corrrectly
  554. // send datagrams to the requesting client. First, get the best interface address for
  555. // responding to the requesting client.
  556. ZeroMemory(&addr, sizeof(addr));
  557. // Make the ioctl call.
  558. WSASetLastError(NO_ERROR);
  559. if ((WSAIoctl(globals.io.master.s, SIO_ROUTING_INTERFACE_QUERY,
  560. &buffer->internal.io.peer, buffer->internal.io.peerLen,
  561. &addr, sizeof(SOCKADDR_IN),
  562. &len, NULL, NULL) == SOCKET_ERROR) ||
  563. (len != sizeof(SOCKADDR_IN))) {
  564. TFTPD_DEBUG((TFTPD_DBG_PROCESS,
  565. "TftpdIoAssignSocket(): "
  566. "WSAIoctl(SIO_ROUTING_INTERFACE_QUERY) failed, error = %d.\n",
  567. WSAGetLastError()));
  568. TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED,
  569. "Failed to initialize network endpoint.");
  570. return (FALSE);
  571. }
  572. // Loop through the control (ancillary) data looking for our packet info.
  573. header = WSA_CMSG_FIRSTHDR(&buffer->internal.io.msg);
  574. packetInfo = NULL;
  575. while (header) {
  576. if ((header->cmsg_level == IPPROTO_IP) && (header->cmsg_type == IP_PKTINFO)) {
  577. packetInfo = (IN_PKTINFO *)WSA_CMSG_DATA(header);
  578. break;
  579. }
  580. header = WSA_CMSG_NXTHDR(&buffer->internal.io.msg, header);
  581. } // while (header)
  582. // Check to see if the best interface we obtained is not the one the client sent the message to.
  583. if ((packetInfo != NULL) &&
  584. (addr.sin_addr.s_addr != packetInfo->ipi_addr.s_addr)) {
  585. TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
  586. "TftpdIoAssignSocket(context = %p, buffer = %p):\n"
  587. "\tRemote client TID = <%s:%d>\n",
  588. context, buffer,
  589. inet_ntoa(buffer->internal.io.peer.sin_addr),
  590. ntohs(buffer->internal.io.peer.sin_port) ));
  591. TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
  592. "\tRequest issued to local IP = <%s>\n",
  593. inet_ntoa(packetInfo->ipi_addr) ));
  594. TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
  595. "\tDefault route is over IP = <%s>\n",
  596. inet_ntoa(addr.sin_addr) ));
  597. // We need to create a private socket for this client.
  598. context->socket = TftpdIoAllocateSocketContext();
  599. if (context->socket == NULL) {
  600. TFTPD_DEBUG((TFTPD_DBG_PROCESS,
  601. "TftpdIoAssignSocket(): "
  602. "TftpdIoAllocateSocketContext() failed.\n"));
  603. TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED,
  604. "Out of memory");
  605. return (FALSE);
  606. }
  607. context->socket->s = INVALID_SOCKET;
  608. context->socket->buffersize = (TFTPD_BUFFER_SIZE)
  609. (FIELD_OFFSET(TFTPD_BUFFER, message.data.data) +
  610. context->blksize);
  611. context->socket->datasize = (TFTPD_DATA_SIZE)context->blksize;
  612. if (!(buffer->internal.io.msg.dwFlags & MSG_BCAST)) {
  613. ZeroMemory(&addr, sizeof(addr));
  614. addr.sin_family = AF_INET;
  615. addr.sin_addr.s_addr = packetInfo->ipi_addr.s_addr;
  616. }
  617. TftpdIoInitializeSocketContext(context->socket, &addr, context);
  618. if (context->socket->s == INVALID_SOCKET) {
  619. TFTPD_DEBUG((TFTPD_DBG_PROCESS,
  620. "TftpdIoAssignSocket(): "
  621. "TftpdIoInitializeSocketContext() failed.\n"));
  622. TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED,
  623. "Failed to initialize network endpoint.");
  624. HeapFree(globals.hServiceHeap, 0, context->socket);
  625. context->socket = NULL;
  626. return (FALSE);
  627. }
  628. #if defined(DBG)
  629. InterlockedIncrement((PLONG)&globals.performance.privateSockets);
  630. #endif // defined(DBG)
  631. return (TRUE);
  632. } // if ((packetInfo != NULL) && ...)
  633. } else {
  634. TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
  635. "TftpdIoAssignSocket(context = %p, buffer = %p):\n"
  636. "\tRemote client TID = <%s:%d> issued broadcast request.\n",
  637. context, buffer,
  638. inet_ntoa(buffer->internal.io.peer.sin_addr), ntohs(buffer->internal.io.peer.sin_port) ));
  639. } // if (!(buffer->internal.io.msg.dwFlags & MSG_BCAST))
  640. ZeroMemory(&addr, sizeof(addr));
  641. addr.sin_family = AF_INET;
  642. addr.sin_addr.s_addr = INADDR_ANY;
  643. addr.sin_port = 0;
  644. // Figure out which socket to use for this request (based on blksize).
  645. if (context->blksize <= TFTPD_DEF_DATA) {
  646. if (globals.io.def.s == INVALID_SOCKET) {
  647. EnterCriticalSection(&globals.io.cs); {
  648. if (globals.service.shutdown) {
  649. TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED, "TFTPD service is stopping.");
  650. LeaveCriticalSection(&globals.io.cs);
  651. return (FALSE);
  652. }
  653. TftpdIoInitializeSocketContext(&globals.io.def, &addr, NULL);
  654. if (globals.io.def.s != INVALID_SOCKET) {
  655. context->socket = &globals.io.def;
  656. } else {
  657. context->socket = &globals.io.master;
  658. if (context->options) {
  659. TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
  660. "TftpdIoAssignSocket(): Removing requested blksize = %d "
  661. "option since we failed to create the MTU-size socket.\n",
  662. context->blksize));
  663. context->options &= ~TFTPD_OPTION_BLKSIZE;
  664. }
  665. }
  666. } LeaveCriticalSection(&globals.io.cs);
  667. } else {
  668. context->socket = &globals.io.def;
  669. } // if (globals.io.def.s == INVALID_SOCKET)
  670. } else {
  671. if (context->blksize <= TFTPD_MTU_DATA) {
  672. if (globals.io.mtu.s == INVALID_SOCKET) {
  673. EnterCriticalSection(&globals.io.cs); {
  674. if (globals.service.shutdown) {
  675. TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED, "TFTPD service is stopping.");
  676. LeaveCriticalSection(&globals.io.cs);
  677. return (FALSE);
  678. }
  679. TftpdIoInitializeSocketContext(&globals.io.mtu, &addr, NULL);
  680. if (globals.io.mtu.s != INVALID_SOCKET) {
  681. context->socket = &globals.io.mtu;
  682. } else {
  683. context->socket = &globals.io.master;
  684. if (context->options) {
  685. TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
  686. "TftpdIoAssignSocket(): Removing requested blksize = %d "
  687. "option since we failed to create the MTU-size socket.\n",
  688. context->blksize));
  689. context->options &= ~TFTPD_OPTION_BLKSIZE;
  690. }
  691. }
  692. } LeaveCriticalSection(&globals.io.cs);
  693. } else {
  694. context->socket = &globals.io.mtu;
  695. } // if (globals.io.mtu.s == INVALID_SOCKET)
  696. } else if (context->blksize <= TFTPD_MAX_DATA) {
  697. if (globals.io.max.s == INVALID_SOCKET) {
  698. EnterCriticalSection(&globals.io.cs); {
  699. if (globals.service.shutdown) {
  700. TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED, "TFTPD service is stopping.");
  701. LeaveCriticalSection(&globals.io.cs);
  702. return (FALSE);
  703. }
  704. TftpdIoInitializeSocketContext(&globals.io.max, &addr, NULL);
  705. if (globals.io.max.s != INVALID_SOCKET) {
  706. context->socket = &globals.io.max;
  707. } else {
  708. context->socket = &globals.io.master;
  709. if (context->options) {
  710. TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
  711. "TftpdIoAssignSocket(): Removing requested blksize = %d "
  712. "option since we failed to create the MAX-size socket.\n",
  713. context->blksize));
  714. context->options &= ~TFTPD_OPTION_BLKSIZE;
  715. }
  716. }
  717. } LeaveCriticalSection(&globals.io.cs);
  718. } else {
  719. context->socket = &globals.io.max;
  720. } // if (globals.io.max.s == INVALID_SOCKET)
  721. }
  722. } // (context->blksize <= TFTPD_DEF_DATA)
  723. return (TRUE);
  724. } // TftpdIoAssignSocket()
  725. BOOL
  726. TftpdIoDestroySocketContext(PTFTPD_SOCKET socket) {
  727. NTSTATUS status;
  728. SOCKET s;
  729. if (socket->s == INVALID_SOCKET)
  730. return (TRUE);
  731. TFTPD_DEBUG((TFTPD_TRACE_IO,
  732. "TftpdIoDestroySocketContext(socket = %s).\n",
  733. ((socket == &globals.io.master) ? "master" :
  734. ((socket == &globals.io.def) ? "def" :
  735. ((socket == &globals.io.mtu) ? "mtu" :
  736. ((socket == &globals.io.max) ? "max" :
  737. "private")))) ));
  738. // Disable further buffer posting.
  739. socket->lowWaterMark = 0;
  740. if (socket->context == NULL) {
  741. if (!UnregisterWait(socket->wSelectWait)) {
  742. DWORD error;
  743. if ((error = GetLastError()) != ERROR_IO_PENDING) {
  744. TFTPD_DEBUG((TFTPD_DBG_IO,
  745. "TftpdIoDestroySocketContext: "
  746. "UnregisterWait() failed, error 0x%08X.\n",
  747. error));
  748. TftpdIoLeakSocketContext(socket);
  749. return (FALSE);
  750. }
  751. }
  752. socket->wSelectWait = NULL;
  753. CloseHandle(socket->hSelect);
  754. socket->hSelect = NULL;
  755. } // if (socket->context == NULL)
  756. // Kill the socket. This will disable the FD_READ and FD_WRITE
  757. // event select, as well as cancel all pending overlapped operations
  758. // on it. Add a buffer reference here so after we close the
  759. // socket we can test if there were never any buffers posted
  760. // which would cancel above in TftpdIoCompletionCallback so
  761. // we should deallocate socket here.
  762. // Kill it.
  763. InterlockedIncrement((PLONG)&socket->numBuffers);
  764. s = socket->s;
  765. socket->s = INVALID_SOCKET;
  766. if (closesocket(s) == SOCKET_ERROR) {
  767. TFTPD_DEBUG((TFTPD_DBG_IO,
  768. "TftpdIoDestroySocketContext: "
  769. "closesocket() failed, error 0x%08X.\n",
  770. GetLastError()));
  771. socket->s = s;
  772. InterlockedDecrement((PLONG)&socket->numBuffers);
  773. TftpdIoLeakSocketContext(socket);
  774. return (FALSE);
  775. }
  776. if (InterlockedDecrement((PLONG)&socket->numBuffers) == -1)
  777. HeapFree(globals.hServiceHeap, 0, socket);
  778. return (TRUE);
  779. } // TftpdIoDestroySocketContext()