|
|
/*++
Copyright (c) 2001 Microsoft Corporation
Module Name:
io.c
Abstract:
This module contains functions to manage all socket I/O between the server and clients, including socket management and overlapped completion indication. It also contains buffer management.
Author:
Jeffrey C. Venable, Sr. (jeffv) 01-Jun-2001
Revision History:
--*/
#include "precomp.h"
void TftpdIoFreeBuffer(PTFTPD_BUFFER buffer) {
PTFTPD_SOCKET socket = buffer->internal.socket;
TFTPD_DEBUG((TFTPD_TRACE_IO, "TftpdIoFreeBuffer(buffer = %p).\n", buffer));
HeapFree(globals.hServiceHeap, 0, buffer);
if ((InterlockedDecrement((PLONG)&socket->numBuffers) == -1) && (socket->context != NULL)) HeapFree(globals.hServiceHeap, 0, socket);
if (InterlockedDecrement(&globals.io.numBuffers) == -1) TftpdServiceAttemptCleanup();
} // TftpdIoFreeBuffer()
PTFTPD_BUFFER TftpdIoAllocateBuffer(PTFTPD_SOCKET socket) {
PTFTPD_BUFFER buffer;
TFTPD_DEBUG((TFTPD_TRACE_IO, "TftpdIoAllocateBuffer(socket = %s).\n", ((socket == &globals.io.master) ? "master" : ((socket == &globals.io.def) ? "def" : ((socket == &globals.io.mtu) ? "mtu" : ((socket == &globals.io.max) ? "max" : "private")))) ));
buffer = (PTFTPD_BUFFER)HeapAlloc(globals.hServiceHeap, 0, socket->buffersize); if (buffer == NULL) { TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoAllocateBuffer(socket = %s): " "HeapAlloc() failed, error 0x%08X.\n", ((socket == &globals.io.master) ? "master" : ((socket == &globals.io.def) ? "def" : ((socket == &globals.io.mtu) ? "mtu" : ((socket == &globals.io.max) ? "max" : "private")))), GetLastError())); return (NULL); } ZeroMemory(buffer, sizeof(buffer->internal));
InterlockedIncrement(&globals.io.numBuffers); InterlockedIncrement((PLONG)&socket->numBuffers);
buffer->internal.socket = socket; buffer->internal.datasize = socket->datasize;
if (globals.service.shutdown) { TftpdIoFreeBuffer(buffer); buffer = NULL; }
return (buffer);
} // TftpdIoAllocateBuffer()
PTFTPD_BUFFER TftpdIoSwapBuffer(PTFTPD_BUFFER buffer, PTFTPD_SOCKET socket) {
PTFTPD_BUFFER tmp;
ASSERT((buffer->message.opcode == TFTPD_RRQ) || (buffer->message.opcode == TFTPD_WRQ));
// Allocate a buffer for the new socket.
tmp = TftpdIoAllocateBuffer(socket);
// Copy information we need to retain.
if (tmp != NULL) {
tmp->internal.context = buffer->internal.context; tmp->internal.io.peerLen = buffer->internal.io.peerLen; CopyMemory(&tmp->internal.io.peer, &buffer->internal.io.peer, buffer->internal.io.peerLen); CopyMemory(&tmp->internal.io.msg, &buffer->internal.io.msg, sizeof(tmp->internal.io.msg)); CopyMemory(&tmp->internal.io.control, &buffer->internal.io.control, sizeof(tmp->internal.io.control));
} // if (tmp != NULL)
TFTPD_DEBUG((TFTPD_TRACE_IO, "TftpdIoCompletionCallback(buffer = %p): " "new buffer = %p.\n", buffer, tmp));
// Return the original buffer.
TftpdIoPostReceiveBuffer(buffer->internal.socket, buffer);
return (tmp);
} // TftpdIoSwapBuffer()
void TftpdIoCompletionCallback(DWORD dwErrorCode, DWORD dwBytes, LPOVERLAPPED overlapped) {
PTFTPD_BUFFER buffer = CONTAINING_RECORD(overlapped, TFTPD_BUFFER, internal.io.overlapped); PTFTPD_CONTEXT context = buffer->internal.context; PTFTPD_SOCKET socket = buffer->internal.socket;
TFTPD_DEBUG((TFTPD_TRACE_IO, "TftpdIoCompletionCallback(buffer = %p): bytes = %d.\n", buffer, dwBytes));
if (context == NULL) InterlockedDecrement((PLONG)&socket->postedBuffers);
switch (dwErrorCode) {
case STATUS_SUCCESS :
if (context == NULL) {
if (dwBytes < TFTPD_MIN_RECEIVED_DATA) goto exit_completion_callback;
buffer->internal.io.bytes = dwBytes; buffer = TftpdProcessReceivedBuffer(buffer);
} // if (context == NULL)
break;
case STATUS_PORT_UNREACHABLE :
TFTPD_DEBUG((TFTPD_TRACE_IO, "TftpdIoCompletionCallback(buffer = %p, context = %p): " "STATUS_PORT_UNREACHABLE.\n", buffer, context)); // If this was a write operation, kill the context.
if (context != NULL) { TftpdProcessError(buffer); context = NULL; } goto exit_completion_callback;
case STATUS_CANCELLED :
// If this was a write operation, kill the context.
if (context != NULL) { TFTPD_DEBUG((TFTPD_TRACE_IO, "TftpdIoCompletionCallback(buffer = %p, context = %p): " "STATUS_CANCELLED.\n", buffer, context)); TftpdProcessError(buffer); context = NULL; }
TftpdIoFreeBuffer(buffer); buffer = NULL;
goto exit_completion_callback;
default :
TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoCompletionCallback(buffer = %p): " "dwErrorcode = 0x%08X.\n", buffer, dwErrorCode)); goto exit_completion_callback;
} // switch (dwErrorCode)
exit_completion_callback :
if (context != NULL) {
// Do we bother reposting the buffer?
if (context->state & TFTPD_STATE_DEAD) { TftpdIoFreeBuffer(buffer); buffer = NULL; }
// Release the overlapped send reference.
TftpdContextRelease(context);
} // if (context != NULL)
if (buffer != NULL) TftpdIoPostReceiveBuffer(buffer->internal.socket, buffer);
} // TftpdIoCompletionCallback()
void CALLBACK TftpdIoReadNotification(PTFTPD_SOCKET socket, BOOLEAN timeout) {
TFTPD_DEBUG((TFTPD_TRACE_IO, "TftpdIoReadNotification(socket = %s).\n", ((socket == &globals.io.master) ? "master" : ((socket == &globals.io.def) ? "def" : ((socket == &globals.io.mtu) ? "mtu" : ((socket == &globals.io.max) ? "max" : "private")))) ));
// If this fails, the event triggering this callback will stop signalling
// due to a lack of a successful WSARecvFrom() ... this will likely occur
// during low-memory/stress conditions. When the system returns to normal,
// the low water-mark buffers will be reposted, thus receiving data and
// re-enabling the event which triggers this callback.
while (!globals.service.shutdown) if (TftpdIoPostReceiveBuffer(socket, NULL) >= socket->lowWaterMark) break;
} // TftpdIoReadNotification()
DWORD TftpdIoPostReceiveBuffer(PTFTPD_SOCKET socket, PTFTPD_BUFFER buffer) {
DWORD postedBuffers = 0, successfulPosts = 0; int error;
TFTPD_DEBUG((TFTPD_TRACE_IO, "TftpdIoPostReceiveBuffer(buffer = %p, socket = %s).\n", buffer, ((socket == &globals.io.master) ? "master" : ((socket == &globals.io.def) ? "def" : ((socket == &globals.io.mtu) ? "mtu" : ((socket == &globals.io.max) ? "max" : "private")))) ));
postedBuffers = InterlockedIncrement((PLONG)&socket->postedBuffers);
//
// Attempt to post a buffer:
//
while (TRUE) {
WSABUF buf;
if (globals.service.shutdown || (postedBuffers > globals.parameters.highWaterMark)) goto exit_post_buffer;
// Allocate the buffer if we're not reusing one.
if (buffer == NULL) {
buffer = TftpdIoAllocateBuffer(socket); if (buffer == NULL) { TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoPostReceiveBuffer(buffer = %p): " "TftpdIoAllocateBuffer() failed.\n", buffer)); goto exit_post_buffer; } TFTPD_DEBUG((TFTPD_TRACE_IO, "TftpdIoPostReceiveBuffer(buffer = %p).\n", buffer));
} else { if (socket->s == INVALID_SOCKET) goto exit_post_buffer;
ASSERT(buffer->internal.socket == socket); ZeroMemory(buffer, sizeof(buffer->internal)); buffer->internal.socket = socket; buffer->internal.datasize = socket->datasize;
} // if (buffer == NULL)
buf.buf = ((char *)buffer + FIELD_OFFSET(TFTPD_BUFFER, message.opcode)); buf.len = (FIELD_OFFSET(TFTPD_BUFFER, message.data.data) - FIELD_OFFSET(TFTPD_BUFFER, message.opcode) + socket->datasize);
error = NO_ERROR;
if (socket == &globals.io.master) {
DWORD bytes = 0; buffer->internal.io.msg.lpBuffers = &buf; buffer->internal.io.msg.dwBufferCount = 1; buffer->internal.io.msg.name = (LPSOCKADDR)&buffer->internal.io.peer; buffer->internal.io.msg.namelen = sizeof(buffer->internal.io.peer); buffer->internal.io.peerLen = sizeof(buffer->internal.io.peer); buffer->internal.io.msg.Control.buf = (char *)&buffer->internal.io.control; buffer->internal.io.msg.Control.len = sizeof(buffer->internal.io.control); buffer->internal.io.msg.dwFlags = 0; if (globals.fp.WSARecvMsg(socket->s, &buffer->internal.io.msg, &bytes, &buffer->internal.io.overlapped, NULL) == SOCKET_ERROR) error = WSAGetLastError();
} else {
DWORD bytes = 0; buffer->internal.io.peerLen = sizeof(buffer->internal.io.peer); if (WSARecvFrom(socket->s, &buf, 1, &bytes, &buffer->internal.io.flags, (PSOCKADDR)&buffer->internal.io.peer, &buffer->internal.io.peerLen, &buffer->internal.io.overlapped, NULL) == SOCKET_ERROR) error = WSAGetLastError();
} // if (socket == &globals.io.master)
switch (error) {
case NO_ERROR : if (successfulPosts < 10) { successfulPosts++; postedBuffers = InterlockedIncrement((PLONG)&socket->postedBuffers); buffer = NULL; continue; } else { return (postedBuffers); }
case WSA_IO_PENDING : return (postedBuffers);
case WSAECONNRESET : TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoPostReceiveBuffer(buffer = %p): " "%s() failed for TID = <%s:%d>, WSAECONNRESET.\n", buffer, (socket == &globals.io.master) ? "WSARecvMsg" : "WSARecvFrom", inet_ntoa(buffer->internal.io.peer.sin_addr), ntohs(buffer->internal.io.peer.sin_port))); TftpdProcessError(buffer); continue;
default : TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoPostReceiveBuffer(buffer = %p): " "WSARecvMsg/From() failed, error 0x%08X.\n", buffer, error)); goto exit_post_buffer;
} // switch (error)
} // while (true)
exit_post_buffer :
postedBuffers = InterlockedDecrement((PLONG)&socket->postedBuffers); if (buffer != NULL) TftpdIoFreeBuffer(buffer);
return (postedBuffers);
} // TftpdIoPostReceiveBuffer()
void TftpdIoSendErrorPacket(PTFTPD_BUFFER buffer, TFTPD_ERROR_CODE error, char *reason) {
DWORD bytes = 0; WSABUF buf;
TFTPD_DEBUG((TFTPD_TRACE_IO, "TftpdIoSendErrorPacket(buffer = %p): %s\n", buffer, reason));
// Build the error message.
buffer->message.opcode = htons(TFTPD_ERROR); buffer->message.error.code = htons(error); strncpy(buffer->message.error.error, reason, buffer->internal.datasize); buffer->message.error.error[buffer->internal.datasize - 1] = '\0';
// Send it non-blocking only. If it fails, who cares, let the client deal with it.
buf.buf = (char *)&buffer->message.opcode; buf.len = (FIELD_OFFSET(TFTPD_BUFFER, message.error.error) - FIELD_OFFSET(TFTPD_BUFFER, message.opcode) + (strlen(buffer->message.error.error) + 1)); if (WSASendTo(buffer->internal.socket->s, &buf, 1, &bytes, 0, (PSOCKADDR)&buffer->internal.io.peer, sizeof(SOCKADDR_IN), NULL, NULL) == SOCKET_ERROR) { TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoSendErrorPacket(buffer = %p): WSASendTo() failed, error = %d.\n", buffer, WSAGetLastError())); }
} // TftpdIoSendErrorPacket()
PTFTPD_BUFFER TftpdIoSendPacket(PTFTPD_BUFFER buffer) {
PTFTPD_CONTEXT context = buffer->internal.context; DWORD bytes = 0; WSABUF buf;
// NOTE: 'context' must be referenced before this call!
ASSERT(context != NULL); ASSERT(context->reference >= 1); ASSERT(buffer->internal.socket != NULL);
TFTPD_DEBUG((TFTPD_TRACE_IO, "TftpdIoSendPacket(buffer = %p, context = %p): bytes = %d.\n", buffer, context, buffer->internal.io.bytes));
// First try sending it non-blocking.
buf.buf = (char *)&buffer->message.opcode; buf.len = buffer->internal.io.bytes; if (WSASendTo(context->socket->s, &buf, 1, &bytes, 0, (PSOCKADDR)&context->peer, sizeof(SOCKADDR_IN), NULL, NULL) == SOCKET_ERROR) {
if (WSAGetLastError() == WSAEWOULDBLOCK) {
// Keep an overlapped-operation reference to the context.
TftpdContextAddReference(context);
// Send it overlapped. When completion occurs, we'll know it was a send
// when buffer->internal.context is non-NULL.
if (WSASendTo(context->socket->s, &buf, 1, &bytes, 0, (PSOCKADDR)&context->peer, sizeof(SOCKADDR_IN), &buffer->internal.io.overlapped, NULL) == SOCKET_ERROR) {
if (WSAGetLastError() != WSA_IO_PENDING) { TFTPD_DEBUG((TFTPD_TRACE_IO, "TftpdIoSendPacket(buffer = %p, context = %p): " "overlapped send failed.\n", buffer, context)); // Release the overlapped-operation reference to the context.
TftpdContextRelease(context); goto exit_send_packet; }
} // if (WSASendTo(...) == SOCKET_ERROR)
buffer = NULL; // Tell the caller not to recycle a buffer.
} // if (WSAGetLastError() == WSAEWOULDBLOCK)
goto exit_send_packet;
} // if (WSASendTo(...) == SOCKET_ERROR)
//
// Non-blocking send succeeded.
//
exit_send_packet :
return (buffer);
} // TftpdIoSendPacket()
void TftpdIoLeakSocketContext(PTFTPD_SOCKET socket) {
PLIST_ENTRY entry; EnterCriticalSection(&globals.reaper.socketCS); {
// If shutdown is occuring, we're in trouble anyways.
// Just let it go.
if (globals.service.shutdown) { LeaveCriticalSection(&globals.reaper.socketCS); return; }
TFTPD_DEBUG((TFTPD_TRACE_CONTEXT, "TftpdIoLeakSocketContext(context = %p).\n", socket));
// Is the socket already in the list?
for (entry = globals.reaper.leakedSockets.Flink; entry != &globals.reaper.leakedSockets; entry = entry->Flink) {
if (CONTAINING_RECORD(entry, TFTPD_SOCKET, linkage) == socket) { LeaveCriticalSection(&globals.reaper.socketCS); return; }
}
InsertHeadList(&globals.reaper.leakedSockets, &socket->linkage); globals.reaper.numLeakedSockets++;
} LeaveCriticalSection(&globals.reaper.socketCS);
} // TftpdIoLeakSocketContext()
PTFTPD_SOCKET TftpdIoAllocateSocketContext() {
PTFTPD_SOCKET socket = NULL;
if (globals.reaper.leakedSockets.Flink != &globals.reaper.leakedSockets) {
BOOL failAllocate = FALSE;
// Try to recover leaked contexts.
EnterCriticalSection(&globals.reaper.socketCS); {
PLIST_ENTRY entry; while ((entry = RemoveHeadList(&globals.reaper.leakedSockets)) != &globals.reaper.leakedSockets) {
PTFTPD_SOCKET s = CONTAINING_RECORD(entry, TFTPD_SOCKET, linkage);
globals.reaper.numLeakedSockets--; if (!TftpdIoDestroySocketContext(s)) { TftpdIoLeakSocketContext(s); failAllocate = TRUE; break; }
}
} LeaveCriticalSection(&globals.reaper.socketCS);
if (failAllocate) goto exit_allocate_context;
} // if (globals.reaper.leakedSockets.Flink != &globals.reaper.leakedSockets)
socket = (PTFTPD_SOCKET)HeapAlloc(globals.hServiceHeap, HEAP_ZERO_MEMORY, sizeof(TFTPD_SOCKET));
exit_allocate_context :
return (socket);
} // TftpdIoAllocateSocketContext()
void TftpdIoInitializeSocketContext(PTFTPD_SOCKET socket, PSOCKADDR_IN addr, PTFTPD_CONTEXT context) {
BOOL one = TRUE;
TFTPD_DEBUG((TFTPD_TRACE_IO, "TftpdIoInitializeSocketContext(socket = %s): TID = <%s:%d>.\n", ((socket == &globals.io.master) ? "master" : ((socket == &globals.io.def) ? "def" : ((socket == &globals.io.mtu) ? "mtu" : ((socket == &globals.io.max) ? "max" : "private")))), inet_ntoa(addr->sin_addr), ntohs(addr->sin_port)));
// NOTE: Do NOT zero-out 'socket', it has been initialized with some
// values we need to work with.
// Create the socket.
socket->s = WSASocket(AF_INET, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED); if (socket->s == INVALID_SOCKET) { TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoInitializeSocketContext: " "WSASocket() failed, error 0x%08X.\n", GetLastError())); SetLastError(WSAGetLastError()); goto fail_create_context; }
// Ensure that we will exclusively own our local port so nobody can hijack us.
if (setsockopt(socket->s, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, (const char *)&one, sizeof(one)) == SOCKET_ERROR) { TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoInitializeSocketContext: " "setsockopt(SO_EXCLUSIVEADDRUSE) failed, error 0x%08X.\n", GetLastError())); SetLastError(WSAGetLastError()); goto fail_create_context; }
// Bind the socket on the correct port.
if (bind(socket->s, (PSOCKADDR)addr, sizeof(SOCKADDR)) == SOCKET_ERROR) { TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoInitializeSocketContext: " "bind() failed, error 0x%08X.\n", GetLastError())); SetLastError(WSAGetLastError()); goto fail_create_context; }
// Register for completion callbacks on the socket.
if (!BindIoCompletionCallback((HANDLE)socket->s, TftpdIoCompletionCallback, 0)) { TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoInitializeSocketContext: " "BindIoCompletionCallback() failed, error 0x%08X.\n", GetLastError())); goto fail_create_context; }
// Indicate that we want WSARecvMsg() to fill-in packet information.
// Note we only do this on the master-socket only where we can receive TFTPD_RECV and
// TFTPD_WRITE requests, and we need to determine which socket to set the context to.
if (socket == &globals.io.master) {
// Obtain the WSARecvMsg() extension API pointer.
GUID g = WSAID_WSARECVMSG; int opt = TRUE; DWORD len;
if (WSAIoctl(socket->s, SIO_GET_EXTENSION_FUNCTION_POINTER, &g, sizeof(g), &globals.fp.WSARecvMsg, sizeof(globals.fp.WSARecvMsg), &len, NULL, NULL) == SOCKET_ERROR) { TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoInitializeSocketContext: " "WSAIoctl() failed, error 0x%08X.\n", WSAGetLastError())); goto fail_create_context; }
// Indicate that we want WSARecvMsg() to fill-in packet information.
if (setsockopt(socket->s, IPPROTO_IP, IP_PKTINFO, (char *)&opt, sizeof(opt)) == SOCKET_ERROR) { TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoInitializeSocketContext: " "setsockopt() failed, error 0x%08X.\n", WSAGetLastError())); goto fail_create_context; }
} // if (socket == &globals.io.master)
// Record the port used for this context.
CopyMemory(&socket->addr, addr, sizeof(socket->addr));
if (context == NULL) {
// Select the socket for read and write notifications.
// Read so when we know to get data, write so when we know
// whether to do send operations non-blocking or overlapped.
if ((socket->hSelect = CreateEvent(NULL, FALSE, FALSE, NULL)) == NULL) { TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoInitializeSocketContext: " "CreateEvent() failed, error 0x%08X.\n", GetLastError())); goto fail_create_context; }
if (WSAEventSelect(socket->s, socket->hSelect, FD_READ) == SOCKET_ERROR) { TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoInitializeSocketContext: " "WSAEventSelect() failed, error 0x%08X.\n", GetLastError())); SetLastError(WSAGetLastError()); goto fail_create_context; }
// Register for FD_READ notification on the socket.
if (!RegisterWaitForSingleObject(&socket->wSelectWait, socket->hSelect, (WAITORTIMERCALLBACK)TftpdIoReadNotification, socket, INFINITE, WT_EXECUTEINWAITTHREAD)) { TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoInitializeSocketContext: " "RegisterWaitForSingleObject() failed, error 0x%08X.\n", GetLastError())); goto fail_create_context; }
// Prepost the low water-mark number of receive buffers.
// If the FD_READ event signals on the master socket before we're done, we'll
// exceed the low water-mark here but that's harmless as the excess buffers
// will be freed upon completion.
if (!socket->lowWaterMark) socket->lowWaterMark = 1; if (!socket->highWaterMark) socket->highWaterMark = 1;
SetEvent(socket->hSelect);
} else {
// Is this a private socket (ie, not master, def, mtu, or max).
// If so, it will be destroyed when it's one and only one owning context is destroyed.
socket->context = context;
// Initialize read notification variables to NULL.
socket->hSelect = NULL; socket->wSelectWait = NULL; socket->lowWaterMark = 1;
TftpdIoPostReceiveBuffer(socket, NULL);
} // if (context == NULL)
return;
fail_create_context :
if (socket->s != INVALID_SOCKET) closesocket(socket->s), socket->s = INVALID_SOCKET; if (socket->hSelect != NULL) CloseHandle(socket->hSelect), socket->hSelect = NULL;
} // TftpdIoInitializeSocketContext()
BOOL TftpdIoAssignSocket(PTFTPD_CONTEXT context, PTFTPD_BUFFER buffer) {
SOCKADDR_IN addr; DWORD len = 0;
TFTPD_DEBUG((TFTPD_TRACE_PROCESS, "TftpdIoAssignSocket(context = %p, buffer = %p).\n", context, buffer));
if (!(buffer->internal.io.msg.dwFlags & MSG_BCAST)) {
PWSACMSGHDR header; IN_PKTINFO *packetInfo;
// Determine if routing problems force us to use a private socket so we can corrrectly
// send datagrams to the requesting client. First, get the best interface address for
// responding to the requesting client.
ZeroMemory(&addr, sizeof(addr));
// Make the ioctl call.
WSASetLastError(NO_ERROR); if ((WSAIoctl(globals.io.master.s, SIO_ROUTING_INTERFACE_QUERY, &buffer->internal.io.peer, buffer->internal.io.peerLen, &addr, sizeof(SOCKADDR_IN), &len, NULL, NULL) == SOCKET_ERROR) || (len != sizeof(SOCKADDR_IN))) { TFTPD_DEBUG((TFTPD_DBG_PROCESS, "TftpdIoAssignSocket(): " "WSAIoctl(SIO_ROUTING_INTERFACE_QUERY) failed, error = %d.\n", WSAGetLastError())); TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED, "Failed to initialize network endpoint."); return (FALSE); }
// Loop through the control (ancillary) data looking for our packet info.
header = WSA_CMSG_FIRSTHDR(&buffer->internal.io.msg); packetInfo = NULL; while (header) {
if ((header->cmsg_level == IPPROTO_IP) && (header->cmsg_type == IP_PKTINFO)) { packetInfo = (IN_PKTINFO *)WSA_CMSG_DATA(header); break; }
header = WSA_CMSG_NXTHDR(&buffer->internal.io.msg, header);
} // while (header)
// Check to see if the best interface we obtained is not the one the client sent the message to.
if ((packetInfo != NULL) && (addr.sin_addr.s_addr != packetInfo->ipi_addr.s_addr)) {
TFTPD_DEBUG((TFTPD_TRACE_PROCESS, "TftpdIoAssignSocket(context = %p, buffer = %p):\n" "\tRemote client TID = <%s:%d>\n", context, buffer, inet_ntoa(buffer->internal.io.peer.sin_addr), ntohs(buffer->internal.io.peer.sin_port) )); TFTPD_DEBUG((TFTPD_TRACE_PROCESS, "\tRequest issued to local IP = <%s>\n", inet_ntoa(packetInfo->ipi_addr) )); TFTPD_DEBUG((TFTPD_TRACE_PROCESS, "\tDefault route is over IP = <%s>\n", inet_ntoa(addr.sin_addr) ));
// We need to create a private socket for this client.
context->socket = TftpdIoAllocateSocketContext(); if (context->socket == NULL) { TFTPD_DEBUG((TFTPD_DBG_PROCESS, "TftpdIoAssignSocket(): " "TftpdIoAllocateSocketContext() failed.\n")); TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED, "Out of memory"); return (FALSE); } context->socket->s = INVALID_SOCKET; context->socket->buffersize = (TFTPD_BUFFER_SIZE) (FIELD_OFFSET(TFTPD_BUFFER, message.data.data) + context->blksize); context->socket->datasize = (TFTPD_DATA_SIZE)context->blksize;
if (!(buffer->internal.io.msg.dwFlags & MSG_BCAST)) { ZeroMemory(&addr, sizeof(addr)); addr.sin_family = AF_INET; addr.sin_addr.s_addr = packetInfo->ipi_addr.s_addr; }
TftpdIoInitializeSocketContext(context->socket, &addr, context); if (context->socket->s == INVALID_SOCKET) { TFTPD_DEBUG((TFTPD_DBG_PROCESS, "TftpdIoAssignSocket(): " "TftpdIoInitializeSocketContext() failed.\n")); TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED, "Failed to initialize network endpoint."); HeapFree(globals.hServiceHeap, 0, context->socket); context->socket = NULL; return (FALSE); }
#if defined(DBG)
InterlockedIncrement((PLONG)&globals.performance.privateSockets); #endif // defined(DBG)
return (TRUE);
} // if ((packetInfo != NULL) && ...)
} else { TFTPD_DEBUG((TFTPD_TRACE_PROCESS, "TftpdIoAssignSocket(context = %p, buffer = %p):\n" "\tRemote client TID = <%s:%d> issued broadcast request.\n", context, buffer, inet_ntoa(buffer->internal.io.peer.sin_addr), ntohs(buffer->internal.io.peer.sin_port) ));
} // if (!(buffer->internal.io.msg.dwFlags & MSG_BCAST))
ZeroMemory(&addr, sizeof(addr)); addr.sin_family = AF_INET; addr.sin_addr.s_addr = INADDR_ANY; addr.sin_port = 0;
// Figure out which socket to use for this request (based on blksize).
if (context->blksize <= TFTPD_DEF_DATA) {
if (globals.io.def.s == INVALID_SOCKET) {
EnterCriticalSection(&globals.io.cs); {
if (globals.service.shutdown) { TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED, "TFTPD service is stopping."); LeaveCriticalSection(&globals.io.cs); return (FALSE); }
TftpdIoInitializeSocketContext(&globals.io.def, &addr, NULL);
if (globals.io.def.s != INVALID_SOCKET) { context->socket = &globals.io.def; } else { context->socket = &globals.io.master; if (context->options) { TFTPD_DEBUG((TFTPD_TRACE_PROCESS, "TftpdIoAssignSocket(): Removing requested blksize = %d " "option since we failed to create the MTU-size socket.\n", context->blksize)); context->options &= ~TFTPD_OPTION_BLKSIZE; } }
} LeaveCriticalSection(&globals.io.cs);
} else {
context->socket = &globals.io.def;
} // if (globals.io.def.s == INVALID_SOCKET)
} else {
if (context->blksize <= TFTPD_MTU_DATA) {
if (globals.io.mtu.s == INVALID_SOCKET) {
EnterCriticalSection(&globals.io.cs); {
if (globals.service.shutdown) { TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED, "TFTPD service is stopping."); LeaveCriticalSection(&globals.io.cs); return (FALSE); }
TftpdIoInitializeSocketContext(&globals.io.mtu, &addr, NULL);
if (globals.io.mtu.s != INVALID_SOCKET) { context->socket = &globals.io.mtu; } else { context->socket = &globals.io.master; if (context->options) { TFTPD_DEBUG((TFTPD_TRACE_PROCESS, "TftpdIoAssignSocket(): Removing requested blksize = %d " "option since we failed to create the MTU-size socket.\n", context->blksize)); context->options &= ~TFTPD_OPTION_BLKSIZE; } }
} LeaveCriticalSection(&globals.io.cs);
} else {
context->socket = &globals.io.mtu;
} // if (globals.io.mtu.s == INVALID_SOCKET)
} else if (context->blksize <= TFTPD_MAX_DATA) {
if (globals.io.max.s == INVALID_SOCKET) {
EnterCriticalSection(&globals.io.cs); {
if (globals.service.shutdown) { TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED, "TFTPD service is stopping."); LeaveCriticalSection(&globals.io.cs); return (FALSE); }
TftpdIoInitializeSocketContext(&globals.io.max, &addr, NULL);
if (globals.io.max.s != INVALID_SOCKET) { context->socket = &globals.io.max; } else { context->socket = &globals.io.master; if (context->options) { TFTPD_DEBUG((TFTPD_TRACE_PROCESS, "TftpdIoAssignSocket(): Removing requested blksize = %d " "option since we failed to create the MAX-size socket.\n", context->blksize)); context->options &= ~TFTPD_OPTION_BLKSIZE; } }
} LeaveCriticalSection(&globals.io.cs);
} else {
context->socket = &globals.io.max;
} // if (globals.io.max.s == INVALID_SOCKET)
}
} // (context->blksize <= TFTPD_DEF_DATA)
return (TRUE);
} // TftpdIoAssignSocket()
BOOL TftpdIoDestroySocketContext(PTFTPD_SOCKET socket) {
NTSTATUS status; SOCKET s;
if (socket->s == INVALID_SOCKET) return (TRUE);
TFTPD_DEBUG((TFTPD_TRACE_IO, "TftpdIoDestroySocketContext(socket = %s).\n", ((socket == &globals.io.master) ? "master" : ((socket == &globals.io.def) ? "def" : ((socket == &globals.io.mtu) ? "mtu" : ((socket == &globals.io.max) ? "max" : "private")))) ));
// Disable further buffer posting.
socket->lowWaterMark = 0; if (socket->context == NULL) {
if (!UnregisterWait(socket->wSelectWait)) { DWORD error; if ((error = GetLastError()) != ERROR_IO_PENDING) { TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoDestroySocketContext: " "UnregisterWait() failed, error 0x%08X.\n", error)); TftpdIoLeakSocketContext(socket); return (FALSE); } } socket->wSelectWait = NULL;
CloseHandle(socket->hSelect); socket->hSelect = NULL;
} // if (socket->context == NULL)
// Kill the socket. This will disable the FD_READ and FD_WRITE
// event select, as well as cancel all pending overlapped operations
// on it. Add a buffer reference here so after we close the
// socket we can test if there were never any buffers posted
// which would cancel above in TftpdIoCompletionCallback so
// we should deallocate socket here.
// Kill it.
InterlockedIncrement((PLONG)&socket->numBuffers); s = socket->s; socket->s = INVALID_SOCKET; if (closesocket(s) == SOCKET_ERROR) { TFTPD_DEBUG((TFTPD_DBG_IO, "TftpdIoDestroySocketContext: " "closesocket() failed, error 0x%08X.\n", GetLastError())); socket->s = s; InterlockedDecrement((PLONG)&socket->numBuffers); TftpdIoLeakSocketContext(socket); return (FALSE); } if (InterlockedDecrement((PLONG)&socket->numBuffers) == -1) HeapFree(globals.hServiceHeap, 0, socket);
return (TRUE);
} // TftpdIoDestroySocketContext()
|