Leaked source code of windows server 2003
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1042 lines
37 KiB

/*++
Copyright (c) 2001 Microsoft Corporation
Module Name:
io.c
Abstract:
This module contains functions to manage all socket I/O
between the server and clients, including socket management
and overlapped completion indication. It also contains
buffer management.
Author:
Jeffrey C. Venable, Sr. (jeffv) 01-Jun-2001
Revision History:
--*/
#include "precomp.h"
void
TftpdIoFreeBuffer(PTFTPD_BUFFER buffer) {
PTFTPD_SOCKET socket = buffer->internal.socket;
TFTPD_DEBUG((TFTPD_TRACE_IO,
"TftpdIoFreeBuffer(buffer = %p).\n",
buffer));
HeapFree(globals.hServiceHeap, 0, buffer);
if ((InterlockedDecrement((PLONG)&socket->numBuffers) == -1) &&
(socket->context != NULL))
HeapFree(globals.hServiceHeap, 0, socket);
if (InterlockedDecrement(&globals.io.numBuffers) == -1)
TftpdServiceAttemptCleanup();
} // TftpdIoFreeBuffer()
PTFTPD_BUFFER
TftpdIoAllocateBuffer(PTFTPD_SOCKET socket) {
PTFTPD_BUFFER buffer;
TFTPD_DEBUG((TFTPD_TRACE_IO,
"TftpdIoAllocateBuffer(socket = %s).\n",
((socket == &globals.io.master) ? "master" :
((socket == &globals.io.def) ? "def" :
((socket == &globals.io.mtu) ? "mtu" :
((socket == &globals.io.max) ? "max" :
"private")))) ));
buffer = (PTFTPD_BUFFER)HeapAlloc(globals.hServiceHeap, 0,
socket->buffersize);
if (buffer == NULL) {
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoAllocateBuffer(socket = %s): "
"HeapAlloc() failed, error 0x%08X.\n",
((socket == &globals.io.master) ? "master" :
((socket == &globals.io.def) ? "def" :
((socket == &globals.io.mtu) ? "mtu" :
((socket == &globals.io.max) ? "max" :
"private")))), GetLastError()));
return (NULL);
}
ZeroMemory(buffer, sizeof(buffer->internal));
InterlockedIncrement(&globals.io.numBuffers);
InterlockedIncrement((PLONG)&socket->numBuffers);
buffer->internal.socket = socket;
buffer->internal.datasize = socket->datasize;
if (globals.service.shutdown) {
TftpdIoFreeBuffer(buffer);
buffer = NULL;
}
return (buffer);
} // TftpdIoAllocateBuffer()
PTFTPD_BUFFER
TftpdIoSwapBuffer(PTFTPD_BUFFER buffer, PTFTPD_SOCKET socket) {
PTFTPD_BUFFER tmp;
ASSERT((buffer->message.opcode == TFTPD_RRQ) ||
(buffer->message.opcode == TFTPD_WRQ));
// Allocate a buffer for the new socket.
tmp = TftpdIoAllocateBuffer(socket);
// Copy information we need to retain.
if (tmp != NULL) {
tmp->internal.context = buffer->internal.context;
tmp->internal.io.peerLen = buffer->internal.io.peerLen;
CopyMemory(&tmp->internal.io.peer,
&buffer->internal.io.peer,
buffer->internal.io.peerLen);
CopyMemory(&tmp->internal.io.msg,
&buffer->internal.io.msg,
sizeof(tmp->internal.io.msg));
CopyMemory(&tmp->internal.io.control,
&buffer->internal.io.control,
sizeof(tmp->internal.io.control));
} // if (tmp != NULL)
TFTPD_DEBUG((TFTPD_TRACE_IO,
"TftpdIoCompletionCallback(buffer = %p): "
"new buffer = %p.\n",
buffer, tmp));
// Return the original buffer.
TftpdIoPostReceiveBuffer(buffer->internal.socket, buffer);
return (tmp);
} // TftpdIoSwapBuffer()
void
TftpdIoCompletionCallback(DWORD dwErrorCode,
DWORD dwBytes,
LPOVERLAPPED overlapped) {
PTFTPD_BUFFER buffer = CONTAINING_RECORD(overlapped, TFTPD_BUFFER,
internal.io.overlapped);
PTFTPD_CONTEXT context = buffer->internal.context;
PTFTPD_SOCKET socket = buffer->internal.socket;
TFTPD_DEBUG((TFTPD_TRACE_IO,
"TftpdIoCompletionCallback(buffer = %p): bytes = %d.\n",
buffer, dwBytes));
if (context == NULL)
InterlockedDecrement((PLONG)&socket->postedBuffers);
switch (dwErrorCode) {
case STATUS_SUCCESS :
if (context == NULL) {
if (dwBytes < TFTPD_MIN_RECEIVED_DATA)
goto exit_completion_callback;
buffer->internal.io.bytes = dwBytes;
buffer = TftpdProcessReceivedBuffer(buffer);
} // if (context == NULL)
break;
case STATUS_PORT_UNREACHABLE :
TFTPD_DEBUG((TFTPD_TRACE_IO,
"TftpdIoCompletionCallback(buffer = %p, context = %p): "
"STATUS_PORT_UNREACHABLE.\n",
buffer, context));
// If this was a write operation, kill the context.
if (context != NULL) {
TftpdProcessError(buffer);
context = NULL;
}
goto exit_completion_callback;
case STATUS_CANCELLED :
// If this was a write operation, kill the context.
if (context != NULL) {
TFTPD_DEBUG((TFTPD_TRACE_IO,
"TftpdIoCompletionCallback(buffer = %p, context = %p): "
"STATUS_CANCELLED.\n",
buffer, context));
TftpdProcessError(buffer);
context = NULL;
}
TftpdIoFreeBuffer(buffer);
buffer = NULL;
goto exit_completion_callback;
default :
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoCompletionCallback(buffer = %p): "
"dwErrorcode = 0x%08X.\n",
buffer, dwErrorCode));
goto exit_completion_callback;
} // switch (dwErrorCode)
exit_completion_callback :
if (context != NULL) {
// Do we bother reposting the buffer?
if (context->state & TFTPD_STATE_DEAD) {
TftpdIoFreeBuffer(buffer);
buffer = NULL;
}
// Release the overlapped send reference.
TftpdContextRelease(context);
} // if (context != NULL)
if (buffer != NULL)
TftpdIoPostReceiveBuffer(buffer->internal.socket, buffer);
} // TftpdIoCompletionCallback()
void CALLBACK
TftpdIoReadNotification(PTFTPD_SOCKET socket, BOOLEAN timeout) {
TFTPD_DEBUG((TFTPD_TRACE_IO,
"TftpdIoReadNotification(socket = %s).\n",
((socket == &globals.io.master) ? "master" :
((socket == &globals.io.def) ? "def" :
((socket == &globals.io.mtu) ? "mtu" :
((socket == &globals.io.max) ? "max" :
"private")))) ));
// If this fails, the event triggering this callback will stop signalling
// due to a lack of a successful WSARecvFrom() ... this will likely occur
// during low-memory/stress conditions. When the system returns to normal,
// the low water-mark buffers will be reposted, thus receiving data and
// re-enabling the event which triggers this callback.
while (!globals.service.shutdown)
if (TftpdIoPostReceiveBuffer(socket, NULL) >= socket->lowWaterMark)
break;
} // TftpdIoReadNotification()
DWORD
TftpdIoPostReceiveBuffer(PTFTPD_SOCKET socket, PTFTPD_BUFFER buffer) {
DWORD postedBuffers = 0, successfulPosts = 0;
int error;
TFTPD_DEBUG((TFTPD_TRACE_IO,
"TftpdIoPostReceiveBuffer(buffer = %p, socket = %s).\n",
buffer,
((socket == &globals.io.master) ? "master" :
((socket == &globals.io.def) ? "def" :
((socket == &globals.io.mtu) ? "mtu" :
((socket == &globals.io.max) ? "max" :
"private")))) ));
postedBuffers = InterlockedIncrement((PLONG)&socket->postedBuffers);
//
// Attempt to post a buffer:
//
while (TRUE) {
WSABUF buf;
if (globals.service.shutdown ||
(postedBuffers > globals.parameters.highWaterMark))
goto exit_post_buffer;
// Allocate the buffer if we're not reusing one.
if (buffer == NULL) {
buffer = TftpdIoAllocateBuffer(socket);
if (buffer == NULL) {
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoPostReceiveBuffer(buffer = %p): "
"TftpdIoAllocateBuffer() failed.\n",
buffer));
goto exit_post_buffer;
}
TFTPD_DEBUG((TFTPD_TRACE_IO,
"TftpdIoPostReceiveBuffer(buffer = %p).\n",
buffer));
} else {
if (socket->s == INVALID_SOCKET)
goto exit_post_buffer;
ASSERT(buffer->internal.socket == socket);
ZeroMemory(buffer, sizeof(buffer->internal));
buffer->internal.socket = socket;
buffer->internal.datasize = socket->datasize;
} // if (buffer == NULL)
buf.buf = ((char *)buffer + FIELD_OFFSET(TFTPD_BUFFER, message.opcode));
buf.len = (FIELD_OFFSET(TFTPD_BUFFER, message.data.data) -
FIELD_OFFSET(TFTPD_BUFFER, message.opcode) +
socket->datasize);
error = NO_ERROR;
if (socket == &globals.io.master) {
DWORD bytes = 0;
buffer->internal.io.msg.lpBuffers = &buf;
buffer->internal.io.msg.dwBufferCount = 1;
buffer->internal.io.msg.name = (LPSOCKADDR)&buffer->internal.io.peer;
buffer->internal.io.msg.namelen = sizeof(buffer->internal.io.peer);
buffer->internal.io.peerLen = sizeof(buffer->internal.io.peer);
buffer->internal.io.msg.Control.buf = (char *)&buffer->internal.io.control;
buffer->internal.io.msg.Control.len = sizeof(buffer->internal.io.control);
buffer->internal.io.msg.dwFlags = 0;
if (globals.fp.WSARecvMsg(socket->s, &buffer->internal.io.msg, &bytes,
&buffer->internal.io.overlapped, NULL) == SOCKET_ERROR)
error = WSAGetLastError();
} else {
DWORD bytes = 0;
buffer->internal.io.peerLen = sizeof(buffer->internal.io.peer);
if (WSARecvFrom(socket->s, &buf, 1, &bytes, &buffer->internal.io.flags,
(PSOCKADDR)&buffer->internal.io.peer, &buffer->internal.io.peerLen,
&buffer->internal.io.overlapped, NULL) == SOCKET_ERROR)
error = WSAGetLastError();
} // if (socket == &globals.io.master)
switch (error) {
case NO_ERROR :
if (successfulPosts < 10) {
successfulPosts++;
postedBuffers = InterlockedIncrement((PLONG)&socket->postedBuffers);
buffer = NULL;
continue;
} else {
return (postedBuffers);
}
case WSA_IO_PENDING :
return (postedBuffers);
case WSAECONNRESET :
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoPostReceiveBuffer(buffer = %p): "
"%s() failed for TID = <%s:%d>, WSAECONNRESET.\n",
buffer,
(socket == &globals.io.master) ? "WSARecvMsg" : "WSARecvFrom",
inet_ntoa(buffer->internal.io.peer.sin_addr),
ntohs(buffer->internal.io.peer.sin_port)));
TftpdProcessError(buffer);
continue;
default :
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoPostReceiveBuffer(buffer = %p): "
"WSARecvMsg/From() failed, error 0x%08X.\n",
buffer, error));
goto exit_post_buffer;
} // switch (error)
} // while (true)
exit_post_buffer :
postedBuffers = InterlockedDecrement((PLONG)&socket->postedBuffers);
if (buffer != NULL)
TftpdIoFreeBuffer(buffer);
return (postedBuffers);
} // TftpdIoPostReceiveBuffer()
void
TftpdIoSendErrorPacket(PTFTPD_BUFFER buffer, TFTPD_ERROR_CODE error, char *reason) {
DWORD bytes = 0;
WSABUF buf;
TFTPD_DEBUG((TFTPD_TRACE_IO,
"TftpdIoSendErrorPacket(buffer = %p): %s\n",
buffer, reason));
// Build the error message.
buffer->message.opcode = htons(TFTPD_ERROR);
buffer->message.error.code = htons(error);
strncpy(buffer->message.error.error, reason, buffer->internal.datasize);
buffer->message.error.error[buffer->internal.datasize - 1] = '\0';
// Send it non-blocking only. If it fails, who cares, let the client deal with it.
buf.buf = (char *)&buffer->message.opcode;
buf.len = (FIELD_OFFSET(TFTPD_BUFFER, message.error.error) -
FIELD_OFFSET(TFTPD_BUFFER, message.opcode) +
(strlen(buffer->message.error.error) + 1));
if (WSASendTo(buffer->internal.socket->s, &buf, 1, &bytes, 0,
(PSOCKADDR)&buffer->internal.io.peer, sizeof(SOCKADDR_IN),
NULL, NULL) == SOCKET_ERROR) {
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoSendErrorPacket(buffer = %p): WSASendTo() failed, error = %d.\n",
buffer, WSAGetLastError()));
}
} // TftpdIoSendErrorPacket()
PTFTPD_BUFFER
TftpdIoSendPacket(PTFTPD_BUFFER buffer) {
PTFTPD_CONTEXT context = buffer->internal.context;
DWORD bytes = 0;
WSABUF buf;
// NOTE: 'context' must be referenced before this call!
ASSERT(context != NULL);
ASSERT(context->reference >= 1);
ASSERT(buffer->internal.socket != NULL);
TFTPD_DEBUG((TFTPD_TRACE_IO,
"TftpdIoSendPacket(buffer = %p, context = %p): bytes = %d.\n",
buffer, context, buffer->internal.io.bytes));
// First try sending it non-blocking.
buf.buf = (char *)&buffer->message.opcode;
buf.len = buffer->internal.io.bytes;
if (WSASendTo(context->socket->s, &buf, 1, &bytes, 0,
(PSOCKADDR)&context->peer, sizeof(SOCKADDR_IN),
NULL, NULL) == SOCKET_ERROR) {
if (WSAGetLastError() == WSAEWOULDBLOCK) {
// Keep an overlapped-operation reference to the context.
TftpdContextAddReference(context);
// Send it overlapped. When completion occurs, we'll know it was a send
// when buffer->internal.context is non-NULL.
if (WSASendTo(context->socket->s, &buf, 1, &bytes, 0,
(PSOCKADDR)&context->peer, sizeof(SOCKADDR_IN),
&buffer->internal.io.overlapped, NULL) == SOCKET_ERROR) {
if (WSAGetLastError() != WSA_IO_PENDING) {
TFTPD_DEBUG((TFTPD_TRACE_IO,
"TftpdIoSendPacket(buffer = %p, context = %p): "
"overlapped send failed.\n",
buffer, context));
// Release the overlapped-operation reference to the context.
TftpdContextRelease(context);
goto exit_send_packet;
}
} // if (WSASendTo(...) == SOCKET_ERROR)
buffer = NULL; // Tell the caller not to recycle a buffer.
} // if (WSAGetLastError() == WSAEWOULDBLOCK)
goto exit_send_packet;
} // if (WSASendTo(...) == SOCKET_ERROR)
//
// Non-blocking send succeeded.
//
exit_send_packet :
return (buffer);
} // TftpdIoSendPacket()
void
TftpdIoLeakSocketContext(PTFTPD_SOCKET socket) {
PLIST_ENTRY entry;
EnterCriticalSection(&globals.reaper.socketCS); {
// If shutdown is occuring, we're in trouble anyways.
// Just let it go.
if (globals.service.shutdown) {
LeaveCriticalSection(&globals.reaper.socketCS);
return;
}
TFTPD_DEBUG((TFTPD_TRACE_CONTEXT,
"TftpdIoLeakSocketContext(context = %p).\n",
socket));
// Is the socket already in the list?
for (entry = globals.reaper.leakedSockets.Flink;
entry != &globals.reaper.leakedSockets;
entry = entry->Flink) {
if (CONTAINING_RECORD(entry, TFTPD_SOCKET, linkage) == socket) {
LeaveCriticalSection(&globals.reaper.socketCS);
return;
}
}
InsertHeadList(&globals.reaper.leakedSockets, &socket->linkage);
globals.reaper.numLeakedSockets++;
} LeaveCriticalSection(&globals.reaper.socketCS);
} // TftpdIoLeakSocketContext()
PTFTPD_SOCKET
TftpdIoAllocateSocketContext() {
PTFTPD_SOCKET socket = NULL;
if (globals.reaper.leakedSockets.Flink != &globals.reaper.leakedSockets) {
BOOL failAllocate = FALSE;
// Try to recover leaked contexts.
EnterCriticalSection(&globals.reaper.socketCS); {
PLIST_ENTRY entry;
while ((entry = RemoveHeadList(&globals.reaper.leakedSockets)) !=
&globals.reaper.leakedSockets) {
PTFTPD_SOCKET s = CONTAINING_RECORD(entry, TFTPD_SOCKET, linkage);
globals.reaper.numLeakedSockets--;
if (!TftpdIoDestroySocketContext(s)) {
TftpdIoLeakSocketContext(s);
failAllocate = TRUE;
break;
}
}
} LeaveCriticalSection(&globals.reaper.socketCS);
if (failAllocate)
goto exit_allocate_context;
} // if (globals.reaper.leakedSockets.Flink != &globals.reaper.leakedSockets)
socket = (PTFTPD_SOCKET)HeapAlloc(globals.hServiceHeap,
HEAP_ZERO_MEMORY,
sizeof(TFTPD_SOCKET));
exit_allocate_context :
return (socket);
} // TftpdIoAllocateSocketContext()
void
TftpdIoInitializeSocketContext(PTFTPD_SOCKET socket, PSOCKADDR_IN addr, PTFTPD_CONTEXT context) {
BOOL one = TRUE;
TFTPD_DEBUG((TFTPD_TRACE_IO,
"TftpdIoInitializeSocketContext(socket = %s): TID = <%s:%d>.\n",
((socket == &globals.io.master) ? "master" :
((socket == &globals.io.def) ? "def" :
((socket == &globals.io.mtu) ? "mtu" :
((socket == &globals.io.max) ? "max" : "private")))),
inet_ntoa(addr->sin_addr), ntohs(addr->sin_port)));
// NOTE: Do NOT zero-out 'socket', it has been initialized with some
// values we need to work with.
// Create the socket.
socket->s = WSASocket(AF_INET, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
if (socket->s == INVALID_SOCKET) {
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoInitializeSocketContext: "
"WSASocket() failed, error 0x%08X.\n",
GetLastError()));
SetLastError(WSAGetLastError());
goto fail_create_context;
}
// Ensure that we will exclusively own our local port so nobody can hijack us.
if (setsockopt(socket->s,
SOL_SOCKET,
SO_EXCLUSIVEADDRUSE,
(const char *)&one,
sizeof(one)) == SOCKET_ERROR) {
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoInitializeSocketContext: "
"setsockopt(SO_EXCLUSIVEADDRUSE) failed, error 0x%08X.\n",
GetLastError()));
SetLastError(WSAGetLastError());
goto fail_create_context;
}
// Bind the socket on the correct port.
if (bind(socket->s, (PSOCKADDR)addr, sizeof(SOCKADDR)) == SOCKET_ERROR) {
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoInitializeSocketContext: "
"bind() failed, error 0x%08X.\n",
GetLastError()));
SetLastError(WSAGetLastError());
goto fail_create_context;
}
// Register for completion callbacks on the socket.
if (!BindIoCompletionCallback((HANDLE)socket->s, TftpdIoCompletionCallback, 0)) {
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoInitializeSocketContext: "
"BindIoCompletionCallback() failed, error 0x%08X.\n",
GetLastError()));
goto fail_create_context;
}
// Indicate that we want WSARecvMsg() to fill-in packet information.
// Note we only do this on the master-socket only where we can receive TFTPD_RECV and
// TFTPD_WRITE requests, and we need to determine which socket to set the context to.
if (socket == &globals.io.master) {
// Obtain the WSARecvMsg() extension API pointer.
GUID g = WSAID_WSARECVMSG;
int opt = TRUE;
DWORD len;
if (WSAIoctl(socket->s, SIO_GET_EXTENSION_FUNCTION_POINTER, &g, sizeof(g),
&globals.fp.WSARecvMsg, sizeof(globals.fp.WSARecvMsg),
&len, NULL, NULL) == SOCKET_ERROR) {
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoInitializeSocketContext: "
"WSAIoctl() failed, error 0x%08X.\n",
WSAGetLastError()));
goto fail_create_context;
}
// Indicate that we want WSARecvMsg() to fill-in packet information.
if (setsockopt(socket->s, IPPROTO_IP, IP_PKTINFO,
(char *)&opt, sizeof(opt)) == SOCKET_ERROR) {
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoInitializeSocketContext: "
"setsockopt() failed, error 0x%08X.\n",
WSAGetLastError()));
goto fail_create_context;
}
} // if (socket == &globals.io.master)
// Record the port used for this context.
CopyMemory(&socket->addr, addr, sizeof(socket->addr));
if (context == NULL) {
// Select the socket for read and write notifications.
// Read so when we know to get data, write so when we know
// whether to do send operations non-blocking or overlapped.
if ((socket->hSelect = CreateEvent(NULL, FALSE, FALSE, NULL)) == NULL) {
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoInitializeSocketContext: "
"CreateEvent() failed, error 0x%08X.\n",
GetLastError()));
goto fail_create_context;
}
if (WSAEventSelect(socket->s, socket->hSelect, FD_READ) == SOCKET_ERROR) {
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoInitializeSocketContext: "
"WSAEventSelect() failed, error 0x%08X.\n",
GetLastError()));
SetLastError(WSAGetLastError());
goto fail_create_context;
}
// Register for FD_READ notification on the socket.
if (!RegisterWaitForSingleObject(&socket->wSelectWait,
socket->hSelect,
(WAITORTIMERCALLBACK)TftpdIoReadNotification,
socket,
INFINITE,
WT_EXECUTEINWAITTHREAD)) {
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoInitializeSocketContext: "
"RegisterWaitForSingleObject() failed, error 0x%08X.\n",
GetLastError()));
goto fail_create_context;
}
// Prepost the low water-mark number of receive buffers.
// If the FD_READ event signals on the master socket before we're done, we'll
// exceed the low water-mark here but that's harmless as the excess buffers
// will be freed upon completion.
if (!socket->lowWaterMark)
socket->lowWaterMark = 1;
if (!socket->highWaterMark)
socket->highWaterMark = 1;
SetEvent(socket->hSelect);
} else {
// Is this a private socket (ie, not master, def, mtu, or max).
// If so, it will be destroyed when it's one and only one owning context is destroyed.
socket->context = context;
// Initialize read notification variables to NULL.
socket->hSelect = NULL;
socket->wSelectWait = NULL;
socket->lowWaterMark = 1;
TftpdIoPostReceiveBuffer(socket, NULL);
} // if (context == NULL)
return;
fail_create_context :
if (socket->s != INVALID_SOCKET)
closesocket(socket->s), socket->s = INVALID_SOCKET;
if (socket->hSelect != NULL)
CloseHandle(socket->hSelect), socket->hSelect = NULL;
} // TftpdIoInitializeSocketContext()
BOOL
TftpdIoAssignSocket(PTFTPD_CONTEXT context, PTFTPD_BUFFER buffer) {
SOCKADDR_IN addr;
DWORD len = 0;
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
"TftpdIoAssignSocket(context = %p, buffer = %p).\n",
context, buffer));
if (!(buffer->internal.io.msg.dwFlags & MSG_BCAST)) {
PWSACMSGHDR header;
IN_PKTINFO *packetInfo;
// Determine if routing problems force us to use a private socket so we can corrrectly
// send datagrams to the requesting client. First, get the best interface address for
// responding to the requesting client.
ZeroMemory(&addr, sizeof(addr));
// Make the ioctl call.
WSASetLastError(NO_ERROR);
if ((WSAIoctl(globals.io.master.s, SIO_ROUTING_INTERFACE_QUERY,
&buffer->internal.io.peer, buffer->internal.io.peerLen,
&addr, sizeof(SOCKADDR_IN),
&len, NULL, NULL) == SOCKET_ERROR) ||
(len != sizeof(SOCKADDR_IN))) {
TFTPD_DEBUG((TFTPD_DBG_PROCESS,
"TftpdIoAssignSocket(): "
"WSAIoctl(SIO_ROUTING_INTERFACE_QUERY) failed, error = %d.\n",
WSAGetLastError()));
TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED,
"Failed to initialize network endpoint.");
return (FALSE);
}
// Loop through the control (ancillary) data looking for our packet info.
header = WSA_CMSG_FIRSTHDR(&buffer->internal.io.msg);
packetInfo = NULL;
while (header) {
if ((header->cmsg_level == IPPROTO_IP) && (header->cmsg_type == IP_PKTINFO)) {
packetInfo = (IN_PKTINFO *)WSA_CMSG_DATA(header);
break;
}
header = WSA_CMSG_NXTHDR(&buffer->internal.io.msg, header);
} // while (header)
// Check to see if the best interface we obtained is not the one the client sent the message to.
if ((packetInfo != NULL) &&
(addr.sin_addr.s_addr != packetInfo->ipi_addr.s_addr)) {
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
"TftpdIoAssignSocket(context = %p, buffer = %p):\n"
"\tRemote client TID = <%s:%d>\n",
context, buffer,
inet_ntoa(buffer->internal.io.peer.sin_addr),
ntohs(buffer->internal.io.peer.sin_port) ));
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
"\tRequest issued to local IP = <%s>\n",
inet_ntoa(packetInfo->ipi_addr) ));
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
"\tDefault route is over IP = <%s>\n",
inet_ntoa(addr.sin_addr) ));
// We need to create a private socket for this client.
context->socket = TftpdIoAllocateSocketContext();
if (context->socket == NULL) {
TFTPD_DEBUG((TFTPD_DBG_PROCESS,
"TftpdIoAssignSocket(): "
"TftpdIoAllocateSocketContext() failed.\n"));
TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED,
"Out of memory");
return (FALSE);
}
context->socket->s = INVALID_SOCKET;
context->socket->buffersize = (TFTPD_BUFFER_SIZE)
(FIELD_OFFSET(TFTPD_BUFFER, message.data.data) +
context->blksize);
context->socket->datasize = (TFTPD_DATA_SIZE)context->blksize;
if (!(buffer->internal.io.msg.dwFlags & MSG_BCAST)) {
ZeroMemory(&addr, sizeof(addr));
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = packetInfo->ipi_addr.s_addr;
}
TftpdIoInitializeSocketContext(context->socket, &addr, context);
if (context->socket->s == INVALID_SOCKET) {
TFTPD_DEBUG((TFTPD_DBG_PROCESS,
"TftpdIoAssignSocket(): "
"TftpdIoInitializeSocketContext() failed.\n"));
TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED,
"Failed to initialize network endpoint.");
HeapFree(globals.hServiceHeap, 0, context->socket);
context->socket = NULL;
return (FALSE);
}
#if defined(DBG)
InterlockedIncrement((PLONG)&globals.performance.privateSockets);
#endif // defined(DBG)
return (TRUE);
} // if ((packetInfo != NULL) && ...)
} else {
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
"TftpdIoAssignSocket(context = %p, buffer = %p):\n"
"\tRemote client TID = <%s:%d> issued broadcast request.\n",
context, buffer,
inet_ntoa(buffer->internal.io.peer.sin_addr), ntohs(buffer->internal.io.peer.sin_port) ));
} // if (!(buffer->internal.io.msg.dwFlags & MSG_BCAST))
ZeroMemory(&addr, sizeof(addr));
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = INADDR_ANY;
addr.sin_port = 0;
// Figure out which socket to use for this request (based on blksize).
if (context->blksize <= TFTPD_DEF_DATA) {
if (globals.io.def.s == INVALID_SOCKET) {
EnterCriticalSection(&globals.io.cs); {
if (globals.service.shutdown) {
TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED, "TFTPD service is stopping.");
LeaveCriticalSection(&globals.io.cs);
return (FALSE);
}
TftpdIoInitializeSocketContext(&globals.io.def, &addr, NULL);
if (globals.io.def.s != INVALID_SOCKET) {
context->socket = &globals.io.def;
} else {
context->socket = &globals.io.master;
if (context->options) {
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
"TftpdIoAssignSocket(): Removing requested blksize = %d "
"option since we failed to create the MTU-size socket.\n",
context->blksize));
context->options &= ~TFTPD_OPTION_BLKSIZE;
}
}
} LeaveCriticalSection(&globals.io.cs);
} else {
context->socket = &globals.io.def;
} // if (globals.io.def.s == INVALID_SOCKET)
} else {
if (context->blksize <= TFTPD_MTU_DATA) {
if (globals.io.mtu.s == INVALID_SOCKET) {
EnterCriticalSection(&globals.io.cs); {
if (globals.service.shutdown) {
TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED, "TFTPD service is stopping.");
LeaveCriticalSection(&globals.io.cs);
return (FALSE);
}
TftpdIoInitializeSocketContext(&globals.io.mtu, &addr, NULL);
if (globals.io.mtu.s != INVALID_SOCKET) {
context->socket = &globals.io.mtu;
} else {
context->socket = &globals.io.master;
if (context->options) {
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
"TftpdIoAssignSocket(): Removing requested blksize = %d "
"option since we failed to create the MTU-size socket.\n",
context->blksize));
context->options &= ~TFTPD_OPTION_BLKSIZE;
}
}
} LeaveCriticalSection(&globals.io.cs);
} else {
context->socket = &globals.io.mtu;
} // if (globals.io.mtu.s == INVALID_SOCKET)
} else if (context->blksize <= TFTPD_MAX_DATA) {
if (globals.io.max.s == INVALID_SOCKET) {
EnterCriticalSection(&globals.io.cs); {
if (globals.service.shutdown) {
TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED, "TFTPD service is stopping.");
LeaveCriticalSection(&globals.io.cs);
return (FALSE);
}
TftpdIoInitializeSocketContext(&globals.io.max, &addr, NULL);
if (globals.io.max.s != INVALID_SOCKET) {
context->socket = &globals.io.max;
} else {
context->socket = &globals.io.master;
if (context->options) {
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
"TftpdIoAssignSocket(): Removing requested blksize = %d "
"option since we failed to create the MAX-size socket.\n",
context->blksize));
context->options &= ~TFTPD_OPTION_BLKSIZE;
}
}
} LeaveCriticalSection(&globals.io.cs);
} else {
context->socket = &globals.io.max;
} // if (globals.io.max.s == INVALID_SOCKET)
}
} // (context->blksize <= TFTPD_DEF_DATA)
return (TRUE);
} // TftpdIoAssignSocket()
BOOL
TftpdIoDestroySocketContext(PTFTPD_SOCKET socket) {
NTSTATUS status;
SOCKET s;
if (socket->s == INVALID_SOCKET)
return (TRUE);
TFTPD_DEBUG((TFTPD_TRACE_IO,
"TftpdIoDestroySocketContext(socket = %s).\n",
((socket == &globals.io.master) ? "master" :
((socket == &globals.io.def) ? "def" :
((socket == &globals.io.mtu) ? "mtu" :
((socket == &globals.io.max) ? "max" :
"private")))) ));
// Disable further buffer posting.
socket->lowWaterMark = 0;
if (socket->context == NULL) {
if (!UnregisterWait(socket->wSelectWait)) {
DWORD error;
if ((error = GetLastError()) != ERROR_IO_PENDING) {
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoDestroySocketContext: "
"UnregisterWait() failed, error 0x%08X.\n",
error));
TftpdIoLeakSocketContext(socket);
return (FALSE);
}
}
socket->wSelectWait = NULL;
CloseHandle(socket->hSelect);
socket->hSelect = NULL;
} // if (socket->context == NULL)
// Kill the socket. This will disable the FD_READ and FD_WRITE
// event select, as well as cancel all pending overlapped operations
// on it. Add a buffer reference here so after we close the
// socket we can test if there were never any buffers posted
// which would cancel above in TftpdIoCompletionCallback so
// we should deallocate socket here.
// Kill it.
InterlockedIncrement((PLONG)&socket->numBuffers);
s = socket->s;
socket->s = INVALID_SOCKET;
if (closesocket(s) == SOCKET_ERROR) {
TFTPD_DEBUG((TFTPD_DBG_IO,
"TftpdIoDestroySocketContext: "
"closesocket() failed, error 0x%08X.\n",
GetLastError()));
socket->s = s;
InterlockedDecrement((PLONG)&socket->numBuffers);
TftpdIoLeakSocketContext(socket);
return (FALSE);
}
if (InterlockedDecrement((PLONG)&socket->numBuffers) == -1)
HeapFree(globals.hServiceHeap, 0, socket);
return (TRUE);
} // TftpdIoDestroySocketContext()