|
|
/*++
Copyright (c) 2001 Microsoft Corporation
Module Name:
context.c
Abstract:
This module contains functions to manage contexts for TFTP sessions with remote clients.
Author:
Jeffrey C. Venable, Sr. (jeffv) 01-Jun-2001
Revision History:
--*/
#include "precomp.h"
void TftpdContextLeak(PTFTPD_CONTEXT context) {
PLIST_ENTRY entry; EnterCriticalSection(&globals.reaper.contextCS); {
// If shutdown is occuring, we're in trouble anyways. Just let it go.
if (globals.service.shutdown) { LeaveCriticalSection(&globals.reaper.contextCS); if (InterlockedDecrement(&globals.io.numContexts) == -1) TftpdServiceAttemptCleanup(); return; }
TFTPD_DEBUG((TFTPD_TRACE_CONTEXT, "TftpdContextLeak(context = %p).\n", context));
// Is the context already in the list?
for (entry = globals.reaper.leakedContexts.Flink; entry != &globals.reaper.leakedContexts; entry = entry->Flink) { if (CONTAINING_RECORD(entry, TFTPD_CONTEXT, linkage) == context) { LeaveCriticalSection(&globals.reaper.contextCS); return; } }
InsertHeadList(&globals.reaper.leakedContexts, &context->linkage); globals.reaper.numLeakedContexts++; TftpdContextAddReference(context);
} LeaveCriticalSection(&globals.reaper.contextCS);
} // TftpdContextLeak()
BOOL TftpdContextFree( PTFTPD_CONTEXT context );
void CALLBACK TftpdContextTimerCleanup(PTFTPD_CONTEXT context, BOOLEAN timeout) {
TFTPD_DEBUG((TFTPD_TRACE_CONTEXT, "TftpdContextTimerCleanup(context = %p).\n", context));
context->hTimer = NULL; if (!UnregisterWait(context->wWait)) { DWORD error; if ((error = GetLastError()) != ERROR_IO_PENDING) { TFTPD_DEBUG((TFTPD_DBG_CONTEXT, "TftpdContextTimerCleanup(context = %p): " "UnregisterWait() failed, error 0x%08X.\n", context, error)); TftpdContextLeak(context); return; } } context->wWait = NULL;
TftpdContextFree(context);
} // TftpdContextTimerCleanup()
BOOL TftpdContextFree(PTFTPD_CONTEXT context) {
DWORD numContexts; NTSTATUS status;
TFTPD_DEBUG((TFTPD_TRACE_CONTEXT, "TftpdContextFree(context = %p).\n", context));
if (context->wWait != NULL) { if (!UnregisterWait(context->wWait)) { DWORD error; if ((error = GetLastError()) != ERROR_IO_PENDING) { TFTPD_DEBUG((TFTPD_DBG_CONTEXT, "TftpdContextFree(context = %p): " "UnregisterWait() failed, error 0x%08X.\n", context, error)); TftpdContextLeak(context); return (FALSE); } } context->wWait = NULL; } if (context->hTimer != NULL) {
HANDLE hTimer; BOOL reset;
TFTPD_DEBUG((TFTPD_TRACE_CONTEXT, "TftpdContextFree(context = %p): " "Deleting timer.\n", context));
// WriteFile() or ReadFile() may have signalled this event if they
// last completed immediately.
reset = ResetEvent(context->hWait); ASSERT(reset); ASSERT(context->wWait == NULL); if (!RegisterWaitForSingleObject(&context->wWait, context->hWait, (WAITORTIMERCALLBACKFUNC)TftpdContextTimerCleanup, context, INFINITE, (WT_EXECUTEINIOTHREAD | WT_EXECUTEONLYONCE))) { TFTPD_DEBUG((TFTPD_DBG_CONTEXT, "TftpdContextFree(context = %p): " "RegisterWaitForSingleObject() failed, error 0x%08X.\n", context, GetLastError())); TftpdContextLeak(context); return (FALSE); }
if (!DeleteTimerQueueTimer(globals.io.hTimerQueue, context->hTimer, context->hWait)) { DWORD error; if ((error = GetLastError()) != ERROR_IO_PENDING) { TFTPD_DEBUG((TFTPD_DBG_CONTEXT, "TftpdContextFree(context = %p): " "DeleteTimerQueueTimer() failed, error 0x%08X.\n", context, error)); // The next call to TftpdContextFree() to recover this context from the
// leak list will deregister the wait for us.
TftpdContextLeak(context); return (FALSE); } }
return (TRUE);
} // if (context->hTimer != NULL)
ASSERT(context->wWait == NULL);
// If a private socket was used, destroy it.
if ((context->socket != NULL) && context->socket->context) TftpdIoDestroySocketContext(context->socket);
// Cleanup everything else.
if (context->hFile != NULL) CloseHandle(context->hFile); if (context->hWait != NULL) CloseHandle(context->hWait); if (context->filename != NULL) HeapFree(globals.hServiceHeap, 0, context->filename);
numContexts = InterlockedDecrement(&globals.io.numContexts);
HeapFree(globals.hServiceHeap, 0, context);
TFTPD_DEBUG((TFTPD_TRACE_CONTEXT, "TftpdContextFree(context = %p): ### numContexts = %d.\n", context, numContexts)); if (numContexts == -1) TftpdServiceAttemptCleanup();
return (TRUE);
} // TftpdContextFree()
DWORD TftpdContextAddReference(PTFTPD_CONTEXT context) {
DWORD result; result = InterlockedIncrement(&context->reference); TFTPD_DEBUG((TFTPD_TRACE_CONTEXT, "TftpdContextAddReference(context = %p): reference = %d.\n", context, result));
return (result);
} // TftpdContextAddReference()
PTFTPD_CONTEXT TftpdContextAllocate() {
PTFTPD_CONTEXT context = NULL; DWORD numContexts; TFTPD_DEBUG((TFTPD_TRACE_CONTEXT, "TftpdContextAllocate().\n"));
if (globals.reaper.leakedContexts.Flink != &globals.reaper.leakedContexts) {
BOOL failAllocate = FALSE;
// Try to recover leaked contexts.
EnterCriticalSection(&globals.reaper.contextCS); {
PLIST_ENTRY entry; while ((entry = RemoveHeadList(&globals.reaper.leakedContexts)) != &globals.reaper.leakedContexts) {
globals.reaper.numLeakedContexts--; if (!TftpdContextFree(CONTAINING_RECORD(entry, TFTPD_CONTEXT, linkage))) { // If the free failed, the context is readded to the leak list.
// Free the reference from it having already been on the leak list.
TftpdContextRelease(context); failAllocate = TRUE; break; }
}
} LeaveCriticalSection(&globals.reaper.contextCS);
if (failAllocate) goto exit_allocate_context;
} // if (globals.reaper.leakedContexts.Flink != &globals.reaper.leakedContexts)
context = (PTFTPD_CONTEXT)HeapAlloc(globals.hServiceHeap, HEAP_ZERO_MEMORY, sizeof(TFTPD_CONTEXT)); if (context == NULL) { TFTPD_DEBUG((TFTPD_DBG_CONTEXT, "TftpdContextAllocate(): HeapAlloc() failed, error = 0x%08X.\n", GetLastError())); return (NULL); }
InitializeListHead(&context->linkage); context->sorcerer = -1;
numContexts = InterlockedIncrement(&globals.io.numContexts); TFTPD_DEBUG((TFTPD_TRACE_CONTEXT, "TftpdContextAllocate(): ### numContexts = %d.\n", numContexts));
if (globals.service.shutdown) TftpdContextFree(context), context = NULL;
exit_allocate_context :
return (context);
} // TftpdContextAllocate()
DWORD TftpdContextHash(PSOCKADDR_IN addr) {
return ((addr->sin_addr.s_addr + addr->sin_port) % globals.parameters.hashEntries);
} // TftpdContextHash()
BOOL TftpdContextAdd(PTFTPD_CONTEXT context) {
PLIST_ENTRY entry; DWORD index;
TFTPD_DEBUG((TFTPD_TRACE_CONTEXT, "TftpdContextAdd(context = %p).\n", context));
index = TftpdContextHash(&context->peer);
EnterCriticalSection(&globals.hash.table[index].cs); {
if (globals.service.shutdown) { LeaveCriticalSection(&globals.hash.table[index].cs); return (FALSE); }
// Is the context already in the table?
for (entry = globals.hash.table[index].bucket.Flink; entry != &globals.hash.table[index].bucket; entry = entry->Flink) {
PTFTPD_CONTEXT c = CONTAINING_RECORD(entry, TFTPD_CONTEXT, linkage); if ((c->peer.sin_addr.s_addr == context->peer.sin_addr.s_addr) && (c->peer.sin_port == context->peer.sin_port)) { TFTPD_DEBUG((TFTPD_DBG_CONTEXT, "TftpdContextAdd(context = %p): TID already exists.\n", context)); LeaveCriticalSection(&globals.hash.table[index].cs); return (FALSE); }
}
TftpdContextAddReference(context); InsertHeadList(&globals.hash.table[index].bucket, &context->linkage);
#if defined(DBG)
{ DWORD numEntries, maxClients; numEntries = InterlockedIncrement((PLONG)&globals.hash.numEntries); InterlockedIncrement((PLONG)&globals.hash.table[index].numEntries); while (numEntries > (maxClients = globals.performance.maxClients)) InterlockedCompareExchange((PLONG)&globals.performance.maxClients, numEntries, maxClients); } #endif // defined(DBG)
} LeaveCriticalSection(&globals.hash.table[index].cs);
return (TRUE);
} // TftpdContextAdd()
void TftpdContextRemove(PTFTPD_CONTEXT context) {
PLIST_ENTRY entry; DWORD index; TFTPD_DEBUG((TFTPD_TRACE_CONTEXT, "TftpdContextRemove(context = %p).\n", context));
index = TftpdContextHash(&context->peer);
EnterCriticalSection(&globals.hash.table[index].cs); {
// Validate that the context is still in the bucket and
// wasn't already removed by another thread.
for (entry = globals.hash.table[index].bucket.Flink; entry != &globals.hash.table[index].bucket; entry = entry->Flink) {
PTFTPD_CONTEXT c;
c = CONTAINING_RECORD(entry, TFTPD_CONTEXT, linkage);
if (c == context) {
// Pull the context out of the hash-table.
RemoveEntryList(&context->linkage); TftpdContextRelease(context);
#if defined(DBG)
InterlockedDecrement((PLONG)&globals.hash.numEntries); InterlockedDecrement((PLONG)&globals.hash.table[index].numEntries); #endif // defined(DBG)
break;
} // if (c == context)
}
} LeaveCriticalSection(&globals.hash.table[index].cs);
} // TftpdContextRemove()
void TftpdContextKill(PTFTPD_CONTEXT context) {
// Set the dead flag in the context state.
while (TRUE) { DWORD state = context->state; if (state & TFTPD_STATE_DEAD) return; if (InterlockedCompareExchange(&context->state, (state | TFTPD_STATE_DEAD), state) == state) break; }
TFTPD_DEBUG((TFTPD_TRACE_CONTEXT, "TftpdContextKill(context = %p).\n", context));
// Add a reference count to the context for ourselves so it won't free
// itself from under us as we close the file below.
TftpdContextAddReference(context);
// Remove it from the hash-table.
TftpdContextRemove(context);
// Close the file. This will force any outstanding overlapped read or write operations
// to complete immediately, deregister their waits, and decrement their reference
// to this context.
if (context->hFile != NULL) { CloseHandle(context->hFile); context->hFile = NULL; }
// Release our kill reference.
TftpdContextRelease(context);
} // TftpdContextKill()
BOOL TftpdContextUpdateTimer(PTFTPD_CONTEXT context) {
ULONG timeout = context->timeout;
ASSERT(context->state & TFTPD_STATE_BUSY);
if (!timeout) { unsigned int x; timeout = 1000; for (x = 0; x < context->retransmissions; x++) timeout *= 2; if (timeout > 10000) timeout = 10000; }
// Update the retransmission timer.
return (ChangeTimerQueueTimer(globals.io.hTimerQueue, context->hTimer, timeout, 720000));
} // TftpdContextUpdateTimer()
PTFTPD_CONTEXT TftpdContextAquire(PSOCKADDR_IN addr) {
PTFTPD_CONTEXT context = NULL; PLIST_ENTRY entry; DWORD index;
if (globals.service.shutdown) goto exit_acquire;
index = TftpdContextHash(addr);
EnterCriticalSection(&globals.hash.table[index].cs); {
if (!globals.service.shutdown) {
for (entry = globals.hash.table[index].bucket.Flink; entry != &globals.hash.table[index].bucket; entry = entry->Flink) {
PTFTPD_CONTEXT c; c = CONTAINING_RECORD(entry, TFTPD_CONTEXT, linkage);
if ((c->peer.sin_addr.s_addr == addr->sin_addr.s_addr) && (c->peer.sin_port == addr->sin_port)) { context = c; TftpdContextAddReference(context); break; }
}
} // if (!globals.service.shutdown)
} LeaveCriticalSection(&globals.hash.table[index].cs);
if ((context != NULL) && (context->state & TFTPD_STATE_DEAD)) { TftpdContextRelease(context); context = NULL; }
exit_acquire : TFTPD_DEBUG((TFTPD_TRACE_CONTEXT, "TftpdContextAquire(TID = %s:%d): context = %p.\n", inet_ntoa(addr->sin_addr), addr->sin_port, context));
return (context);
} // TftpdContextAquire()
DWORD TftpdContextRelease(PTFTPD_CONTEXT context) {
DWORD reference;
TFTPD_DEBUG((TFTPD_TRACE_CONTEXT, "TftpdContextRelease(context = %p).\n", context));
// When a context is killable, only its retransmit timer will have a reference to it.
reference = InterlockedDecrement(&context->reference); if (reference == 0) TftpdContextFree(context);
return (reference);
} // TftpdContextRelease()
|