You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1042 lines
37 KiB
1042 lines
37 KiB
/*++
|
|
|
|
Copyright (c) 2001 Microsoft Corporation
|
|
|
|
Module Name:
|
|
|
|
io.c
|
|
|
|
Abstract:
|
|
|
|
This module contains functions to manage all socket I/O
|
|
between the server and clients, including socket management
|
|
and overlapped completion indication. It also contains
|
|
buffer management.
|
|
|
|
Author:
|
|
|
|
Jeffrey C. Venable, Sr. (jeffv) 01-Jun-2001
|
|
|
|
Revision History:
|
|
|
|
--*/
|
|
|
|
#include "precomp.h"
|
|
|
|
|
|
void
|
|
TftpdIoFreeBuffer(PTFTPD_BUFFER buffer) {
|
|
|
|
PTFTPD_SOCKET socket = buffer->internal.socket;
|
|
|
|
TFTPD_DEBUG((TFTPD_TRACE_IO,
|
|
"TftpdIoFreeBuffer(buffer = %p).\n",
|
|
buffer));
|
|
|
|
HeapFree(globals.hServiceHeap, 0, buffer);
|
|
|
|
if ((InterlockedDecrement((PLONG)&socket->numBuffers) == -1) &&
|
|
(socket->context != NULL))
|
|
HeapFree(globals.hServiceHeap, 0, socket);
|
|
|
|
if (InterlockedDecrement(&globals.io.numBuffers) == -1)
|
|
TftpdServiceAttemptCleanup();
|
|
|
|
} // TftpdIoFreeBuffer()
|
|
|
|
|
|
PTFTPD_BUFFER
|
|
TftpdIoAllocateBuffer(PTFTPD_SOCKET socket) {
|
|
|
|
PTFTPD_BUFFER buffer;
|
|
|
|
TFTPD_DEBUG((TFTPD_TRACE_IO,
|
|
"TftpdIoAllocateBuffer(socket = %s).\n",
|
|
((socket == &globals.io.master) ? "master" :
|
|
((socket == &globals.io.def) ? "def" :
|
|
((socket == &globals.io.mtu) ? "mtu" :
|
|
((socket == &globals.io.max) ? "max" :
|
|
"private")))) ));
|
|
|
|
buffer = (PTFTPD_BUFFER)HeapAlloc(globals.hServiceHeap, 0,
|
|
socket->buffersize);
|
|
if (buffer == NULL) {
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoAllocateBuffer(socket = %s): "
|
|
"HeapAlloc() failed, error 0x%08X.\n",
|
|
((socket == &globals.io.master) ? "master" :
|
|
((socket == &globals.io.def) ? "def" :
|
|
((socket == &globals.io.mtu) ? "mtu" :
|
|
((socket == &globals.io.max) ? "max" :
|
|
"private")))), GetLastError()));
|
|
return (NULL);
|
|
}
|
|
ZeroMemory(buffer, sizeof(buffer->internal));
|
|
|
|
InterlockedIncrement(&globals.io.numBuffers);
|
|
InterlockedIncrement((PLONG)&socket->numBuffers);
|
|
|
|
buffer->internal.socket = socket;
|
|
buffer->internal.datasize = socket->datasize;
|
|
|
|
if (globals.service.shutdown) {
|
|
TftpdIoFreeBuffer(buffer);
|
|
buffer = NULL;
|
|
}
|
|
|
|
return (buffer);
|
|
|
|
} // TftpdIoAllocateBuffer()
|
|
|
|
|
|
PTFTPD_BUFFER
|
|
TftpdIoSwapBuffer(PTFTPD_BUFFER buffer, PTFTPD_SOCKET socket) {
|
|
|
|
PTFTPD_BUFFER tmp;
|
|
|
|
ASSERT((buffer->message.opcode == TFTPD_RRQ) ||
|
|
(buffer->message.opcode == TFTPD_WRQ));
|
|
|
|
// Allocate a buffer for the new socket.
|
|
tmp = TftpdIoAllocateBuffer(socket);
|
|
|
|
// Copy information we need to retain.
|
|
if (tmp != NULL) {
|
|
|
|
tmp->internal.context = buffer->internal.context;
|
|
tmp->internal.io.peerLen = buffer->internal.io.peerLen;
|
|
CopyMemory(&tmp->internal.io.peer,
|
|
&buffer->internal.io.peer,
|
|
buffer->internal.io.peerLen);
|
|
CopyMemory(&tmp->internal.io.msg,
|
|
&buffer->internal.io.msg,
|
|
sizeof(tmp->internal.io.msg));
|
|
CopyMemory(&tmp->internal.io.control,
|
|
&buffer->internal.io.control,
|
|
sizeof(tmp->internal.io.control));
|
|
|
|
} // if (tmp != NULL)
|
|
|
|
TFTPD_DEBUG((TFTPD_TRACE_IO,
|
|
"TftpdIoCompletionCallback(buffer = %p): "
|
|
"new buffer = %p.\n",
|
|
buffer, tmp));
|
|
|
|
// Return the original buffer.
|
|
TftpdIoPostReceiveBuffer(buffer->internal.socket, buffer);
|
|
|
|
return (tmp);
|
|
|
|
} // TftpdIoSwapBuffer()
|
|
|
|
|
|
void
|
|
TftpdIoCompletionCallback(DWORD dwErrorCode,
|
|
DWORD dwBytes,
|
|
LPOVERLAPPED overlapped) {
|
|
|
|
PTFTPD_BUFFER buffer = CONTAINING_RECORD(overlapped, TFTPD_BUFFER,
|
|
internal.io.overlapped);
|
|
PTFTPD_CONTEXT context = buffer->internal.context;
|
|
PTFTPD_SOCKET socket = buffer->internal.socket;
|
|
|
|
TFTPD_DEBUG((TFTPD_TRACE_IO,
|
|
"TftpdIoCompletionCallback(buffer = %p): bytes = %d.\n",
|
|
buffer, dwBytes));
|
|
|
|
if (context == NULL)
|
|
InterlockedDecrement((PLONG)&socket->postedBuffers);
|
|
|
|
switch (dwErrorCode) {
|
|
|
|
case STATUS_SUCCESS :
|
|
|
|
if (context == NULL) {
|
|
|
|
if (dwBytes < TFTPD_MIN_RECEIVED_DATA)
|
|
goto exit_completion_callback;
|
|
|
|
buffer->internal.io.bytes = dwBytes;
|
|
buffer = TftpdProcessReceivedBuffer(buffer);
|
|
|
|
} // if (context == NULL)
|
|
break;
|
|
|
|
case STATUS_PORT_UNREACHABLE :
|
|
|
|
TFTPD_DEBUG((TFTPD_TRACE_IO,
|
|
"TftpdIoCompletionCallback(buffer = %p, context = %p): "
|
|
"STATUS_PORT_UNREACHABLE.\n",
|
|
buffer, context));
|
|
// If this was a write operation, kill the context.
|
|
if (context != NULL) {
|
|
TftpdProcessError(buffer);
|
|
context = NULL;
|
|
}
|
|
goto exit_completion_callback;
|
|
|
|
case STATUS_CANCELLED :
|
|
|
|
// If this was a write operation, kill the context.
|
|
if (context != NULL) {
|
|
TFTPD_DEBUG((TFTPD_TRACE_IO,
|
|
"TftpdIoCompletionCallback(buffer = %p, context = %p): "
|
|
"STATUS_CANCELLED.\n",
|
|
buffer, context));
|
|
TftpdProcessError(buffer);
|
|
context = NULL;
|
|
}
|
|
|
|
TftpdIoFreeBuffer(buffer);
|
|
buffer = NULL;
|
|
|
|
goto exit_completion_callback;
|
|
|
|
default :
|
|
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoCompletionCallback(buffer = %p): "
|
|
"dwErrorcode = 0x%08X.\n",
|
|
buffer, dwErrorCode));
|
|
goto exit_completion_callback;
|
|
|
|
} // switch (dwErrorCode)
|
|
|
|
exit_completion_callback :
|
|
|
|
if (context != NULL) {
|
|
|
|
// Do we bother reposting the buffer?
|
|
if (context->state & TFTPD_STATE_DEAD) {
|
|
TftpdIoFreeBuffer(buffer);
|
|
buffer = NULL;
|
|
}
|
|
|
|
// Release the overlapped send reference.
|
|
TftpdContextRelease(context);
|
|
|
|
} // if (context != NULL)
|
|
|
|
if (buffer != NULL)
|
|
TftpdIoPostReceiveBuffer(buffer->internal.socket, buffer);
|
|
|
|
} // TftpdIoCompletionCallback()
|
|
|
|
|
|
void CALLBACK
|
|
TftpdIoReadNotification(PTFTPD_SOCKET socket, BOOLEAN timeout) {
|
|
|
|
TFTPD_DEBUG((TFTPD_TRACE_IO,
|
|
"TftpdIoReadNotification(socket = %s).\n",
|
|
((socket == &globals.io.master) ? "master" :
|
|
((socket == &globals.io.def) ? "def" :
|
|
((socket == &globals.io.mtu) ? "mtu" :
|
|
((socket == &globals.io.max) ? "max" :
|
|
"private")))) ));
|
|
|
|
// If this fails, the event triggering this callback will stop signalling
|
|
// due to a lack of a successful WSARecvFrom() ... this will likely occur
|
|
// during low-memory/stress conditions. When the system returns to normal,
|
|
// the low water-mark buffers will be reposted, thus receiving data and
|
|
// re-enabling the event which triggers this callback.
|
|
while (!globals.service.shutdown)
|
|
if (TftpdIoPostReceiveBuffer(socket, NULL) >= socket->lowWaterMark)
|
|
break;
|
|
|
|
} // TftpdIoReadNotification()
|
|
|
|
|
|
DWORD
|
|
TftpdIoPostReceiveBuffer(PTFTPD_SOCKET socket, PTFTPD_BUFFER buffer) {
|
|
|
|
DWORD postedBuffers = 0, successfulPosts = 0;
|
|
int error;
|
|
|
|
TFTPD_DEBUG((TFTPD_TRACE_IO,
|
|
"TftpdIoPostReceiveBuffer(buffer = %p, socket = %s).\n",
|
|
buffer,
|
|
((socket == &globals.io.master) ? "master" :
|
|
((socket == &globals.io.def) ? "def" :
|
|
((socket == &globals.io.mtu) ? "mtu" :
|
|
((socket == &globals.io.max) ? "max" :
|
|
"private")))) ));
|
|
|
|
postedBuffers = InterlockedIncrement((PLONG)&socket->postedBuffers);
|
|
|
|
//
|
|
// Attempt to post a buffer:
|
|
//
|
|
|
|
while (TRUE) {
|
|
|
|
WSABUF buf;
|
|
|
|
if (globals.service.shutdown ||
|
|
(postedBuffers > globals.parameters.highWaterMark))
|
|
goto exit_post_buffer;
|
|
|
|
// Allocate the buffer if we're not reusing one.
|
|
if (buffer == NULL) {
|
|
|
|
buffer = TftpdIoAllocateBuffer(socket);
|
|
if (buffer == NULL) {
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoPostReceiveBuffer(buffer = %p): "
|
|
"TftpdIoAllocateBuffer() failed.\n",
|
|
buffer));
|
|
goto exit_post_buffer;
|
|
}
|
|
TFTPD_DEBUG((TFTPD_TRACE_IO,
|
|
"TftpdIoPostReceiveBuffer(buffer = %p).\n",
|
|
buffer));
|
|
|
|
} else {
|
|
|
|
if (socket->s == INVALID_SOCKET)
|
|
goto exit_post_buffer;
|
|
|
|
ASSERT(buffer->internal.socket == socket);
|
|
ZeroMemory(buffer, sizeof(buffer->internal));
|
|
buffer->internal.socket = socket;
|
|
buffer->internal.datasize = socket->datasize;
|
|
|
|
} // if (buffer == NULL)
|
|
|
|
buf.buf = ((char *)buffer + FIELD_OFFSET(TFTPD_BUFFER, message.opcode));
|
|
buf.len = (FIELD_OFFSET(TFTPD_BUFFER, message.data.data) -
|
|
FIELD_OFFSET(TFTPD_BUFFER, message.opcode) +
|
|
socket->datasize);
|
|
|
|
error = NO_ERROR;
|
|
|
|
if (socket == &globals.io.master) {
|
|
|
|
DWORD bytes = 0;
|
|
buffer->internal.io.msg.lpBuffers = &buf;
|
|
buffer->internal.io.msg.dwBufferCount = 1;
|
|
buffer->internal.io.msg.name = (LPSOCKADDR)&buffer->internal.io.peer;
|
|
buffer->internal.io.msg.namelen = sizeof(buffer->internal.io.peer);
|
|
buffer->internal.io.peerLen = sizeof(buffer->internal.io.peer);
|
|
buffer->internal.io.msg.Control.buf = (char *)&buffer->internal.io.control;
|
|
buffer->internal.io.msg.Control.len = sizeof(buffer->internal.io.control);
|
|
buffer->internal.io.msg.dwFlags = 0;
|
|
if (globals.fp.WSARecvMsg(socket->s, &buffer->internal.io.msg, &bytes,
|
|
&buffer->internal.io.overlapped, NULL) == SOCKET_ERROR)
|
|
error = WSAGetLastError();
|
|
|
|
} else {
|
|
|
|
DWORD bytes = 0;
|
|
buffer->internal.io.peerLen = sizeof(buffer->internal.io.peer);
|
|
if (WSARecvFrom(socket->s, &buf, 1, &bytes, &buffer->internal.io.flags,
|
|
(PSOCKADDR)&buffer->internal.io.peer, &buffer->internal.io.peerLen,
|
|
&buffer->internal.io.overlapped, NULL) == SOCKET_ERROR)
|
|
error = WSAGetLastError();
|
|
|
|
} // if (socket == &globals.io.master)
|
|
|
|
switch (error) {
|
|
|
|
case NO_ERROR :
|
|
if (successfulPosts < 10) {
|
|
successfulPosts++;
|
|
postedBuffers = InterlockedIncrement((PLONG)&socket->postedBuffers);
|
|
buffer = NULL;
|
|
continue;
|
|
} else {
|
|
return (postedBuffers);
|
|
}
|
|
|
|
case WSA_IO_PENDING :
|
|
return (postedBuffers);
|
|
|
|
case WSAECONNRESET :
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoPostReceiveBuffer(buffer = %p): "
|
|
"%s() failed for TID = <%s:%d>, WSAECONNRESET.\n",
|
|
buffer,
|
|
(socket == &globals.io.master) ? "WSARecvMsg" : "WSARecvFrom",
|
|
inet_ntoa(buffer->internal.io.peer.sin_addr),
|
|
ntohs(buffer->internal.io.peer.sin_port)));
|
|
TftpdProcessError(buffer);
|
|
continue;
|
|
|
|
default :
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoPostReceiveBuffer(buffer = %p): "
|
|
"WSARecvMsg/From() failed, error 0x%08X.\n",
|
|
buffer, error));
|
|
goto exit_post_buffer;
|
|
|
|
} // switch (error)
|
|
|
|
} // while (true)
|
|
|
|
exit_post_buffer :
|
|
|
|
postedBuffers = InterlockedDecrement((PLONG)&socket->postedBuffers);
|
|
if (buffer != NULL)
|
|
TftpdIoFreeBuffer(buffer);
|
|
|
|
return (postedBuffers);
|
|
|
|
} // TftpdIoPostReceiveBuffer()
|
|
|
|
|
|
void
|
|
TftpdIoSendErrorPacket(PTFTPD_BUFFER buffer, TFTPD_ERROR_CODE error, char *reason) {
|
|
|
|
DWORD bytes = 0;
|
|
WSABUF buf;
|
|
|
|
TFTPD_DEBUG((TFTPD_TRACE_IO,
|
|
"TftpdIoSendErrorPacket(buffer = %p): %s\n",
|
|
buffer, reason));
|
|
|
|
// Build the error message.
|
|
buffer->message.opcode = htons(TFTPD_ERROR);
|
|
buffer->message.error.code = htons(error);
|
|
strncpy(buffer->message.error.error, reason, buffer->internal.datasize);
|
|
buffer->message.error.error[buffer->internal.datasize - 1] = '\0';
|
|
|
|
// Send it non-blocking only. If it fails, who cares, let the client deal with it.
|
|
buf.buf = (char *)&buffer->message.opcode;
|
|
buf.len = (FIELD_OFFSET(TFTPD_BUFFER, message.error.error) -
|
|
FIELD_OFFSET(TFTPD_BUFFER, message.opcode) +
|
|
(strlen(buffer->message.error.error) + 1));
|
|
if (WSASendTo(buffer->internal.socket->s, &buf, 1, &bytes, 0,
|
|
(PSOCKADDR)&buffer->internal.io.peer, sizeof(SOCKADDR_IN),
|
|
NULL, NULL) == SOCKET_ERROR) {
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoSendErrorPacket(buffer = %p): WSASendTo() failed, error = %d.\n",
|
|
buffer, WSAGetLastError()));
|
|
}
|
|
|
|
} // TftpdIoSendErrorPacket()
|
|
|
|
|
|
PTFTPD_BUFFER
|
|
TftpdIoSendPacket(PTFTPD_BUFFER buffer) {
|
|
|
|
PTFTPD_CONTEXT context = buffer->internal.context;
|
|
DWORD bytes = 0;
|
|
WSABUF buf;
|
|
|
|
// NOTE: 'context' must be referenced before this call!
|
|
ASSERT(context != NULL);
|
|
ASSERT(context->reference >= 1);
|
|
ASSERT(buffer->internal.socket != NULL);
|
|
|
|
TFTPD_DEBUG((TFTPD_TRACE_IO,
|
|
"TftpdIoSendPacket(buffer = %p, context = %p): bytes = %d.\n",
|
|
buffer, context, buffer->internal.io.bytes));
|
|
|
|
// First try sending it non-blocking.
|
|
buf.buf = (char *)&buffer->message.opcode;
|
|
buf.len = buffer->internal.io.bytes;
|
|
if (WSASendTo(context->socket->s, &buf, 1, &bytes, 0,
|
|
(PSOCKADDR)&context->peer, sizeof(SOCKADDR_IN),
|
|
NULL, NULL) == SOCKET_ERROR) {
|
|
|
|
if (WSAGetLastError() == WSAEWOULDBLOCK) {
|
|
|
|
// Keep an overlapped-operation reference to the context.
|
|
TftpdContextAddReference(context);
|
|
|
|
// Send it overlapped. When completion occurs, we'll know it was a send
|
|
// when buffer->internal.context is non-NULL.
|
|
if (WSASendTo(context->socket->s, &buf, 1, &bytes, 0,
|
|
(PSOCKADDR)&context->peer, sizeof(SOCKADDR_IN),
|
|
&buffer->internal.io.overlapped, NULL) == SOCKET_ERROR) {
|
|
|
|
if (WSAGetLastError() != WSA_IO_PENDING) {
|
|
TFTPD_DEBUG((TFTPD_TRACE_IO,
|
|
"TftpdIoSendPacket(buffer = %p, context = %p): "
|
|
"overlapped send failed.\n",
|
|
buffer, context));
|
|
// Release the overlapped-operation reference to the context.
|
|
TftpdContextRelease(context);
|
|
goto exit_send_packet;
|
|
}
|
|
|
|
} // if (WSASendTo(...) == SOCKET_ERROR)
|
|
|
|
buffer = NULL; // Tell the caller not to recycle a buffer.
|
|
|
|
} // if (WSAGetLastError() == WSAEWOULDBLOCK)
|
|
|
|
goto exit_send_packet;
|
|
|
|
} // if (WSASendTo(...) == SOCKET_ERROR)
|
|
|
|
//
|
|
// Non-blocking send succeeded.
|
|
//
|
|
|
|
exit_send_packet :
|
|
|
|
return (buffer);
|
|
|
|
} // TftpdIoSendPacket()
|
|
|
|
|
|
void
|
|
TftpdIoLeakSocketContext(PTFTPD_SOCKET socket) {
|
|
|
|
PLIST_ENTRY entry;
|
|
|
|
EnterCriticalSection(&globals.reaper.socketCS); {
|
|
|
|
// If shutdown is occuring, we're in trouble anyways.
|
|
// Just let it go.
|
|
if (globals.service.shutdown) {
|
|
LeaveCriticalSection(&globals.reaper.socketCS);
|
|
return;
|
|
}
|
|
|
|
TFTPD_DEBUG((TFTPD_TRACE_CONTEXT,
|
|
"TftpdIoLeakSocketContext(context = %p).\n",
|
|
socket));
|
|
|
|
// Is the socket already in the list?
|
|
for (entry = globals.reaper.leakedSockets.Flink;
|
|
entry != &globals.reaper.leakedSockets;
|
|
entry = entry->Flink) {
|
|
|
|
if (CONTAINING_RECORD(entry, TFTPD_SOCKET, linkage) == socket) {
|
|
LeaveCriticalSection(&globals.reaper.socketCS);
|
|
return;
|
|
}
|
|
|
|
}
|
|
|
|
InsertHeadList(&globals.reaper.leakedSockets, &socket->linkage);
|
|
globals.reaper.numLeakedSockets++;
|
|
|
|
} LeaveCriticalSection(&globals.reaper.socketCS);
|
|
|
|
} // TftpdIoLeakSocketContext()
|
|
|
|
|
|
PTFTPD_SOCKET
|
|
TftpdIoAllocateSocketContext() {
|
|
|
|
PTFTPD_SOCKET socket = NULL;
|
|
|
|
if (globals.reaper.leakedSockets.Flink != &globals.reaper.leakedSockets) {
|
|
|
|
BOOL failAllocate = FALSE;
|
|
|
|
// Try to recover leaked contexts.
|
|
EnterCriticalSection(&globals.reaper.socketCS); {
|
|
|
|
PLIST_ENTRY entry;
|
|
while ((entry = RemoveHeadList(&globals.reaper.leakedSockets)) !=
|
|
&globals.reaper.leakedSockets) {
|
|
|
|
PTFTPD_SOCKET s = CONTAINING_RECORD(entry, TFTPD_SOCKET, linkage);
|
|
|
|
globals.reaper.numLeakedSockets--;
|
|
if (!TftpdIoDestroySocketContext(s)) {
|
|
TftpdIoLeakSocketContext(s);
|
|
failAllocate = TRUE;
|
|
break;
|
|
}
|
|
|
|
}
|
|
|
|
} LeaveCriticalSection(&globals.reaper.socketCS);
|
|
|
|
if (failAllocate)
|
|
goto exit_allocate_context;
|
|
|
|
} // if (globals.reaper.leakedSockets.Flink != &globals.reaper.leakedSockets)
|
|
|
|
socket = (PTFTPD_SOCKET)HeapAlloc(globals.hServiceHeap,
|
|
HEAP_ZERO_MEMORY,
|
|
sizeof(TFTPD_SOCKET));
|
|
|
|
exit_allocate_context :
|
|
|
|
return (socket);
|
|
|
|
} // TftpdIoAllocateSocketContext()
|
|
|
|
|
|
void
|
|
TftpdIoInitializeSocketContext(PTFTPD_SOCKET socket, PSOCKADDR_IN addr, PTFTPD_CONTEXT context) {
|
|
|
|
BOOL one = TRUE;
|
|
|
|
TFTPD_DEBUG((TFTPD_TRACE_IO,
|
|
"TftpdIoInitializeSocketContext(socket = %s): TID = <%s:%d>.\n",
|
|
((socket == &globals.io.master) ? "master" :
|
|
((socket == &globals.io.def) ? "def" :
|
|
((socket == &globals.io.mtu) ? "mtu" :
|
|
((socket == &globals.io.max) ? "max" : "private")))),
|
|
inet_ntoa(addr->sin_addr), ntohs(addr->sin_port)));
|
|
|
|
// NOTE: Do NOT zero-out 'socket', it has been initialized with some
|
|
// values we need to work with.
|
|
|
|
// Create the socket.
|
|
socket->s = WSASocket(AF_INET, SOCK_DGRAM, 0, NULL, 0, WSA_FLAG_OVERLAPPED);
|
|
if (socket->s == INVALID_SOCKET) {
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoInitializeSocketContext: "
|
|
"WSASocket() failed, error 0x%08X.\n",
|
|
GetLastError()));
|
|
SetLastError(WSAGetLastError());
|
|
goto fail_create_context;
|
|
}
|
|
|
|
// Ensure that we will exclusively own our local port so nobody can hijack us.
|
|
if (setsockopt(socket->s,
|
|
SOL_SOCKET,
|
|
SO_EXCLUSIVEADDRUSE,
|
|
(const char *)&one,
|
|
sizeof(one)) == SOCKET_ERROR) {
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoInitializeSocketContext: "
|
|
"setsockopt(SO_EXCLUSIVEADDRUSE) failed, error 0x%08X.\n",
|
|
GetLastError()));
|
|
SetLastError(WSAGetLastError());
|
|
goto fail_create_context;
|
|
}
|
|
|
|
// Bind the socket on the correct port.
|
|
if (bind(socket->s, (PSOCKADDR)addr, sizeof(SOCKADDR)) == SOCKET_ERROR) {
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoInitializeSocketContext: "
|
|
"bind() failed, error 0x%08X.\n",
|
|
GetLastError()));
|
|
SetLastError(WSAGetLastError());
|
|
goto fail_create_context;
|
|
}
|
|
|
|
// Register for completion callbacks on the socket.
|
|
if (!BindIoCompletionCallback((HANDLE)socket->s, TftpdIoCompletionCallback, 0)) {
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoInitializeSocketContext: "
|
|
"BindIoCompletionCallback() failed, error 0x%08X.\n",
|
|
GetLastError()));
|
|
goto fail_create_context;
|
|
}
|
|
|
|
// Indicate that we want WSARecvMsg() to fill-in packet information.
|
|
// Note we only do this on the master-socket only where we can receive TFTPD_RECV and
|
|
// TFTPD_WRITE requests, and we need to determine which socket to set the context to.
|
|
if (socket == &globals.io.master) {
|
|
|
|
// Obtain the WSARecvMsg() extension API pointer.
|
|
GUID g = WSAID_WSARECVMSG;
|
|
int opt = TRUE;
|
|
DWORD len;
|
|
|
|
if (WSAIoctl(socket->s, SIO_GET_EXTENSION_FUNCTION_POINTER, &g, sizeof(g),
|
|
&globals.fp.WSARecvMsg, sizeof(globals.fp.WSARecvMsg),
|
|
&len, NULL, NULL) == SOCKET_ERROR) {
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoInitializeSocketContext: "
|
|
"WSAIoctl() failed, error 0x%08X.\n",
|
|
WSAGetLastError()));
|
|
goto fail_create_context;
|
|
}
|
|
|
|
// Indicate that we want WSARecvMsg() to fill-in packet information.
|
|
if (setsockopt(socket->s, IPPROTO_IP, IP_PKTINFO,
|
|
(char *)&opt, sizeof(opt)) == SOCKET_ERROR) {
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoInitializeSocketContext: "
|
|
"setsockopt() failed, error 0x%08X.\n",
|
|
WSAGetLastError()));
|
|
goto fail_create_context;
|
|
}
|
|
|
|
} // if (socket == &globals.io.master)
|
|
|
|
// Record the port used for this context.
|
|
CopyMemory(&socket->addr, addr, sizeof(socket->addr));
|
|
|
|
if (context == NULL) {
|
|
|
|
// Select the socket for read and write notifications.
|
|
// Read so when we know to get data, write so when we know
|
|
// whether to do send operations non-blocking or overlapped.
|
|
if ((socket->hSelect = CreateEvent(NULL, FALSE, FALSE, NULL)) == NULL) {
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoInitializeSocketContext: "
|
|
"CreateEvent() failed, error 0x%08X.\n",
|
|
GetLastError()));
|
|
goto fail_create_context;
|
|
}
|
|
|
|
if (WSAEventSelect(socket->s, socket->hSelect, FD_READ) == SOCKET_ERROR) {
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoInitializeSocketContext: "
|
|
"WSAEventSelect() failed, error 0x%08X.\n",
|
|
GetLastError()));
|
|
SetLastError(WSAGetLastError());
|
|
goto fail_create_context;
|
|
}
|
|
|
|
// Register for FD_READ notification on the socket.
|
|
if (!RegisterWaitForSingleObject(&socket->wSelectWait,
|
|
socket->hSelect,
|
|
(WAITORTIMERCALLBACK)TftpdIoReadNotification,
|
|
socket,
|
|
INFINITE,
|
|
WT_EXECUTEINWAITTHREAD)) {
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoInitializeSocketContext: "
|
|
"RegisterWaitForSingleObject() failed, error 0x%08X.\n",
|
|
GetLastError()));
|
|
goto fail_create_context;
|
|
}
|
|
|
|
// Prepost the low water-mark number of receive buffers.
|
|
// If the FD_READ event signals on the master socket before we're done, we'll
|
|
// exceed the low water-mark here but that's harmless as the excess buffers
|
|
// will be freed upon completion.
|
|
if (!socket->lowWaterMark)
|
|
socket->lowWaterMark = 1;
|
|
if (!socket->highWaterMark)
|
|
socket->highWaterMark = 1;
|
|
|
|
SetEvent(socket->hSelect);
|
|
|
|
} else {
|
|
|
|
// Is this a private socket (ie, not master, def, mtu, or max).
|
|
// If so, it will be destroyed when it's one and only one owning context is destroyed.
|
|
socket->context = context;
|
|
|
|
// Initialize read notification variables to NULL.
|
|
socket->hSelect = NULL;
|
|
socket->wSelectWait = NULL;
|
|
socket->lowWaterMark = 1;
|
|
|
|
TftpdIoPostReceiveBuffer(socket, NULL);
|
|
|
|
} // if (context == NULL)
|
|
|
|
return;
|
|
|
|
fail_create_context :
|
|
|
|
if (socket->s != INVALID_SOCKET)
|
|
closesocket(socket->s), socket->s = INVALID_SOCKET;
|
|
if (socket->hSelect != NULL)
|
|
CloseHandle(socket->hSelect), socket->hSelect = NULL;
|
|
|
|
} // TftpdIoInitializeSocketContext()
|
|
|
|
|
|
BOOL
|
|
TftpdIoAssignSocket(PTFTPD_CONTEXT context, PTFTPD_BUFFER buffer) {
|
|
|
|
SOCKADDR_IN addr;
|
|
DWORD len = 0;
|
|
|
|
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
|
|
"TftpdIoAssignSocket(context = %p, buffer = %p).\n",
|
|
context, buffer));
|
|
|
|
if (!(buffer->internal.io.msg.dwFlags & MSG_BCAST)) {
|
|
|
|
PWSACMSGHDR header;
|
|
IN_PKTINFO *packetInfo;
|
|
|
|
// Determine if routing problems force us to use a private socket so we can corrrectly
|
|
// send datagrams to the requesting client. First, get the best interface address for
|
|
// responding to the requesting client.
|
|
ZeroMemory(&addr, sizeof(addr));
|
|
|
|
// Make the ioctl call.
|
|
WSASetLastError(NO_ERROR);
|
|
if ((WSAIoctl(globals.io.master.s, SIO_ROUTING_INTERFACE_QUERY,
|
|
&buffer->internal.io.peer, buffer->internal.io.peerLen,
|
|
&addr, sizeof(SOCKADDR_IN),
|
|
&len, NULL, NULL) == SOCKET_ERROR) ||
|
|
(len != sizeof(SOCKADDR_IN))) {
|
|
TFTPD_DEBUG((TFTPD_DBG_PROCESS,
|
|
"TftpdIoAssignSocket(): "
|
|
"WSAIoctl(SIO_ROUTING_INTERFACE_QUERY) failed, error = %d.\n",
|
|
WSAGetLastError()));
|
|
TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED,
|
|
"Failed to initialize network endpoint.");
|
|
return (FALSE);
|
|
}
|
|
|
|
// Loop through the control (ancillary) data looking for our packet info.
|
|
header = WSA_CMSG_FIRSTHDR(&buffer->internal.io.msg);
|
|
packetInfo = NULL;
|
|
while (header) {
|
|
|
|
if ((header->cmsg_level == IPPROTO_IP) && (header->cmsg_type == IP_PKTINFO)) {
|
|
packetInfo = (IN_PKTINFO *)WSA_CMSG_DATA(header);
|
|
break;
|
|
}
|
|
|
|
header = WSA_CMSG_NXTHDR(&buffer->internal.io.msg, header);
|
|
|
|
} // while (header)
|
|
|
|
// Check to see if the best interface we obtained is not the one the client sent the message to.
|
|
if ((packetInfo != NULL) &&
|
|
(addr.sin_addr.s_addr != packetInfo->ipi_addr.s_addr)) {
|
|
|
|
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
|
|
"TftpdIoAssignSocket(context = %p, buffer = %p):\n"
|
|
"\tRemote client TID = <%s:%d>\n",
|
|
context, buffer,
|
|
inet_ntoa(buffer->internal.io.peer.sin_addr),
|
|
ntohs(buffer->internal.io.peer.sin_port) ));
|
|
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
|
|
"\tRequest issued to local IP = <%s>\n",
|
|
inet_ntoa(packetInfo->ipi_addr) ));
|
|
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
|
|
"\tDefault route is over IP = <%s>\n",
|
|
inet_ntoa(addr.sin_addr) ));
|
|
|
|
// We need to create a private socket for this client.
|
|
context->socket = TftpdIoAllocateSocketContext();
|
|
if (context->socket == NULL) {
|
|
TFTPD_DEBUG((TFTPD_DBG_PROCESS,
|
|
"TftpdIoAssignSocket(): "
|
|
"TftpdIoAllocateSocketContext() failed.\n"));
|
|
TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED,
|
|
"Out of memory");
|
|
return (FALSE);
|
|
}
|
|
context->socket->s = INVALID_SOCKET;
|
|
context->socket->buffersize = (TFTPD_BUFFER_SIZE)
|
|
(FIELD_OFFSET(TFTPD_BUFFER, message.data.data) +
|
|
context->blksize);
|
|
context->socket->datasize = (TFTPD_DATA_SIZE)context->blksize;
|
|
|
|
if (!(buffer->internal.io.msg.dwFlags & MSG_BCAST)) {
|
|
ZeroMemory(&addr, sizeof(addr));
|
|
addr.sin_family = AF_INET;
|
|
addr.sin_addr.s_addr = packetInfo->ipi_addr.s_addr;
|
|
}
|
|
|
|
TftpdIoInitializeSocketContext(context->socket, &addr, context);
|
|
if (context->socket->s == INVALID_SOCKET) {
|
|
TFTPD_DEBUG((TFTPD_DBG_PROCESS,
|
|
"TftpdIoAssignSocket(): "
|
|
"TftpdIoInitializeSocketContext() failed.\n"));
|
|
TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED,
|
|
"Failed to initialize network endpoint.");
|
|
HeapFree(globals.hServiceHeap, 0, context->socket);
|
|
context->socket = NULL;
|
|
return (FALSE);
|
|
}
|
|
|
|
#if defined(DBG)
|
|
InterlockedIncrement((PLONG)&globals.performance.privateSockets);
|
|
#endif // defined(DBG)
|
|
|
|
return (TRUE);
|
|
|
|
} // if ((packetInfo != NULL) && ...)
|
|
|
|
} else {
|
|
|
|
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
|
|
"TftpdIoAssignSocket(context = %p, buffer = %p):\n"
|
|
"\tRemote client TID = <%s:%d> issued broadcast request.\n",
|
|
context, buffer,
|
|
inet_ntoa(buffer->internal.io.peer.sin_addr), ntohs(buffer->internal.io.peer.sin_port) ));
|
|
|
|
} // if (!(buffer->internal.io.msg.dwFlags & MSG_BCAST))
|
|
|
|
ZeroMemory(&addr, sizeof(addr));
|
|
addr.sin_family = AF_INET;
|
|
addr.sin_addr.s_addr = INADDR_ANY;
|
|
addr.sin_port = 0;
|
|
|
|
// Figure out which socket to use for this request (based on blksize).
|
|
if (context->blksize <= TFTPD_DEF_DATA) {
|
|
|
|
if (globals.io.def.s == INVALID_SOCKET) {
|
|
|
|
EnterCriticalSection(&globals.io.cs); {
|
|
|
|
if (globals.service.shutdown) {
|
|
TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED, "TFTPD service is stopping.");
|
|
LeaveCriticalSection(&globals.io.cs);
|
|
return (FALSE);
|
|
}
|
|
|
|
TftpdIoInitializeSocketContext(&globals.io.def, &addr, NULL);
|
|
|
|
if (globals.io.def.s != INVALID_SOCKET) {
|
|
context->socket = &globals.io.def;
|
|
} else {
|
|
context->socket = &globals.io.master;
|
|
if (context->options) {
|
|
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
|
|
"TftpdIoAssignSocket(): Removing requested blksize = %d "
|
|
"option since we failed to create the MTU-size socket.\n",
|
|
context->blksize));
|
|
context->options &= ~TFTPD_OPTION_BLKSIZE;
|
|
}
|
|
}
|
|
|
|
} LeaveCriticalSection(&globals.io.cs);
|
|
|
|
} else {
|
|
|
|
context->socket = &globals.io.def;
|
|
|
|
} // if (globals.io.def.s == INVALID_SOCKET)
|
|
|
|
} else {
|
|
|
|
if (context->blksize <= TFTPD_MTU_DATA) {
|
|
|
|
if (globals.io.mtu.s == INVALID_SOCKET) {
|
|
|
|
EnterCriticalSection(&globals.io.cs); {
|
|
|
|
if (globals.service.shutdown) {
|
|
TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED, "TFTPD service is stopping.");
|
|
LeaveCriticalSection(&globals.io.cs);
|
|
return (FALSE);
|
|
}
|
|
|
|
TftpdIoInitializeSocketContext(&globals.io.mtu, &addr, NULL);
|
|
|
|
if (globals.io.mtu.s != INVALID_SOCKET) {
|
|
context->socket = &globals.io.mtu;
|
|
} else {
|
|
context->socket = &globals.io.master;
|
|
if (context->options) {
|
|
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
|
|
"TftpdIoAssignSocket(): Removing requested blksize = %d "
|
|
"option since we failed to create the MTU-size socket.\n",
|
|
context->blksize));
|
|
context->options &= ~TFTPD_OPTION_BLKSIZE;
|
|
}
|
|
}
|
|
|
|
} LeaveCriticalSection(&globals.io.cs);
|
|
|
|
} else {
|
|
|
|
context->socket = &globals.io.mtu;
|
|
|
|
} // if (globals.io.mtu.s == INVALID_SOCKET)
|
|
|
|
} else if (context->blksize <= TFTPD_MAX_DATA) {
|
|
|
|
if (globals.io.max.s == INVALID_SOCKET) {
|
|
|
|
EnterCriticalSection(&globals.io.cs); {
|
|
|
|
if (globals.service.shutdown) {
|
|
TftpdIoSendErrorPacket(buffer, TFTPD_ERROR_UNDEFINED, "TFTPD service is stopping.");
|
|
LeaveCriticalSection(&globals.io.cs);
|
|
return (FALSE);
|
|
}
|
|
|
|
TftpdIoInitializeSocketContext(&globals.io.max, &addr, NULL);
|
|
|
|
if (globals.io.max.s != INVALID_SOCKET) {
|
|
context->socket = &globals.io.max;
|
|
} else {
|
|
context->socket = &globals.io.master;
|
|
if (context->options) {
|
|
TFTPD_DEBUG((TFTPD_TRACE_PROCESS,
|
|
"TftpdIoAssignSocket(): Removing requested blksize = %d "
|
|
"option since we failed to create the MAX-size socket.\n",
|
|
context->blksize));
|
|
context->options &= ~TFTPD_OPTION_BLKSIZE;
|
|
}
|
|
}
|
|
|
|
} LeaveCriticalSection(&globals.io.cs);
|
|
|
|
} else {
|
|
|
|
context->socket = &globals.io.max;
|
|
|
|
} // if (globals.io.max.s == INVALID_SOCKET)
|
|
|
|
}
|
|
|
|
} // (context->blksize <= TFTPD_DEF_DATA)
|
|
|
|
return (TRUE);
|
|
|
|
} // TftpdIoAssignSocket()
|
|
|
|
|
|
BOOL
|
|
TftpdIoDestroySocketContext(PTFTPD_SOCKET socket) {
|
|
|
|
NTSTATUS status;
|
|
SOCKET s;
|
|
|
|
if (socket->s == INVALID_SOCKET)
|
|
return (TRUE);
|
|
|
|
TFTPD_DEBUG((TFTPD_TRACE_IO,
|
|
"TftpdIoDestroySocketContext(socket = %s).\n",
|
|
((socket == &globals.io.master) ? "master" :
|
|
((socket == &globals.io.def) ? "def" :
|
|
((socket == &globals.io.mtu) ? "mtu" :
|
|
((socket == &globals.io.max) ? "max" :
|
|
"private")))) ));
|
|
|
|
// Disable further buffer posting.
|
|
socket->lowWaterMark = 0;
|
|
|
|
if (socket->context == NULL) {
|
|
|
|
if (!UnregisterWait(socket->wSelectWait)) {
|
|
DWORD error;
|
|
if ((error = GetLastError()) != ERROR_IO_PENDING) {
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoDestroySocketContext: "
|
|
"UnregisterWait() failed, error 0x%08X.\n",
|
|
error));
|
|
TftpdIoLeakSocketContext(socket);
|
|
return (FALSE);
|
|
}
|
|
}
|
|
socket->wSelectWait = NULL;
|
|
|
|
CloseHandle(socket->hSelect);
|
|
socket->hSelect = NULL;
|
|
|
|
} // if (socket->context == NULL)
|
|
|
|
// Kill the socket. This will disable the FD_READ and FD_WRITE
|
|
// event select, as well as cancel all pending overlapped operations
|
|
// on it. Add a buffer reference here so after we close the
|
|
// socket we can test if there were never any buffers posted
|
|
// which would cancel above in TftpdIoCompletionCallback so
|
|
// we should deallocate socket here.
|
|
|
|
// Kill it.
|
|
InterlockedIncrement((PLONG)&socket->numBuffers);
|
|
s = socket->s;
|
|
socket->s = INVALID_SOCKET;
|
|
if (closesocket(s) == SOCKET_ERROR) {
|
|
TFTPD_DEBUG((TFTPD_DBG_IO,
|
|
"TftpdIoDestroySocketContext: "
|
|
"closesocket() failed, error 0x%08X.\n",
|
|
GetLastError()));
|
|
socket->s = s;
|
|
InterlockedDecrement((PLONG)&socket->numBuffers);
|
|
TftpdIoLeakSocketContext(socket);
|
|
return (FALSE);
|
|
}
|
|
if (InterlockedDecrement((PLONG)&socket->numBuffers) == -1)
|
|
HeapFree(globals.hServiceHeap, 0, socket);
|
|
|
|
return (TRUE);
|
|
|
|
} // TftpdIoDestroySocketContext()
|