Windows NT 4.0 source code leak
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

2949 lines
70 KiB

/*++
Copyright (c) 1992-1996 Microsoft Corporation
Module Name:
sockproc.c
Abstract:
This module contains support routines for the WinSock DLL.
Author:
David Treadwell (davidtr) 20-Feb-1992
Revision History:
--*/
#include "winsockp.h"
#include <ctype.h>
#include <stdarg.h>
#include <wincon.h>
#define MAX_BLOCKING_HOOK_CALLS 1000
//
// The (PCHAR) casts in the following macro force the compiler to assume
// only BYTE alignment.
//
#define SockCopyMemory(d,s,l) RtlCopyMemory( (PCHAR)(d), (PCHAR)(s), (l) )
VOID
SockBuildSockaddr (
OUT PSOCKADDR Sockaddr,
OUT PINT SockaddrLength,
IN PTRANSPORT_ADDRESS TdiAddress
)
{
WS_ASSERT( sizeof(TdiAddress->Address[0].AddressType) ==
sizeof(Sockaddr->sa_family) );
WS_ASSERT( FIELD_OFFSET( TA_ADDRESS, AddressLength ) == 0 );
WS_ASSERT( FIELD_OFFSET( TA_ADDRESS, AddressType ) == sizeof(USHORT) );
WS_ASSERT( FIELD_OFFSET( TRANSPORT_ADDRESS, Address[0] ) == sizeof(int) );
WS_ASSERT( FIELD_OFFSET( SOCKADDR, sa_family ) == 0 );
//
// Convert the specified TDI address to a sockaddr.
//
*SockaddrLength = TdiAddress->Address[0].AddressLength +
sizeof(Sockaddr->sa_family);
RtlCopyMemory(
Sockaddr,
&TdiAddress->Address[0].AddressType,
*SockaddrLength
);
return;
} // SockBuildSockaddr
VOID
SockBuildTdiAddress (
OUT PTRANSPORT_ADDRESS TdiAddress,
IN PSOCKADDR Sockaddr,
IN INT SockaddrLength
)
{
WS_ASSERT( sizeof(TdiAddress->Address[0].AddressType) ==
sizeof(Sockaddr->sa_family) );
WS_ASSERT( FIELD_OFFSET( TA_ADDRESS, AddressLength ) == 0 );
WS_ASSERT( FIELD_OFFSET( TA_ADDRESS, AddressType ) == sizeof(USHORT) );
WS_ASSERT( FIELD_OFFSET( TRANSPORT_ADDRESS, Address[0] ) == sizeof(int) );
WS_ASSERT( FIELD_OFFSET( SOCKADDR, sa_family ) == 0 );
//
// Convert the specified sockaddr to a TDI address.
//
TdiAddress->TAAddressCount = 1;
TdiAddress->Address[0].AddressLength =
SockaddrLength - sizeof(Sockaddr->sa_family) ;
RtlCopyMemory(
&TdiAddress->Address[0].AddressType,
Sockaddr,
SockaddrLength
);
return;
} // SockBuildTdiAddress
VOID
SockDereferenceSocket (
IN PSOCKET_INFORMATION Socket
)
/*++
Routine Description:
Dereferences the specified socket and, if necessary, removes it from
the global list of sockets.
Arguments:
Socket - a pointer to the socket to dereference.
Return Value:
None.
--*/
{
DWORD error;
WS_ASSERT( Socket->ReferenceCount > 0 );
//
// Acquire the resource that protects the socket reference count.
//
SockAcquireGlobalLockExclusive( );
//
// Dereference the socket.
//
Socket->ReferenceCount--;
//
// If the reference count on the socket has dropped to zero, then
// nobody is accessing it and the active reference has gone away,
// so it is time to delete the socket from the process's list
// of sockets.
//
if ( Socket->ReferenceCount == 0 ) {
//
// Remove the socket from the global list of sockets.
//
RemoveEntryList( &Socket->SocketListEntry );
error = WahRemoveContextEx(
SockContextTable,
Socket->Handle,
(LPVOID)Socket
);
#if DBG
if( error != NO_ERROR ) {
WS_PRINT((
"SockDereferenceSocket: WahRemoveContextEx failed, error %d\n",
error
));
}
#endif
//
// Release the global lock.
//
SockReleaseGlobalLock( );
//
// Delete the resource on the socket.
//
DeleteCriticalSection( &Socket->Lock );
//
// Finally, free the socket structure itself.
//
FREE_HEAP( Socket );
} else {
SockReleaseGlobalLock( );
}
} // SockDereferenceSocket
PSOCKET_INFORMATION
SockFindAndReferenceSocket (
IN SOCKET Handle,
IN BOOLEAN AttemptImport
)
/*++
Routine Description:
Looks up a socket in the global socket table, and references
it if found.
Arguments:
Handle - NT system handle of the socket to locate.
AttemptImport - if the socket isn't currently valid in this
process, this parameter specifies whether we should attempt
to import the handle into this process.
Return Value:
PSOCKET_INFORMATION - a referenced pointer to a socket structure,
or NULL if none was found that matched the specified handle.
--*/
{
PSOCKET_INFORMATION socket;
PLIST_ENTRY listEntry;
BOOLEAN found = FALSE;
DWORD error;
//
// Acquire the resource that protects sockets.
//
SockAcquireGlobalLockExclusive( );
//
// Attempt to find the socket in the list.
//
error = WahGetContext(
SockContextTable,
Handle,
(LPVOID *)&socket
);
if( error == NO_ERROR ) {
WS_ASSERT( socket != NULL );
WS_ASSERT( socket->Handle == Handle );
if( socket->State == SocketStateClosing ) {
//
// This socket is in the process of closing, so there's
// no point in trying to import it from AFD.
//
AttemptImport = FALSE;
} else {
//
// Found it.
//
found = TRUE;
}
}
//
// If the socket wasn't found, check to see if AFD knows about the
// handle. If AFD knows about it, then it was either inherited or
// duped into this process, and we need to set up state for it.
//
if ( !found ) {
if ( AttemptImport ) {
socket = SockGetHandleContext( Handle );
if ( socket == NULL ) {
SockReleaseGlobalLock( );
return NULL;
}
} else {
SockReleaseGlobalLock( );
return NULL;
}
}
//
// The socket was found, so reference the socket and return. The
// reference ensures that the socket information structure will not
// be deallocated while somebody is looking at it. However, it is
// the responsibility of the caller to dereference the socket when
// it is done using the structure.
//
WS_ASSERT( socket->ReferenceCount > 0 );
socket->ReferenceCount++;
SockReleaseGlobalLock( );
return socket;
} // SockFindAndReferenceSocket
PSOCKET_INFORMATION
SockGetHandleContext (
IN SOCKET Handle
)
{
NTSTATUS status;
IO_STATUS_BLOCK ioStatusBlock;
ULONG contextLength;
PVOID context;
PSOCKET_INFORMATION newSocket;
UNICODE_STRING transportDeviceName;
BOOLEAN succeeded;
ULONG helperDllContextLength;
PCHAR contextPtr;
INT error;
BOOLEAN resourceInitialized;
PVOID helperDllContext;
PWINSOCK_HELPER_DLL_INFO helperDll;
DWORD helperDllNotificationEvents;
ULONG newSocketLength;
INT addressFamily;
INT socketType;
INT protocol;
UCHAR contextBuffer[MAX_FAST_HANDLE_CONTEXT];
//
// Get the lock that protects access to socket lists, etc.
//
SockAcquireGlobalLockExclusive( );
//
// Initialize locals so that we know how to clean up on exit.
//
context = NULL;
newSocket = NULL;
succeeded = FALSE;
resourceInitialized = FALSE;
RtlInitUnicodeString( &transportDeviceName, NULL );
//
// Call AFD to determine the length of context info for the socket.
// If this succeeds, then it is most likely true that the handle
// is valid for this process.
//
status = NtDeviceIoControlFile(
(HANDLE)Handle,
SockThreadEvent,
NULL, // APC Routine
NULL, // APC Context
&ioStatusBlock,
IOCTL_AFD_GET_CONTEXT_LENGTH,
NULL,
0,
&contextLength,
sizeof(contextLength)
);
if ( status == STATUS_PENDING ) {
SockWaitForSingleObject(
SockThreadEvent,
Handle,
SOCK_NEVER_CALL_BLOCKING_HOOK,
SOCK_NO_TIMEOUT
);
status = ioStatusBlock.Status;
}
if ( !NT_SUCCESS(status) || contextLength < sizeof(*newSocket) ) {
goto exit;
}
//
// Now allocate memory to hold the socket context and get the actual
// context for the socket.
//
if( contextLength <= sizeof(contextBuffer) ) {
context = contextBuffer;
} else {
context = ALLOCATE_HEAP( contextLength );
if ( context == NULL ) {
goto exit;
}
}
status = NtDeviceIoControlFile(
(HANDLE)Handle,
SockThreadEvent,
NULL, // APC Routine
NULL, // APC Context
&ioStatusBlock,
IOCTL_AFD_GET_CONTEXT,
NULL,
0,
context,
contextLength
);
if ( status == STATUS_PENDING ) {
SockWaitForSingleObject(
SockThreadEvent,
Handle,
SOCK_NEVER_CALL_BLOCKING_HOOK,
SOCK_NO_TIMEOUT
);
status = ioStatusBlock.Status;
}
if ( !NT_SUCCESS(status) ) {
goto exit;
}
//
// We have obtained the necessary context for the socket. The context
// information is structured as follows:
//
// SOCKET_INFORMATION structure
// Helper DLL Context Length
// Local Address
// Remote Address
// Helper DLL Context
//
//
// Grab some parameters from the context structure.
//
addressFamily = ((PSOCKET_INFORMATION)context)->AddressFamily;
socketType = ((PSOCKET_INFORMATION)context)->SocketType;
protocol = ((PSOCKET_INFORMATION)context)->Protocol;
//
// Get the helper DLL for the socket loaded.
//
error = SockGetTdiName(
&addressFamily,
&socketType,
&protocol,
0,
0,
&transportDeviceName,
&helperDllContext,
&helperDll,
&helperDllNotificationEvents
);
if ( error != NO_ERROR ) {
goto exit;
}
//
// Allocate a socket information structure for this socket.
//
newSocketLength = ALIGN_8(sizeof(*newSocket)) +
(ALIGN_8(helperDll->MaxSockaddrLength) * 2);
newSocket = ALLOCATE_HEAP( newSocketLength );
if ( newSocket == NULL ) {
goto exit;
}
//
// Copy in to the new socket information structure the initial context.
//
RtlCopyMemory( newSocket, context, sizeof(*newSocket) );
//
// Initialize various fields in the socket information structure.
//
// Note that the reference count is initialized to 2 to account for
// the SockDereferenceSocket() call below.
//
newSocket->Handle = Handle;
newSocket->ReferenceCount = 2;
newSocket->ConnectInProgress = FALSE;
newSocket->TdiAddressHandle = NULL;
newSocket->TdiConnectionHandle = NULL;
newSocket->HelperDll = NULL;
newSocket->SocketListEntry.Flink = NULL;
newSocket->LocalAddress = NULL;
newSocket->RemoteAddress = NULL;
newSocket->HelperDllContext = helperDllContext;
newSocket->HelperDll = helperDll;
newSocket->HelperDllNotificationEvents = helperDllNotificationEvents;
newSocket->LocalAddress = (PVOID)ALIGN_8(newSocket + 1);
newSocket->LocalAddressLength = helperDll->MaxSockaddrLength;
newSocket->RemoteAddress = (PVOID)ALIGN_8((PUCHAR)newSocket->LocalAddress +
helperDll->MaxSockaddrLength);
newSocket->RemoteAddressLength = helperDll->MaxSockaddrLength;
try {
InitializeCriticalSection( &newSocket->Lock );
error = NO_ERROR;
} except( SOCK_EXCEPTION_FILTER() ) {
error = GetExceptionCode();
}
if( error != NO_ERROR ) {
goto exit;
}
resourceInitialized = TRUE;
//
// Determine the length of the helper DLL's context information.
//
contextPtr = (PCHAR)context + sizeof(*newSocket);
helperDllContextLength = *(PULONG)contextPtr;
contextPtr += sizeof(ULONG);
//
// Copy in information from the context buffer retrieved from AFD.
//
WS_ASSERT( newSocket->HelperDll != NULL );
RtlCopyMemory(
newSocket->LocalAddress,
contextPtr,
newSocket->LocalAddressLength
);
contextPtr += newSocket->LocalAddressLength;
RtlCopyMemory(
newSocket->RemoteAddress,
contextPtr,
newSocket->RemoteAddressLength
);
contextPtr += newSocket->RemoteAddressLength;
//
// Get TDI handles for this socket.
//
error = SockGetTdiHandles( newSocket );
if ( error != NO_ERROR ) {
goto exit;
}
//
// Give the socket a unique serial number. This is used to identify
// the socket for AsyncSelect requests.
//
newSocket->SocketSerialNumber = SockSocketSerialNumberCounter++;
//
// Place the socket information structure in the process's global
// list of sockets.
//
error = WahSetContext(
SockContextTable,
newSocket->Handle,
newSocket
);
if( error != NO_ERROR ) {
goto exit;
}
InsertHeadList( &SocketListHead, &newSocket->SocketListEntry );
//
// If the socket has AsyncSelect events set up, set them up for this
// process.
//
if ( newSocket->AsyncSelectlEvent ) {
INT result;
result = WSPAsyncSelect(
newSocket->Handle,
newSocket->AsyncSelecthWnd,
newSocket->AsyncSelectwMsg,
newSocket->AsyncSelectlEvent,
&error
);
if( result == SOCKET_ERROR ) {
goto exit;
}
}
succeeded = TRUE;
exit:
if ((socketType == SOCK_RAW) && (transportDeviceName.Buffer != NULL)) {
RtlFreeHeap( RtlProcessHeap(), 0, transportDeviceName.Buffer );
}
if ( !succeeded && newSocket != NULL ) {
if ( resourceInitialized ) {
DeleteCriticalSection( &newSocket->Lock );
}
if ( newSocket->TdiAddressHandle != NULL ) {
status = NtClose( newSocket->TdiAddressHandle );
WS_ASSERT( NT_SUCCESS(status) );
}
if ( newSocket->TdiConnectionHandle != NULL ) {
status = NtClose( newSocket->TdiConnectionHandle );
WS_ASSERT( NT_SUCCESS(status) );
}
if ( newSocket->HelperDll != NULL ) {
SockNotifyHelperDll( newSocket, WSH_NOTIFY_CLOSE );
}
if ( newSocket->SocketListEntry.Flink != NULL ) {
DWORD dummy;
RemoveEntryList( &newSocket->SocketListEntry );
dummy = WahRemoveContext(
SockContextTable,
newSocket->Handle
);
WS_ASSERT( dummy == NO_ERROR );
}
FREE_HEAP( newSocket );
newSocket = NULL;
IF_DEBUG(SOCKET) {
WS_PRINT(( "SockGetHandleContext: failed to import socket "
"handle %lx\n", Handle ));
}
} else if ( !succeeded ) {
IF_DEBUG(SOCKET) {
WS_PRINT(( "SockGetHandleContext: failed to import socket "
"handle %lx, unknown to AFD: %lX\n", Handle, status ));
}
} else {
WS_ASSERT( succeeded );
IF_DEBUG(SOCKET) {
WS_PRINT(( "Imported socket %lx (%lx) of type %s\n",
newSocket->Handle, &newSocket,
( newSocket->SocketType == SOCK_DGRAM ?
"SOCK_DGRAM" :
(newSocket->SocketType == SOCK_STREAM ?
"SOCK_STREAM" : "SOCK_RAW")) ));
}
SockDereferenceSocket( newSocket );
}
if ( context != NULL && context != contextBuffer ) {
FREE_HEAP( context );
}
SockReleaseGlobalLock( );
return newSocket;
} // SockGetHandleContext
INT
SockSetHandleContext (
IN PSOCKET_INFORMATION Socket
)
{
NTSTATUS status;
PVOID context;
PCHAR contextPtr;
ULONG contextLength;
ULONG helperDllContextLength;
IO_STATUS_BLOCK ioStatusBlock;
INT error;
UCHAR contextBuffer[MAX_FAST_HANDLE_CONTEXT];
//
// Determine how much space we need for the helper DLL context.
//
error = Socket->HelperDll->WSHGetSocketInformation (
Socket->HelperDllContext,
Socket->Handle,
Socket->TdiAddressHandle,
Socket->TdiConnectionHandle,
SOL_INTERNAL,
SO_CONTEXT,
NULL,
(PINT)&helperDllContextLength
);
if ( error != NO_ERROR ) {
return NO_ERROR; // !!!
//return error;
}
//
// Allocate a buffer to hold all context information.
//
contextLength = sizeof(*Socket) + Socket->LocalAddressLength +
Socket->RemoteAddressLength +
sizeof(helperDllContextLength) + helperDllContextLength;
if( contextLength <= sizeof(contextBuffer) ) {
context = (PVOID)contextBuffer;
} else {
context = ALLOCATE_HEAP( contextLength );
if ( context == NULL ) {
error = WSAENOBUFS;
return error;
}
}
//
// Copy over information to the context buffer. The context buffer
// has the following format:
//
// SOCKET_INFORMATION structure
// Helper DLL Context Length
// Local Address
// Remote Address
// Helper DLL Context
//
contextPtr = context;
RtlCopyMemory( contextPtr, Socket, sizeof(*Socket) );
contextPtr += sizeof(*Socket);
*(PULONG)contextPtr = helperDllContextLength;
contextPtr += sizeof(helperDllContextLength);
RtlCopyMemory(contextPtr, Socket->LocalAddress, Socket->LocalAddressLength );
contextPtr += Socket->LocalAddressLength;
RtlCopyMemory(contextPtr, Socket->RemoteAddress, Socket->RemoteAddressLength );
contextPtr += Socket->RemoteAddressLength;
//
// Get the context from the helper DLL.
//
error = Socket->HelperDll->WSHGetSocketInformation (
Socket->HelperDllContext,
Socket->Handle,
Socket->TdiAddressHandle,
Socket->TdiConnectionHandle,
SOL_INTERNAL,
SO_CONTEXT,
contextPtr,
(PINT)&helperDllContextLength
);
if ( error != NO_ERROR ) {
if( context != (PVOID)contextBuffer ) {
FREE_HEAP( context );
}
return error;
}
//
// Now give all this information to AFD to hold on to.
//
status = NtDeviceIoControlFile(
(HANDLE)Socket->Handle,
SockThreadEvent,
NULL, // APC Routine
NULL, // APC Context
&ioStatusBlock,
IOCTL_AFD_SET_CONTEXT,
context,
contextLength,
NULL,
0
);
if ( status == STATUS_PENDING ) {
SockWaitForSingleObject(
SockThreadEvent,
Socket->Handle,
SOCK_NEVER_CALL_BLOCKING_HOOK,
SOCK_NO_TIMEOUT
);
status = ioStatusBlock.Status;
}
if( context != (PVOID)contextBuffer ) {
FREE_HEAP( context );
}
if ( !NT_SUCCESS(status) ) {
error = SockNtStatusToSocketError( status );
return error;
}
return NO_ERROR;
} // SockSetHandleContext
VOID
SockReferenceSocket (
IN PSOCKET_INFORMATION Socket
)
/*++
Routine Description:
Increments the reference count on the passed-in socket handle.
Arguments:
Socket - points to the socket information structure.
Return Value:
PSOCKET_INFORMATION - a referenced pointer to a socket structure,
or NULL if none was found that matched the specified handle.
--*/
{
//
// Acquire the resource that protects socket reference counts.
//
SockAcquireGlobalLockExclusive( );
//
// The reference ensures that the socket information structure will
// not be deallocated while somebody is looking at it. However, it
// is the responsibility of the caller to dereference the socket
// when it is done using the structure.
//
WS_ASSERT( Socket->ReferenceCount > 0 );
Socket->ReferenceCount++;
SockReleaseGlobalLock( );
return;
} // SockReferenceSocket
BOOL
SockWaitForSingleObject (
IN HANDLE Handle,
IN SOCKET SocketHandle,
IN DWORD BlockingHookUsage,
IN DWORD TimeoutUsage
)
/*++
Routine Description:
Does an alertable wait on the specified handle. If the wait completes
due to an alert, it rewaits.
Arguments:
Handle - NT system handle to wait on.
SocketHandle - the socket handle on which we're performing the IO
we're waiting for. This is necessary to support
WSACancelBlockingCall().
BlockingHookUsage - indicates whether to call the thread's blocking
hook. Possible values are:
SOCK_ALWAYS_CALL_BLOCKING_HOOK - blocking hook is always
called if blocking is necessary.
SOCK_CONDITIONALLY_CALL_BLOCKING_HOOK - blocking hook is
called if the socket is blocking (i.e. not a nonblocking
socket).
SOCK_NEVER_CALL_BLOCKING_HOOK - blocking hook is never
called.
TimeoutUsage - determines whether to wait infinitely or for a
timeout. Possible values are:
SOCK_NO_TIMEOUT - wait forever for the handle to be
signalled.
SOCK_SEND_TIMEOUT - use the socket's send timeout value
as a timeout.
SOCK_RECEIVE_TIMEOUT - use the socket's receive timeout
value as a timeout.
Return Value:
BOOL - TRUE if the object was signalled within the appropriate
timeout, and FALSE if the timeout occurred first.
--*/
{
NTSTATUS status;
LARGE_INTEGER timeout;
BOOLEAN callBlockingHook;
BOOLEAN useTimeout;
LARGE_INTEGER endTime;
LARGE_INTEGER currentTime;
PSOCKET_INFORMATION socket = NULL;
LPBLOCKINGCALLBACK blockingCallback;
DWORD blockingContext;
//
// First wait for the object for a little while. This handles the
// usual case where the object is already signalled or is signalled
// shortly into the wait. We'll only go through the longer, more
// complex path if we're going to have to wait longer.
//
timeout.HighPart = 0xFFFFFFFF;
timeout.LowPart = (ULONG)(-1 * (10*1000*500)); // 0.5 seconds
status = NtWaitForSingleObject( Handle, TRUE, &timeout );
if ( status == STATUS_SUCCESS ) {
return TRUE;
}
//
// If we need to extract information from the socket, get a pointer
// to the socket information structure.
//
if ( BlockingHookUsage == SOCK_CONDITIONALLY_CALL_BLOCKING_HOOK ||
BlockingHookUsage == SOCK_ALWAYS_CALL_BLOCKING_HOOK ||
TimeoutUsage == SOCK_SEND_TIMEOUT ||
TimeoutUsage == SOCK_RECEIVE_TIMEOUT ) {
socket = SockFindAndReferenceSocket( SocketHandle, FALSE );
if ( socket == NULL ) {
NtWaitForSingleObject( Handle, TRUE, NULL );
return TRUE;
}
}
//
// Determine whether we need to call the blocking hook while
// we're waiting.
//
switch ( BlockingHookUsage ) {
case SOCK_ALWAYS_CALL_BLOCKING_HOOK:
//
// We'll assume (for now) that we'll need to call the blocking
// hook. If we later determine that there is no blocking hook
// installed, then we obviously cannot call it...
//
callBlockingHook = TRUE;
break;
case SOCK_CONDITIONALLY_CALL_BLOCKING_HOOK:
//
// We'll try to call the blocking hook if this is a blocking socket.
// (Later we'll determine if there is really a blocking hook
// installed.)
//
callBlockingHook = !socket->NonBlocking;
break;
case SOCK_NEVER_CALL_BLOCKING_HOOK:
callBlockingHook = FALSE;
break;
default:
WS_ASSERT( FALSE );
break;
}
//
// Determine if there's really a blocking hook installed. If the
// upcall fails, we'll just press on regardless.
//
if( callBlockingHook == TRUE ) {
INT result;
INT error;
ASSERT( socket != NULL );
blockingCallback = NULL;
result = (SockUpcallTable->lpWPUQueryBlockingCallback)(
socket->CatalogEntryId,
&blockingCallback,
&blockingContext,
&error
);
if( result == SOCKET_ERROR ) {
WS_PRINT((
"SockWaitForSingleObject: WPUQueryBlockingCallback failed %d\n",
error
));
}
callBlockingHook = ( blockingCallback != NULL );
}
//
// Determine what our timeout should be, if any.
//
switch ( TimeoutUsage ) {
case SOCK_NO_TIMEOUT:
useTimeout = FALSE;
break;
case SOCK_SEND_TIMEOUT:
if ( socket->SendTimeout != 0 ) {
useTimeout = TRUE;
timeout = RtlEnlargedIntegerMultiply( socket->SendTimeout, 10*1000 );
} else {
useTimeout = FALSE;
}
break;
case SOCK_RECEIVE_TIMEOUT:
if ( socket->ReceiveTimeout != 0 ) {
useTimeout = TRUE;
timeout = RtlEnlargedIntegerMultiply( socket->ReceiveTimeout, 10*1000 );
} else {
useTimeout = FALSE;
}
break;
default:
WS_ASSERT( FALSE );
break;
}
//
// Dereference the socket if we got a pointer to the socket
// information structure.
//
if ( socket != NULL ) {
SockDereferenceSocket( socket );
#if DBG
socket = NULL;
#endif
}
//
// Calculate the end time we'll use when waiting on the handle. The
// end time is the time at which we must quit waiting on the handle
// and must instead return from this function.
//
if ( useTimeout ) {
//
// The end time if the current time plus the timeout. Query
// the current time.
//
status = NtQuerySystemTime( &currentTime );
WS_ASSERT( NT_SUCCESS(status) );
endTime.QuadPart = currentTime.QuadPart + timeout.QuadPart;
} else {
//
// We need an infinite timeout. Set the end time to the largest
// possible time in NT format.
//
endTime.LowPart = 0xFFFFFFFF;
endTime.HighPart = 0x7FFFFFFF;
}
//
// If we're going to be calling a blocking hook, set up a minimal
// timeout since we have to call the blocking hook instead of idly
// waiting. If we won't be calling the blocking hook, then we'll
// wait until the end time.
//
if ( callBlockingHook ) {
timeout.LowPart = 0xFFFFFFFF;
timeout.HighPart = 0xFFFFFFFF;
} else {
timeout = endTime;
}
//
// Remember that we're in a blocking call to prevent other winsock
// calls from succeeding. Also initialize the thread's cancel
// Boolean so that we can tell whether the IO has been cancelled,
// and remember the socket handle on which we're doing the IO.
//
WS_ASSERT( !SockThreadIsBlocking );
SockThreadIsBlocking = TRUE;
SockThreadSocketHandle = SocketHandle;
do {
//
// If necessary, call the blocking hook function until it
// returns FALSE. This gives the routine the oppurtunity
// to process all the available messages before we complete
// the wait.
//
if ( callBlockingHook ) {
ASSERT( blockingCallback != NULL );
if( !(blockingCallback)( blockingContext ) ) {
ASSERT( SockThreadIoCancelled == TRUE );
}
}
//
// If the operation was cancelled, reset the timeout to infinite
// and wait for the cancellation. We don't want to call the
// blocking hook after the IO is cancelled.
//
if ( SockThreadIoCancelled ) {
timeout.LowPart = 0xFFFFFFFF;
timeout.HighPart = 0x7FFFFFFF;
} else {
//
// Determine whether we have exceeded the end time. If we
// have exceeded the end time then we must not wait any
// longer.
//
status = NtQuerySystemTime( &currentTime );
WS_ASSERT( NT_SUCCESS(status) );
if ( currentTime.QuadPart > endTime.QuadPart ) {
status = STATUS_TIMEOUT;
break;
}
}
//
// Perform the actual wait on the object handle.
//
status = NtWaitForSingleObject( Handle, TRUE, &timeout );
WS_ASSERT( NT_SUCCESS(status) );
WS_ASSERT( status != STATUS_TIMEOUT || !SockThreadIoCancelled );
} while ( status == STATUS_USER_APC ||
status == STATUS_ALERTED ||
status == STATUS_TIMEOUT );
//
// Reset thread variables.
//
SockThreadIsBlocking = FALSE;
SockThreadSocketHandle = INVALID_SOCKET;
//
// Return TRUE if the wait's return code was success; otherwise, we
// had to timeout the wait so return FALSE.
//
if ( status == STATUS_SUCCESS ) {
return TRUE;
} else {
return FALSE;
}
} // SockWaitForSingleObject
BOOL
SockDefaultBlockingHook (
VOID
)
{
MSG msg;
BOOLEAN retrievedMessage;
//
// Get the next message for this thread, if any.
//
retrievedMessage = PeekMessage( &msg, NULL, 0, 0, PM_REMOVE );
//
// Process the message if we got one.
//
if ( retrievedMessage ) {
TranslateMessage( &msg );
DispatchMessage( &msg );
}
//
// If we got a message, indicate that we want to be called again.
//
return retrievedMessage;
} // SockDefaultBlockingHook
BOOLEAN
SockIsSocketConnected (
IN PSOCKET_INFORMATION Socket
)
{
NTSTATUS status;
//
// If there is a connect in progress, call NtTestAlert() to
// give the connect completion APC a chance to run.
//
// !!! Note that if this routine is called in a different thread from
// the one that initiated the connect, it is possible that the
// connect has completed and the APC hasn't run, which would result
// in this routine giving an incorrect answer. Do we care about
// this?
if ( Socket->ConnectInProgress ) {
do {
status = NtTestAlert( );
} while ( status == STATUS_ALERTED || status == STATUS_USER_APC );
}
//
// Check whether the socket is already connected.
//
if ( Socket->State == SocketStateConnected ) {
return TRUE;
}
return FALSE;
} // SockIsSocketConnected
INT
SockGetTdiHandles (
IN PSOCKET_INFORMATION Socket
)
{
AFD_HANDLE_INFO handleInfo;
NTSTATUS status;
IO_STATUS_BLOCK ioStatusBlock;
ULONG getHandleInfo;
//
// Determine which handles we need to get.
//
getHandleInfo = 0;
if ( Socket->TdiAddressHandle == NULL ) {
getHandleInfo |= AFD_QUERY_ADDRESS_HANDLE;
}
if ( Socket->TdiConnectionHandle == NULL ) {
getHandleInfo |= AFD_QUERY_CONNECTION_HANDLE;
}
//
// If we already have both TDI handles for the socket, just return.
//
if ( getHandleInfo == 0 ) {
return NO_ERROR;
}
//
// Call AFD to retrieve the TDI handles for the socket.
//
status = NtDeviceIoControlFile(
(HANDLE)Socket->Handle,
SockThreadEvent,
NULL, // APC Routine
NULL, // APC Context
&ioStatusBlock,
IOCTL_AFD_QUERY_HANDLES,
&getHandleInfo,
sizeof(getHandleInfo),
&handleInfo,
sizeof(handleInfo)
);
// *** Because this routine can be called at APC level from
// ConnectCompletionApc(), IOCTL_AFD_QUERY_HANDLES must
// never pend.
WS_ASSERT( status != STATUS_PENDING );
if ( status == STATUS_PENDING ) {
SockWaitForSingleObject(
SockThreadEvent,
Socket->Handle,
SOCK_NEVER_CALL_BLOCKING_HOOK,
SOCK_NO_TIMEOUT
);
status = ioStatusBlock.Status;
}
if ( !NT_SUCCESS(status) ) {
return SockNtStatusToSocketError( status );
}
//
// Set up the handles that we were returned.
//
if ( Socket->TdiAddressHandle == NULL ) {
Socket->TdiAddressHandle = handleInfo.TdiAddressHandle;
}
if ( Socket->TdiConnectionHandle == NULL ) {
Socket->TdiConnectionHandle = handleInfo.TdiConnectionHandle;
}
return NO_ERROR;
} // SockGetTdiHandles
INT
SockGetInformation (
IN PSOCKET_INFORMATION Socket,
IN ULONG InformationType,
IN PVOID AdditionalInputInfo OPTIONAL,
IN ULONG AdditionalInputInfoLength,
IN OUT PBOOLEAN Boolean OPTIONAL,
IN OUT PULONG Ulong OPTIONAL,
IN OUT PLARGE_INTEGER LargeInteger OPTIONAL
)
{
NTSTATUS status;
IO_STATUS_BLOCK ioStatusBlock;
PAFD_INFORMATION afdInfo;
ULONG afdInfoLength;
//
// Allocate space for the I/O buffer.
//
afdInfoLength = sizeof(*afdInfo) + AdditionalInputInfoLength;
afdInfo = ALLOCATE_HEAP( afdInfoLength );
if ( afdInfo == NULL ) {
return WSAENOBUFS;
}
//
// Set up the AFD information block.
//
afdInfo->InformationType = InformationType;
//
// If there is additional input information, copy it to the input
// buffer.
//
if ( ARGUMENT_PRESENT( AdditionalInputInfo ) ) {
WS_ASSERT( AdditionalInputInfoLength != 0 );
RtlCopyMemory( afdInfo + 1, AdditionalInputInfo, AdditionalInputInfoLength );
}
//
// Set the blocking mode to AFD.
//
status = NtDeviceIoControlFile(
(HANDLE)Socket->Handle,
SockThreadEvent,
NULL, // APC Routine
NULL, // APC Context
&ioStatusBlock,
IOCTL_AFD_GET_INFORMATION,
afdInfo,
afdInfoLength,
afdInfo,
sizeof(*afdInfo)
);
//
// Wait for the operation to complete.
//
if ( status == STATUS_PENDING ) {
SockWaitForSingleObject(
SockThreadEvent,
Socket->Handle,
SOCK_NEVER_CALL_BLOCKING_HOOK,
SOCK_NO_TIMEOUT
);
status = ioStatusBlock.Status;
}
if ( !NT_SUCCESS(status) ) {
FREE_HEAP( afdInfo );
return SockNtStatusToSocketError( status );
}
//
// Put the return info in the requested parameter.
//
if ( ARGUMENT_PRESENT( Boolean ) ) {
*Boolean = afdInfo->Information.Boolean;
} else if ( ARGUMENT_PRESENT( Ulong ) ) {
*Ulong = afdInfo->Information.Ulong;
} else {
WS_ASSERT( ARGUMENT_PRESENT( LargeInteger ) );
*LargeInteger = afdInfo->Information.LargeInteger;
}
FREE_HEAP( afdInfo );
return NO_ERROR;
} // SockGetInformation
INT
SockSetInformation (
IN PSOCKET_INFORMATION Socket,
IN ULONG InformationType,
IN PBOOLEAN Boolean OPTIONAL,
IN PULONG Ulong OPTIONAL,
IN PLARGE_INTEGER LargeInteger OPTIONAL
)
{
NTSTATUS status;
IO_STATUS_BLOCK ioStatusBlock;
AFD_INFORMATION afdInfo;
//
// Set up the AFD information block.
//
afdInfo.InformationType = InformationType;
if ( ARGUMENT_PRESENT( Boolean ) ) {
afdInfo.Information.Boolean = *Boolean;
} else if ( ARGUMENT_PRESENT( Ulong ) ) {
afdInfo.Information.Ulong = *Ulong;
} else {
ASSERT( ARGUMENT_PRESENT( LargeInteger ) );
afdInfo.Information.LargeInteger = *LargeInteger;
}
//
// Set the blocking mode to AFD.
//
status = NtDeviceIoControlFile(
(HANDLE)Socket->Handle,
SockThreadEvent,
NULL, // APC Routine
NULL, // APC Context
&ioStatusBlock,
IOCTL_AFD_SET_INFORMATION,
&afdInfo,
sizeof(afdInfo),
NULL,
0
);
//
// Wait for the operation to complete.
//
if ( status == STATUS_PENDING ) {
SockWaitForSingleObject(
SockThreadEvent,
Socket->Handle,
SOCK_NEVER_CALL_BLOCKING_HOOK,
SOCK_NO_TIMEOUT
);
status = ioStatusBlock.Status;
}
if ( !NT_SUCCESS(status) ) {
return SockNtStatusToSocketError( status );
}
return NO_ERROR;
} // SockSetInformation
#if DBG
BOOLEAN WsaStartupWarning = FALSE;
#endif
int
SockEnterApi (
IN BOOLEAN MustBeStarted,
IN BOOLEAN BlockingIllegal,
IN BOOLEAN GetXByYCall
)
{
PWINSOCK_TLS_DATA tlsData;
//
// Bail if we're already detached from the process.
//
if( SockProcessTerminating ) {
IF_DEBUG(ENTER) {
WS_PRINT(( "SockEnterApi: process terminating\n" ));
}
return WSANOTINITIALISED;
}
//
// Make sure that WSAStartup has been called, if necessary.
//
if ( MustBeStarted && (SockWspStartupCount == 0 || SockTerminating) ) {
IF_DEBUG(ENTER) {
WS_PRINT(( "SockEnterApi: WSAStartup() not called!\n" ));
}
return WSANOTINITIALISED;
}
//
// If this thread has not been initialized, do it now.
//
tlsData = GET_THREAD_DATA();
if ( tlsData == NULL ) {
if ( !SockThreadInitialize() ) {
IF_DEBUG(ENTER) {
WS_PRINT(( "SockEnterApi: SockThreadInitialize failed.\n" ));
}
return WSAENOBUFS;
}
tlsData = GET_THREAD_DATA();
}
//
// Make sure that we're not in a blocking call, if appropriate.
//
if ( BlockingIllegal && tlsData->IsBlocking ) {
IF_DEBUG(ENTER) {
WS_PRINT(( "SockEnterApi: in blocking call.\n" ));
}
return WSAEINPROGRESS;
}
//
// Initialize the cancelled thread variable. We'll use this to
// tell whether the operation has been cancelled.
//
tlsData->IoCancelled = FALSE;
//
// If this is a GetXByY call, set up thread variables.
//
if ( GetXByYCall ) {
// SockThreadGetXByYCancelled = FALSE;
// SockThreadProcessingGetXByY = TRUE;
}
//
// Everything's cool. Proceed.
//
return NO_ERROR;
} // SockEnterApi
#if DBG
VOID
WsAssert(
IN PVOID FailedAssertion,
IN PVOID FileName,
IN ULONG LineNumber
)
{
BOOL ok;
CHAR choice[16];
DWORD bytes;
DWORD error;
IF_DEBUG(CONSOLE) {
WS_PRINT(( "\n failed: %s\n at line %ld of %s\n",
FailedAssertion, LineNumber, FileName ));
do {
WS_PRINT(( "[B]reak/[I]gnore? " ));
bytes = sizeof(choice);
ok = ReadFile(
GetStdHandle(STD_INPUT_HANDLE),
&choice,
bytes,
&bytes,
NULL
);
if ( ok ) {
if ( toupper(choice[0]) == 'I' ) {
break;
}
if ( toupper(choice[0]) == 'B' ) {
DbgUserBreakPoint( );
}
} else {
error = GetLastError( );
}
} while ( TRUE );
return;
}
RtlAssert( FailedAssertion, FileName, LineNumber, NULL );
} // WsAssert
BOOLEAN ConsoleInitialized = FALSE;
HANDLE DebugFileHandle = INVALID_HANDLE_VALUE;
PCHAR DebugFileName = "msafd.log";
VOID
WsPrintf (
char *Format,
...
)
{
va_list arglist;
char OutputBuffer[1024];
ULONG length;
BOOL ret;
length = (ULONG)wsprintfA( OutputBuffer, "MSAFD: " );
va_start( arglist, Format );
wvsprintfA( OutputBuffer + length, Format, arglist );
va_end( arglist );
IF_DEBUG(DEBUGGER) {
DbgPrint( "%s", OutputBuffer );
}
IF_DEBUG(CONSOLE) {
if ( !ConsoleInitialized ) {
CONSOLE_SCREEN_BUFFER_INFO csbi;
COORD coord;
ConsoleInitialized = TRUE;
(VOID)AllocConsole( );
(VOID)GetConsoleScreenBufferInfo(
GetStdHandle(STD_OUTPUT_HANDLE),
&csbi
);
coord.X = (SHORT)(csbi.srWindow.Right - csbi.srWindow.Left + 1);
coord.Y = (SHORT)((csbi.srWindow.Bottom - csbi.srWindow.Top + 1) * 20);
(VOID)SetConsoleScreenBufferSize(
GetStdHandle(STD_OUTPUT_HANDLE),
coord
);
}
length = strlen( OutputBuffer );
ret = WriteFile(
GetStdHandle(STD_OUTPUT_HANDLE),
(LPVOID )OutputBuffer,
length,
&length,
NULL
);
if ( !ret ) {
DbgPrint( "WsPrintf: console WriteFile failed: %ld\n",
GetLastError( ) );
}
}
IF_DEBUG(FILE) {
if ( DebugFileHandle == INVALID_HANDLE_VALUE ) {
DebugFileHandle = CreateFile(
DebugFileName,
GENERIC_READ | GENERIC_WRITE,
FILE_SHARE_READ,
NULL,
CREATE_ALWAYS,
0,
NULL
);
}
if ( DebugFileHandle == INVALID_HANDLE_VALUE ) {
DbgPrint( "WsPrintf: Failed to open winsock debug log file %s: %ld\n",
DebugFileName, GetLastError( ) );
} else {
length = strlen( OutputBuffer );
ret = WriteFile(
DebugFileHandle,
(LPVOID )OutputBuffer,
length,
&length,
NULL
);
if ( !ret ) {
DbgPrint( "WsPrintf: file WriteFile failed: %ld\n",
GetLastError( ) );
}
}
}
} // WsPrintf
#endif
VOID
WsPrintSockaddr (
IN PSOCKADDR Sockaddr,
IN PINT SockaddrLength
)
{
#if DBG
if ( Sockaddr == NULL ) {
WS_PRINT(( " NULL addr pointer.\n" ));
return;
}
if ( SockaddrLength == NULL ) {
WS_PRINT(( " NULL addrlen pointer.\n" ));
return;
}
switch ( Sockaddr->sa_family) {
case AF_INET: {
PSOCKADDR_IN sockaddrIn = (PSOCKADDR_IN)Sockaddr;
if ( *SockaddrLength < sizeof(SOCKADDR_IN) ) {
WS_PRINT(( " SHORT AF_INET: len %ld\n", *SockaddrLength ));
return;
}
WS_PRINT(( " IP %ld.%ld.%ld.%ld port %ld\n",
sockaddrIn->sin_addr.S_un.S_un_b.s_b1,
sockaddrIn->sin_addr.S_un.S_un_b.s_b2,
sockaddrIn->sin_addr.S_un.S_un_b.s_b3,
sockaddrIn->sin_addr.S_un.S_un_b.s_b4,
sockaddrIn->sin_port ));
return;
}
default:
WS_PRINT(( " family %lx\n", Sockaddr->sa_family ));
break;
}
return;
#endif
} // WsPrintSockaddr
#if DBG
VOID
WsEnterApiCall (
IN PCHAR RoutineName,
IN PVOID Arg1,
IN PVOID Arg2,
IN PVOID Arg3,
IN PVOID Arg4
)
{
ULONG i;
CHECK_HEAP;
//
// If this thread has not been initialized, do it now. This is
// duplicated in SockEnterApi(), but we need it here to
// access SockIndentLevel below.
//
if ( GET_THREAD_DATA() == NULL ) {
if ( SockProcessTerminating ||
!SockThreadInitialize() ) {
return;
}
}
IF_DEBUG(ENTER) {
for ( i = 0; i < SockIndentLevel; i++ ) {
WS_PRINT(( " " ));
}
WS_PRINT(( "---> %s() args 0x%lx 0x%lx 0x%lx 0x%lx\n",
RoutineName, Arg1, Arg2, Arg3, Arg4 ));
}
SockIndentLevel++;
return;
} // WsEnter
struct _ERROR_STRINGS {
INT ErrorCode;
PCHAR ErrorString;
} ErrorStrings[] = {
(WSABASEERR+4), "WSAEINTR",
(WSABASEERR+9), "WSAEBADF",
(WSABASEERR+13), "WSAEACCES",
(WSABASEERR+14), "WSAEFAULT",
(WSABASEERR+22), "WSAEINVAL",
(WSABASEERR+24), "WSAEMFILE",
(WSABASEERR+35), "WSAEWOULDBLOCK",
(WSABASEERR+36), "WSAEINPROGRESS",
(WSABASEERR+37), "WSAEALREADY",
(WSABASEERR+38), "WSAENOTSOCK",
(WSABASEERR+39), "WSAEDESTADDRREQ",
(WSABASEERR+40), "WSAEMSGSIZE",
(WSABASEERR+41), "WSAEPROTOTYPE",
(WSABASEERR+42), "WSAENOPROTOOPT",
(WSABASEERR+43), "WSAEPROTONOSUPPORT",
(WSABASEERR+44), "WSAESOCKTNOSUPPORT",
(WSABASEERR+45), "WSAEOPNOTSUPP",
(WSABASEERR+46), "WSAEPFNOSUPPORT",
(WSABASEERR+47), "WSAEAFNOSUPPORT",
(WSABASEERR+48), "WSAEADDRINUSE",
(WSABASEERR+49), "WSAEADDRNOTAVAIL",
(WSABASEERR+50), "WSAENETDOWN",
(WSABASEERR+51), "WSAENETUNREACH",
(WSABASEERR+52), "WSAENETRESET",
(WSABASEERR+53), "WSAECONNABORTED",
(WSABASEERR+54), "WSAECONNRESET",
(WSABASEERR+55), "WSAENOBUFS",
(WSABASEERR+56), "WSAEISCONN",
(WSABASEERR+57), "WSAENOTCONN",
(WSABASEERR+58), "WSAESHUTDOWN",
(WSABASEERR+59), "WSAETOOMANYREFS",
(WSABASEERR+60), "WSAETIMEDOUT",
(WSABASEERR+61), "WSAECONNREFUSED",
(WSABASEERR+62), "WSAELOOP",
(WSABASEERR+63), "WSAENAMETOOLONG",
(WSABASEERR+64), "WSAEHOSTDOWN",
(WSABASEERR+65), "WSAEHOSTUNREACH",
(WSABASEERR+66), "WSAENOTEMPTY",
(WSABASEERR+67), "WSAEPROCLIM",
(WSABASEERR+68), "WSAEUSERS",
(WSABASEERR+69), "WSAEDQUOT",
(WSABASEERR+70), "WSAESTALE",
(WSABASEERR+71), "WSAEREMOTE",
(WSABASEERR+101), "WSAEDISCON",
(WSABASEERR+91), "WSASYSNOTREADY",
(WSABASEERR+92), "WSAVERNOTSUPPORTED",
(WSABASEERR+93), "WSANOTINITIALISED",
NO_ERROR, "NO_ERROR"
};
PCHAR
WsGetErrorString (
IN INT Error
)
{
INT i;
for ( i = 0; ErrorStrings[i].ErrorCode != NO_ERROR; i++ ) {
if ( ErrorStrings[i].ErrorCode == Error ) {
return ErrorStrings[i].ErrorString;
}
}
return "Unknown";
} // WsGetErrorString
VOID
WsExitApiCall (
IN PCHAR RoutineName,
IN INT ReturnValue,
IN BOOLEAN Failed
)
{
ULONG i;
INT error = GetLastError( );
if( SockProcessTerminating ||
GET_THREAD_DATA() == NULL ) {
SetLastError( error );
return;
}
CHECK_HEAP;
SockIndentLevel--;
IF_DEBUG(EXIT) {
for ( i = 0; i < SockIndentLevel; i++ ) {
WS_PRINT(( " " ));
}
if ( !Failed ) {
WS_PRINT(( "<--- %s() returning %ld (0x%lx)\n",
RoutineName, ReturnValue, ReturnValue ));
} else {
PSZ errorString = WsGetErrorString( error );
WS_PRINT(( "<--- %s() FAILED--error %ld (0x%lx) == %s\n",
RoutineName, error, error, errorString ));
}
}
SetLastError( error );
return;
} // WsExitApiCall
LIST_ENTRY SockHeapListHead;
ULONG SockTotalAllocations = 0;
ULONG SockTotalFrees = 0;
ULONG SockTotalBytesAllocated = 0;
RTL_RESOURCE SocketHeapLock;
BOOLEAN SockHeapDebugInitialized = FALSE;
BOOLEAN SockDebugHeap = FALSE;
PVOID SockHeap = NULL;
PVOID SockCaller1;
PVOID SockCaller2;
BOOLEAN SockDoHeapCheck = TRUE;
BOOLEAN SockDoubleHeapCheck = FALSE;
#define WINSOCK_HEAP_CODE_1 0xabcdef00
#define WINSOCK_HEAP_CODE_2 0x12345678
#define WINSOCK_HEAP_CODE_3 0x87654321
#define WINSOCK_HEAP_CODE_4 0x00fedcba
#define WINSOCK_HEAP_CODE_5 0xa1b2c3d4
typedef struct _SOCK_HEAP_HEADER {
ULONG HeapCode1;
ULONG HeapCode2;
LIST_ENTRY GlobalHeapListEntry;
PCHAR FileName;
ULONG LineNumber;
ULONG Size;
ULONG Pad;
} SOCK_HEAP_HEADER, *PSOCK_HEAP_HEADER;
typedef struct _SOCK_HEAP_TAIL {
PSOCK_HEAP_HEADER Header;
ULONG HeapCode3;
ULONG HeapCode4;
ULONG HeapCode5;
} SOCK_HEAP_TAIL, *PSOCK_HEAP_TAIL;
#define FREE_LIST_SIZE 64
SOCK_HEAP_HEADER SockRecentFreeList[FREE_LIST_SIZE];
ULONG SockRecentFreeListIndex = 0;
VOID
SockInitializeDebugData (
VOID
)
{
RtlInitializeResource( &SocketHeapLock );
InitializeListHead( &SockHeapListHead );
} // SockInitializeDebugData
PVOID
SockAllocateHeap (
IN ULONG NumberOfBytes,
PCHAR FileName,
ULONG LineNumber
)
{
PSOCK_HEAP_HEADER header;
SOCK_HEAP_TAIL UNALIGNED *tail;
SOCK_HEAP_TAIL localTail;
//WS_ASSERT( !SockProcessTerminating );
WS_ASSERT( (NumberOfBytes & 0xF0000000) == 0 );
WS_ASSERT( SockPrivateHeap != NULL );
SockCheckHeap( );
RtlAcquireResourceExclusive( &SocketHeapLock, TRUE );
header = RtlAllocateHeap( SockPrivateHeap, 0,
NumberOfBytes + sizeof(*header) + sizeof(*tail) );
if ( header == NULL ) {
RtlReleaseResource( &SocketHeapLock );
if( SockDoubleHeapCheck ) {
SockCheckHeap();
}
return NULL;
}
header->HeapCode1 = WINSOCK_HEAP_CODE_1;
header->HeapCode2 = WINSOCK_HEAP_CODE_2;
header->FileName = FileName;
header->LineNumber = LineNumber;
header->Size = NumberOfBytes;
tail = (SOCK_HEAP_TAIL UNALIGNED *)( (PCHAR)(header + 1) + NumberOfBytes );
localTail.Header = header;
localTail.HeapCode3 = WINSOCK_HEAP_CODE_3;
localTail.HeapCode4 = WINSOCK_HEAP_CODE_4;
localTail.HeapCode5 = WINSOCK_HEAP_CODE_5;
SockCopyMemory(
tail,
&localTail,
sizeof(localTail)
);
InsertTailList( &SockHeapListHead, &header->GlobalHeapListEntry );
SockTotalAllocations++;
SockTotalBytesAllocated += header->Size;
RtlReleaseResource( &SocketHeapLock );
if( SockDoubleHeapCheck ) {
SockCheckHeap();
}
return (PVOID)(header + 1);
} // SockAllocateHeap
VOID
SockFreeHeap (
IN PVOID Pointer
)
{
PSOCK_HEAP_HEADER header = (PSOCK_HEAP_HEADER)Pointer - 1;
SOCK_HEAP_TAIL UNALIGNED * tail;
SOCK_HEAP_TAIL localTail;
//WS_ASSERT( !SockProcessTerminating );
WS_ASSERT( SockPrivateHeap != NULL );
SockCheckHeap( );
tail = (SOCK_HEAP_TAIL UNALIGNED *)( (PCHAR)(header + 1) + header->Size );
if ( !SockHeapDebugInitialized ) {
SockInitializeDebugData( );
SockHeapDebugInitialized = TRUE;
}
RtlAcquireResourceExclusive( &SocketHeapLock, TRUE );
SockCopyMemory(
&localTail,
tail,
sizeof(localTail)
);
WS_ASSERT( header->HeapCode1 == WINSOCK_HEAP_CODE_1 );
WS_ASSERT( header->HeapCode2 == WINSOCK_HEAP_CODE_2 );
WS_ASSERT( localTail.HeapCode3 == WINSOCK_HEAP_CODE_3 );
WS_ASSERT( localTail.HeapCode4 == WINSOCK_HEAP_CODE_4 );
WS_ASSERT( localTail.HeapCode5 == WINSOCK_HEAP_CODE_5 );
WS_ASSERT( localTail.Header == header );
RemoveEntryList( &header->GlobalHeapListEntry );
SockTotalFrees++;
SockTotalBytesAllocated -= header->Size;
//RtlMoveMemory( &SockRecentFreeList[SockRecentFreeListIndex], header, sizeof(*header ) );
//SockRecentFreeListIndex++;
//if ( SockRecentFreeListIndex >= FREE_LIST_SIZE ) {
// SockRecentFreeListIndex = 0;
//}
RtlZeroMemory( header, sizeof(*header) );
header->HeapCode1 = (ULONG)~WINSOCK_HEAP_CODE_1;
header->HeapCode2 = (ULONG)~WINSOCK_HEAP_CODE_2;
localTail.HeapCode3 = (ULONG)~WINSOCK_HEAP_CODE_3;
localTail.HeapCode4 = (ULONG)~WINSOCK_HEAP_CODE_4;
localTail.HeapCode5 = (ULONG)~WINSOCK_HEAP_CODE_5;
localTail.Header = NULL;
SockCopyMemory(
tail,
&localTail,
sizeof(localTail)
);
RtlReleaseResource( &SocketHeapLock );
RtlFreeHeap( SockPrivateHeap, 0, (PVOID)header );
if( SockDoubleHeapCheck ) {
SockCheckHeap();
}
} // SockFreeHeap
VOID
SockCheckHeap (
VOID
)
{
PLIST_ENTRY listEntry;
PLIST_ENTRY lastListEntry = NULL;
PSOCK_HEAP_HEADER header;
SOCK_HEAP_TAIL UNALIGNED *tail;
SOCK_HEAP_TAIL localTail;
if ( !SockHeapDebugInitialized ) {
SockInitializeDebugData( );
SockHeapDebugInitialized = TRUE;
//SockHeap = RtlCreateHeap( HEAP_GROWABLE, 0, 0, 0, 0, NULL );
//WS_ASSERT( SockHeap != NULL );
}
if ( !SockDoHeapCheck ) {
return;
}
RtlValidateHeap( SockPrivateHeap, 0, NULL );
RtlAcquireResourceExclusive( &SocketHeapLock, TRUE );
for ( listEntry = SockHeapListHead.Flink;
listEntry != &SockHeapListHead;
listEntry = listEntry->Flink ) {
if ( listEntry == NULL ) {
DbgPrint( "listEntry == NULL, lastListEntry == %lx\n", lastListEntry );
DbgBreakPoint( );
}
header = CONTAINING_RECORD( listEntry, SOCK_HEAP_HEADER, GlobalHeapListEntry );
tail = (SOCK_HEAP_TAIL UNALIGNED *)( (PCHAR)(header + 1) + header->Size );
SockCopyMemory(
&localTail,
tail,
sizeof(localTail)
);
if ( header->HeapCode1 != WINSOCK_HEAP_CODE_1 ) {
DbgPrint( "SockCheckHeap, fail 1, header %lx tail %lx\n", header, tail );
DbgBreakPoint( );
}
if ( header->HeapCode2 != WINSOCK_HEAP_CODE_2 ) {
DbgPrint( "SockCheckHeap, fail 2, header %lx tail %lx\n", header, tail );
DbgBreakPoint( );
}
if ( localTail.HeapCode3 != WINSOCK_HEAP_CODE_3 ) {
DbgPrint( "SockCheckHeap, fail 3, header %lx tail %lx\n", header, tail );
DbgBreakPoint( );
}
if ( localTail.HeapCode4 != WINSOCK_HEAP_CODE_4 ) {
DbgPrint( "SockCheckHeap, fail 4, header %lx tail %lx\n", header, tail );
DbgBreakPoint( );
}
if ( localTail.HeapCode5 != WINSOCK_HEAP_CODE_5 ) {
DbgPrint( "SockCheckHeap, fail 5, header %lx tail %lx\n", header, tail );
DbgBreakPoint( );
}
if ( localTail.Header != header ) {
DbgPrint( "SockCheckHeap, fail 6, header %lx tail %lx\n", header, tail );
DbgBreakPoint( );
}
lastListEntry = listEntry;
}
RtlGetCallersAddress( &SockCaller1, &SockCaller2 );
RtlReleaseResource( &SocketHeapLock );
} // SockCheckHeap
LONG
SockExceptionFilter(
LPEXCEPTION_POINTERS ExceptionPointers,
LPSTR SourceFile,
LONG LineNumber
)
{
LPSTR fileName;
//
// Protect ourselves in case the process is totally screwed.
//
try {
//
// Exceptions should never be thrown in a properly functioning
// system, so this is bad. To ensure that someone will see this,
// forcibly enable debugger output if none of the output bits are
// enabled.
//
if( ( WsDebug & ( WINSOCK_DEBUG_CONSOLE |
WINSOCK_DEBUG_FILE |
WINSOCK_DEBUG_DEBUGGER ) ) == 0 ) {
WsDebug |= WINSOCK_DEBUG_DEBUGGER;
}
//
// Strip off the path from the source file.
//
fileName = strrchr( SourceFile, '\\' );
if( fileName == NULL ) {
fileName = SourceFile;
} else {
fileName++;
}
//
// Whine about the exception.
//
WS_PRINT((
"SockExceptionFilter: exception %08lx @ %08lx, caught in %s:%d\n",
ExceptionPointers->ExceptionRecord->ExceptionCode,
ExceptionPointers->ExceptionRecord->ExceptionAddress,
fileName,
LineNumber
));
} except( EXCEPTION_EXECUTE_HANDLER ) {
//
// Not much we can do here...
//
NOTHING;
}
return EXCEPTION_EXECUTE_HANDLER;
} // SockExceptionFilter
#endif // if DBG
#if DEBUG_LOCKS
//
// Critical section debugging code.
//
VOID
SockInitializeCriticalSection(
OUT PSOCK_CRITICAL_SECTION Lock
)
{
RtlZeroMemory(
Lock,
sizeof(*Lock)
);
InitializeCriticalSection( &Lock->Lock );
} // SockInitializeCriticalSection
VOID
SockDeleteCriticalSection(
OUT PSOCK_CRITICAL_SECTION Lock
)
{
DeleteCriticalSection( &Lock->Lock );
} // SockDeleteCriticalSection
VOID
SockpEnterCriticalSection(
IN PSOCK_CRITICAL_SECTION Lock,
IN PSTR FileName,
IN LONG LineNumber,
IN PVOID Caller,
IN PVOID CallersCaller
)
{
LONG index;
PSOCK_CRITICAL_SECTION_DEBUG_INFO slot;
WS_ASSERT( !SockProcessTerminating );
EnterCriticalSection( &Lock->Lock );
Lock->AcquireCount++;
index = ( Lock->DebugSlot++ ) % MAX_CRITICAL_SECTION_DEBUG;
slot = &Lock->DebugInfo[index];
slot->FileName = FileName;
slot->LineNumber = LineNumber;
slot->Caller = Caller;
slot->CallersCaller = CallersCaller;
} // SockpEnterCriticalSection
VOID
SockpLeaveCriticalSection(
IN PSOCK_CRITICAL_SECTION Lock,
IN PSTR FileName,
IN LONG LineNumber,
IN PVOID Caller,
IN PVOID CallersCaller
)
{
LONG index;
PSOCK_CRITICAL_SECTION_DEBUG_INFO slot;
WS_ASSERT( !SockProcessTerminating );
Lock->ReleaseCount++;
index = ( Lock->DebugSlot++ ) % MAX_CRITICAL_SECTION_DEBUG;
slot = &Lock->DebugInfo[index];
slot->FileName = FileName;
slot->LineNumber = LineNumber | 0x80000000;
slot->Caller = Caller;
slot->CallersCaller = CallersCaller;
LeaveCriticalSection( &Lock->Lock );
} // SockpLeaveCriticalSection
VOID
SockAcquireGlobalLockHelper(
IN PSTR FileName,
IN LONG LineNumber
)
{
PVOID Caller;
PVOID CallersCaller;
RtlGetCallersAddress(
&Caller,
&CallersCaller
);
SockpEnterCriticalSection(
&SocketLock,
FileName,
LineNumber,
Caller,
CallersCaller
);
} // SockAcquireGlobalLockHelper
VOID
SockReleaseGlobalLockHelper(
IN PSTR FileName,
IN LONG LineNumber
)
{
PVOID Caller;
PVOID CallersCaller;
RtlGetCallersAddress(
&Caller,
&CallersCaller
);
SockpLeaveCriticalSection(
&SocketLock,
FileName,
LineNumber,
Caller,
CallersCaller
);
} // SockReleaseGlobalLockHelper
#endif // DEBUG_LOCKS
VOID
WINAPI
SockIoCompletion (
PVOID ApcContext,
PIO_STATUS_BLOCK IoStatusBlock,
DWORD Reserved
)
/*++
Routine Description:
This procedure is called to complete WSARecv, WSARecvFrom, WSASend,
WSASendTo, and WSAIoctl asynchronous I/O operations. Its primary
function is to extract the appropriate information from the passed
IoStatusBlock and call the user's completion routine.
The users completion routine is called as:
Routine Description:
When an outstanding I/O completes with a callback, this
function is called. This function is only called while the
thread is in an alertable wait (SleepEx,
WaitForSingleObjectEx, or WaitForMultipleObjectsEx with the
bAlertable flag set to TRUE). Returning from this function
allows another pendiong I/O completion callback to be
processed. If this is the case, this callback is entered
before the termination of the thread's wait with a return
code of WAIT_IO_COMPLETION.
Note that each time your completion routine is called, the
system uses some of your stack. If you code your completion
logic to do additional ReadFileEx's and WriteFileEx's within
your completion routine, AND you do alertable waits in your
completion routine, you may grow your stack without ever
trimming it back.
Arguments:
dwErrorCode - Supplies the I/O completion status for the
related I/O. A value of 0 indicates that the I/O was
successful. Note that end of file is indicated by a
non-zero dwErrorCode value of ERROR_HANDLE_EOF.
dwNumberOfBytesTransfered - Supplies the number of bytes
transfered during the associated I/O. If an error
occured, a value of 0 is supplied.
lpOverlapped - Supplies the address of the WSAOVERLAPPED
structure used to initiate the associated I/O. The
hEvent field of this structure is not used by the system
and may be used by the application to provide additional
I/O context. Once a completion routine is called, the
system will not use the WSAOVERLAPPED structure. The
completion routine is free to deallocate the overlapped
structure.
Arguments:
ApcContext - Supplies the users completion routine. The format of
this routine is an LPWSAOVERLAPPED_COMPLETION_ROUTINE.
IoStatusBlock - Supplies the address of the IoStatusBlock that
contains the I/O completion status. The IoStatusBlock is
contained within the WSAOVERLAPPED structure.
Reserved - Not used; reserved for future use.
Return Value:
None.
--*/
{
LPWSAOVERLAPPED_COMPLETION_ROUTINE CompletionRoutine;
DWORD dwErrorCode;
DWORD dwNumberOfBytesTransfered;
DWORD dwFlags;
LPWSAOVERLAPPED lpOverlapped;
UNREFERENCED_PARAMETER( Reserved);
dwErrorCode = 0;
dwFlags = 0;
if( NT_ERROR(IoStatusBlock->Status) ) {
dwErrorCode = SockNtStatusToSocketError(IoStatusBlock->Status);
dwNumberOfBytesTransfered = 0;
} else {
dwErrorCode = 0;
dwNumberOfBytesTransfered = IoStatusBlock->Information;
//
// Set up the ReceiveFlags output parameter based on the type
// of receive.
//
switch( IoStatusBlock->Status ) {
case STATUS_BUFFER_OVERFLOW:
case STATUS_RECEIVE_PARTIAL:
dwFlags = MSG_PARTIAL;
break;
case STATUS_RECEIVE_EXPEDITED:
dwFlags = MSG_OOB;
break;
case STATUS_RECEIVE_PARTIAL_EXPEDITED:
dwFlags = MSG_PARTIAL | MSG_OOB;
break;
}
}
CompletionRoutine = (LPWSAOVERLAPPED_COMPLETION_ROUTINE)ApcContext;
lpOverlapped = (LPWSAOVERLAPPED)CONTAINING_RECORD(
IoStatusBlock,
WSAOVERLAPPED,
Internal
);
(CompletionRoutine)(
dwErrorCode,
dwNumberOfBytesTransfered,
lpOverlapped,
dwFlags
);
} // SockIoCompletion
BOOL
SockThreadInitialize(
VOID
)
{
PWINSOCK_TLS_DATA data;
HANDLE threadEvent;
NTSTATUS status;
IF_DEBUG(INIT) {
WS_PRINT(( "SockThreadInitialize: TEB = %lx\n",
NtCurrentTeb( ) ));
}
//
// Create the thread's event.
//
status = NtCreateEvent(
&threadEvent,
EVENT_ALL_ACCESS,
NULL,
NotificationEvent,
FALSE
);
if ( !NT_SUCCESS(status) ) {
WS_PRINT(( "SockThreadInitialize: NtCreateEvent failed: %X\n", status ));
return FALSE;
}
//
// Allocate space for per-thread data the DLL will have.
//
data = ALLOCATE_THREAD_DATA( sizeof(*data) );
if ( data == NULL ) {
WS_PRINT(( "SockThreadInitialize: unable to allocate thread data.\n" ));
return FALSE;
}
//
// Store a pointer to this data area in TLS.
//
if( !SET_THREAD_DATA(data) ) {
WS_PRINT(( "SockThreadInitialize: TlsSetValue failed: %ld\n", GetLastError( ) ));
#if !defined(USE_TEB_FIELD)
SockTlsSlot = 0xFFFFFFFF;
#endif // !USE_TEB_FIELD
return FALSE;
}
//
// Initialize the thread data.
//
RtlZeroMemory( data, sizeof(*data) );
#if DBG
SockIndentLevel = 0;
#endif
SockThreadSocketHandle = INVALID_SOCKET;
SockThreadEvent = threadEvent;
return TRUE;
} // SockThreadInitialize
BOOLEAN
SockIsAddressConsistentWithConstrainedGroup(
IN PSOCKET_INFORMATION Socket,
IN GROUP Group,
IN PSOCKADDR SocketAddress,
IN INT SocketAddressLength
)
/*++
Routine Description:
Searches all open sockets, validating that the specified address is
consistent with all sockets associated to the specified constrained
group identifier.
Arguments:
Socket - An open socket handle. Used just to get use "into" AFD
where the real work is done.
Group - The constrained group identifier.
SocketAddress - The socket address to check.
SocketAddressLength - The length of SocketAddress.
Return Value:
BOOLEAN - TRUE if the address is consistent, FALSE otherwise.
--*/
{
NTSTATUS status;
IO_STATUS_BLOCK ioStatusBlock;
PAFD_VALIDATE_GROUP_INFO validateInfo;
ULONG validateInfoLength;
UCHAR validateInfoBuffer[sizeof(AFD_VALIDATE_GROUP_INFO) + MAX_FAST_TDI_ADDRESS];
WS_ASSERT( Socket != NULL );
WS_ASSERT( Group != 0 );
WS_ASSERT( Group != SG_UNCONSTRAINED_GROUP );
WS_ASSERT( Group != SG_CONSTRAINED_GROUP );
WS_ASSERT( SocketAddress != NULL );
WS_ASSERT( SocketAddressLength > 0 );
//
// Allocate enough space to hold the TDI address structure we'll pass
// to AFD. Note that is the address is small enough, we just use
// an automatic in order to improve performance.
//
validateInfo = (PAFD_VALIDATE_GROUP_INFO)validateInfoBuffer;
validateInfoLength = sizeof(AFD_VALIDATE_GROUP_INFO) -
sizeof(TRANSPORT_ADDRESS) +
Socket->HelperDll->MaxTdiAddressLength;
if( validateInfoLength > sizeof(validateInfoBuffer) ) {
validateInfo = ALLOCATE_HEAP( validateInfoLength );
if( validateInfo == NULL ) {
return FALSE;
}
}
//
// Convert the address from the sockaddr structure to the appropriate
// TDI structure.
//
SockBuildTdiAddress(
&validateInfo->RemoteAddress,
SocketAddress,
SocketAddressLength
);
//
// Let AFD do the dirty work.
//
validateInfo->GroupID = (LONG)Group;
status = NtDeviceIoControlFile(
(HANDLE)Socket->Handle,
SockThreadEvent,
NULL, // APC Routine
NULL, // APC Context
&ioStatusBlock,
IOCTL_AFD_VALIDATE_GROUP,
validateInfo,
validateInfoLength,
NULL,
0
);
if( status == STATUS_PENDING ) {
SockWaitForSingleObject(
SockThreadEvent,
Socket->Handle,
SOCK_NEVER_CALL_BLOCKING_HOOK,
SOCK_NO_TIMEOUT
);
status = ioStatusBlock.Status;
}
if( validateInfo != (PAFD_VALIDATE_GROUP_INFO)validateInfoBuffer ) {
FREE_HEAP( validateInfo );
}
if( !NT_SUCCESS(status) ) {
return FALSE;
}
//
// Success!
//
return TRUE;
} // SockIsAddressConsistentWithConstrainedGroup
VOID
SockCancelIo(
IN SOCKET Socket
)
/*++
Routine Description:
Cancels all IO on the specific socket initiated by the current thread.
Arguments:
Socket - The socket to cancel.
Return Value:
None.
--*/
{
IO_STATUS_BLOCK ioStatusBlock;
NTSTATUS status;
WS_ASSERT( Socket != (SOCKET)NULL );
WS_ASSERT( Socket != INVALID_SOCKET );
status = NtCancelIoFile(
(HANDLE)Socket,
&ioStatusBlock
);
WS_ASSERT( status != STATUS_PENDING );
if( !NT_SUCCESS(status) || !NT_SUCCESS(ioStatusBlock.Status) ) {
WS_PRINT((
"SockCancelIo: NtCancelIoFile() failed, %08lx:%08lx\n",
status,
ioStatusBlock.Status
));
}
} // SockCancelIo
VOID
SockBuildProtocolInfoForSocket(
IN PSOCKET_INFORMATION Socket,
OUT LPWSAPROTOCOL_INFOW ProtocolInfo
)
{
WS_ASSERT( Socket != NULL );
WS_ASSERT( ProtocolInfo != NULL );
RtlZeroMemory(
ProtocolInfo,
sizeof(*ProtocolInfo)
);
ProtocolInfo->dwCatalogEntryId = Socket->CatalogEntryId;
ProtocolInfo->iVersion = 2;
ProtocolInfo->iAddressFamily = Socket->AddressFamily;
ProtocolInfo->iMaxSockAddr = Socket->HelperDll->MaxSockaddrLength;
ProtocolInfo->iMinSockAddr = Socket->HelperDll->MinSockaddrLength;
ProtocolInfo->iSocketType = Socket->SocketType;
ProtocolInfo->iProtocol = Socket->Protocol;
ProtocolInfo->iNetworkByteOrder = BIGENDIAN;
ProtocolInfo->iSecurityScheme = SECURITY_PROTOCOL_NONE;
ProtocolInfo->dwServiceFlags1 = Socket->ServiceFlags1;
ProtocolInfo->dwProviderFlags = Socket->ProviderFlags;
//
// !!! The following fields are not set, because I don't know
// where to get this data (yet).
//
// ProtocolInfo->dwProviderFlags = ?
// ProtocolInfo->ProviderId = ?
// ProtocolInfo->dwMessageSize = ?
// ProtocolInfo->szProtocol = ?
// ProtocolInfo->dwProviderReserved = ?
//
} // SockBuildProtocolInfoForSocket