|
|
//+---------------------------------------------------------------------------
//
// Microsoft Windows
// Copyright (C) Microsoft Corporation, 1992 - 1997.
//
// File: lpcsvr.c
//
// Contents:
//
// Classes:
//
// Functions:
//
// History: 12-12-97 RichardW Created
//
//----------------------------------------------------------------------------
#include <ntos.h>
#include <nt.h>
#include <ntrtl.h>
#include <nturtl.h>
#include "lpcsvr.h"
#define RtlpLpcLockServer( s ) RtlEnterCriticalSection( &s->Lock );
#define RtlpLpcUnlockServer( s ) RtlLeaveCriticalSection( &s->Lock );
#define RtlpLpcContextFromClient( p ) ( CONTAINING_RECORD( p, LPCSVR_CONTEXT, PrivateContext ) )
//+---------------------------------------------------------------------------
//
// Function: RtlpLpcDerefContext
//
// Synopsis: Deref the context. If this context is being cleaned up after
// the server has been deleted, then the message is freed directly,
// rather than being released to the general queue.
//
// Arguments: [Context] --
// [Message] --
//
// History: 2-06-98 RichardW Created
//
// Notes:
//
//----------------------------------------------------------------------------
VOID RtlpLpcDerefContext( PLPCSVR_CONTEXT Context, PLPCSVR_MESSAGE Message ) { PLPCSVR_SERVER Server ;
Server = Context->Server ;
if ( InterlockedDecrement( &Context->RefCount ) < 0 ) { //
// All gone, time to clean up:
//
RtlpLpcLockServer( Server );
if ( Context->List.Flink ) { RemoveEntryList( &Context->List );
Server->ContextCount -- ;
} else { if ( Message ) { RtlFreeHeap( RtlProcessHeap(), 0, Message ); } }
RtlpLpcUnlockServer( Server );
if ( Context->CommPort ) { NtClose( Context->CommPort ); }
RtlFreeHeap( RtlProcessHeap(), 0, Context ); } else { RtlpLpcLockServer( Server );
Server->MessagePoolSize++ ;
if ( Server->MessagePoolSize < Server->MessagePoolLimit ) { Message->Header.Next = Server->MessagePool ;
Server->MessagePool = Message ; } else { Server->MessagePoolSize-- ;
RtlFreeHeap( RtlProcessHeap(), 0, Message );
}
RtlpLpcUnlockServer( Server ); }
}
//+---------------------------------------------------------------------------
//
// Function: RtlpLpcWorkerThread
//
// Synopsis: General worker thread
//
// Arguments: [Parameter] --
//
// History: 2-06-98 RichardW Created
//
// Notes:
//
//----------------------------------------------------------------------------
VOID RtlpLpcWorkerThread( PVOID Parameter ) { PLPCSVR_MESSAGE Message ; PLPCSVR_CONTEXT Context ; NTSTATUS Status ; BOOLEAN Accept ;
Message = (PLPCSVR_MESSAGE) Parameter ;
Context = Message->Header.Context ;
switch ( Message->Message.u2.s2.Type & 0xF ) { case LPC_REQUEST: case LPC_DATAGRAM: DbgPrint("Calling Server's Request function\n"); Status = Context->Server->Init.RequestFn( &Context->PrivateContext, &Message->Message, &Message->Message );
if ( NT_SUCCESS( Status ) ) { Status = NtReplyPort( Context->CommPort, &Message->Message );
if ( !NT_SUCCESS( Status ) ) { //
// See what happened. The client may have gone away already.
//
break;
} } break;
case LPC_CONNECTION_REQUEST: DbgPrint("Calling Server's Connect function\n"); Status = Context->Server->Init.ConnectFn( &Context->PrivateContext, &Message->Message, &Accept );
//
// If the comm port is still null, then do the accept. Otherwise, the
// server called RtlAcceptConnectPort() explicitly, to set up a view.
//
if ( NT_SUCCESS( Status ) ) { if ( Context->CommPort == NULL ) { Status = NtAcceptConnectPort( &Context->CommPort, Context, &Message->Message, Accept, NULL, NULL );
if ( !Accept ) { //
// Yank the context out of the list, since it is worthless
//
Context->RefCount = 0 ;
} else { Status = NtCompleteConnectPort( Context->CommPort ); } }
} else { Status = NtAcceptConnectPort( &Context->CommPort, NULL, &Message->Message, FALSE, NULL, NULL );
Context->RefCount = 0 ;
}
break;
case LPC_CLIENT_DIED: DbgPrint( "Calling Server's Rundown function\n" ); Status = Context->Server->Init.RundownFn( &Context->PrivateContext, &Message->Message );
InterlockedDecrement( &Context->RefCount );
break;
default: //
// An unexpected message came through. Normal LPC servers
// don't handle the other types of messages. Drop it.
//
break; }
RtlpLpcDerefContext( Context, Message );
return ;
}
VOID RtlpLpcServerCallback( PVOID Parameter, BOOLEAN TimedOut ) { PLPCSVR_SERVER Server ; NTSTATUS Status ; PLPCSVR_MESSAGE Message ; PLPCSVR_CONTEXT Context ; PLARGE_INTEGER RealTimeout ; LPCSVR_FILTER_RESULT FilterResult ;
Server = (PLPCSVR_SERVER) Parameter ;
if ( Server->WaitHandle ) { Server->WaitHandle = NULL ; }
while ( 1 ) { DbgPrint("Entering LPC server\n" );
RtlpLpcLockServer( Server );
if ( Server->Flags & LPCSVR_SHUTDOWN_PENDING ) { break; }
if ( Server->MessagePool ) { Message = Server->MessagePool ; Server->MessagePool = Message->Header.Next ; } else { Message = RtlAllocateHeap( RtlProcessHeap(), 0, Server->MessageSize );
}
RtlpLpcUnlockServer( Server );
if ( !Message ) { LARGE_INTEGER SleepInterval ;
SleepInterval.QuadPart = 125 * 10000 ;
NtDelayExecution( FALSE, &SleepInterval ); continue; }
if ( Server->Timeout.QuadPart ) { RealTimeout = &Server->Timeout ; } else { RealTimeout = NULL ; }
Status = NtReplyWaitReceivePortEx( Server->Port, &Context, NULL, &Message->Message, RealTimeout );
DbgPrint("Server: NtReplyWaitReceivePort completed with %x\n", Status );
if ( NT_SUCCESS( Status ) ) { //
// If we timed out, nobody was waiting for us:
//
if ( Status == STATUS_TIMEOUT ) { //
// Set up a general wait that will call back to this function
// when ready.
//
RtlpLpcLockServer( Server );
if ( ( Server->Flags & LPCSVR_SHUTDOWN_PENDING ) == 0 ) {
Status = RtlRegisterWait( &Server->WaitHandle, Server->Port, RtlpLpcServerCallback, Server, 0xFFFFFFFF, WT_EXECUTEONLYONCE ); }
RtlpLpcUnlockServer( Server );
break;
}
if ( Status == STATUS_SUCCESS ) { if ( Context ) { InterlockedIncrement( &Context->RefCount ); } else { //
// New connection. Create a new context record
//
Context = RtlAllocateHeap( RtlProcessHeap(), 0, sizeof( LPCSVR_CONTEXT ) + Server->Init.ContextSize );
if ( !Context ) { HANDLE Bogus ;
Status = NtAcceptConnectPort( &Bogus, NULL, &Message->Message, FALSE, NULL, NULL );
RtlpLpcLockServer( Server );
Message->Header.Next = Server->MessagePool ; Server->MessagePool = Message ;
RtlpLpcUnlockServer( Server );
continue; }
Context->Server = Server ; Context->RefCount = 1 ; Context->CommPort = NULL ;
RtlpLpcLockServer( Server );
InsertTailList( &Server->ContextList, &Context->List ); Server->ContextCount++ ;
RtlpLpcUnlockServer( Server ); }
Message->Header.Context = Context ;
FilterResult = LpcFilterAsync ;
if ( Server->Init.FilterFn ) { FilterResult = Server->Init.FilterFn( Context, &Message->Message );
if (FilterResult == LpcFilterDrop ) { RtlpLpcDerefContext( Context, Message );
continue;
} }
if ( (Server->Flags & LPCSVR_SYNCHRONOUS) || (FilterResult == LpcFilterSync) ) { RtlpLpcWorkerThread( Message ); } else { RtlQueueWorkItem( RtlpLpcWorkerThread, Message, 0 );
} } } else { //
// Error? Better shut down...
//
break; }
}
}
NTSTATUS RtlCreateLpcServer( POBJECT_ATTRIBUTES PortName, PLPCSVR_INITIALIZE Init, PLARGE_INTEGER IdleTimeout, ULONG MessageSize, ULONG Options, PVOID * LpcServer ) { PLPCSVR_SERVER Server ; NTSTATUS Status ; HANDLE Thread ; CLIENT_ID Id ;
*LpcServer = NULL ;
Server = RtlAllocateHeap( RtlProcessHeap(), 0, sizeof( LPCSVR_SERVER ) );
if ( !Server ) { return STATUS_INSUFFICIENT_RESOURCES; }
Status = RtlInitializeCriticalSectionAndSpinCount (&Server->Lock, 1000); if (!NT_SUCCESS (Status)) { RtlFreeHeap( RtlProcessHeap(), 0, Server ); return Status; }
InitializeListHead( &Server->ContextList ); Server->ContextCount = 0;
Server->Init = *Init; if ( !IdleTimeout ) { Server->Timeout.QuadPart = 0; } else { Server->Timeout = *IdleTimeout; }
Server->MessageSize = MessageSize + sizeof( LPCSVR_MESSAGE ) - sizeof( PORT_MESSAGE );
Server->MessagePool = 0; Server->MessagePoolSize = 0; Server->MessagePoolLimit = 4;
Server->Flags = Options;
//
// Create the LPC port:
//
Status = NtCreateWaitablePort( &Server->Port, PortName, MessageSize, MessageSize, MessageSize * 4 );
if ( !NT_SUCCESS( Status ) ) { RtlDeleteCriticalSection( &Server->Lock ); RtlFreeHeap( RtlProcessHeap(), 0, Server ); return Status; }
//
// Now, post the handle over to a wait queue
//
Status = RtlRegisterWait( &Server->WaitHandle, Server->Port, RtlpLpcServerCallback, Server, 0xFFFFFFFF, WT_EXECUTEONLYONCE );
if (!NT_SUCCESS (Status)) { NtClose (Server->Port); RtlDeleteCriticalSection( &Server->Lock ); RtlFreeHeap( RtlProcessHeap(), 0, Server ); return Status; }
*LpcServer = Server; return Status; }
NTSTATUS RtlShutdownLpcServer( PVOID LpcServer ) { PLPCSVR_SERVER Server ; OBJECT_ATTRIBUTES ObjA ; PLIST_ENTRY Scan ; PLPCSVR_CONTEXT Context ; PLPCSVR_MESSAGE Message ; NTSTATUS Status ;
Server = (PLPCSVR_SERVER) LpcServer ;
RtlpLpcLockServer( Server );
if ( Server->Flags & LPCSVR_SHUTDOWN_PENDING ) { RtlpLpcUnlockServer( Server );
return STATUS_PENDING ; }
if ( Server->WaitHandle ) { RtlDeregisterWait( Server->WaitHandle );
Server->WaitHandle = NULL ; }
if ( Server->Timeout.QuadPart == 0 ) { RtlpLpcUnlockServer( Server );
return STATUS_NOT_IMPLEMENTED ; }
//
// If there are receives still pending, we have to sync
// with those threads. To do so, we will tag the shutdown
// flag, and then wait the timeout amount.
//
if ( Server->ReceiveThreads != 0 ) {
InitializeObjectAttributes( &ObjA, NULL, 0, 0, 0 );
Status = NtCreateEvent( &Server->ShutdownEvent, EVENT_ALL_ACCESS, &ObjA, NotificationEvent, FALSE );
if ( !NT_SUCCESS( Status ) ) { RtlpLpcUnlockServer( Server );
return Status ;
}
Server->Flags |= LPCSVR_SHUTDOWN_PENDING ;
RtlpLpcUnlockServer( Server );
Status = NtWaitForSingleObject( Server->ShutdownEvent, FALSE, &Server->Timeout );
if ( Status == STATUS_TIMEOUT ) { //
// Hmm, the LPC server thread is hung somewhere,
// press on
//
}
RtlpLpcLockServer( Server );
NtClose( Server->ShutdownEvent );
Server->ShutdownEvent = NULL ;
} else { Server->Flags |= LPCSVR_SHUTDOWN_PENDING ; }
//
// The server object is locked, and there are no receives
// pending. Or, the receives appear hung. Skim through the
// context list, calling the server code. The disconnect
// message is NULL, indicating that this is a server initiated
// shutdown.
//
while ( ! IsListEmpty( &Server->ContextList ) ) { Scan = RemoveHeadList( &Server->ContextList );
Context = CONTAINING_RECORD( Scan, LPCSVR_CONTEXT, List );
Status = Server->Init.RundownFn( Context->PrivateContext, NULL );
Context->List.Flink = NULL ;
RtlpLpcDerefContext( Context, NULL );
}
//
// All contexts have been deleted: clean up the messages
//
while ( Server->MessagePool ) { Message = Server->MessagePool ;
Server->MessagePool = Message ;
RtlFreeHeap( RtlProcessHeap(), 0, Message ); }
//
// Clean up server objects
//
return(STATUS_SUCCESS);
}
NTSTATUS RtlImpersonateLpcClient( PVOID Context, PPORT_MESSAGE Message ) { PLPCSVR_CONTEXT LpcContext ;
LpcContext = RtlpLpcContextFromClient( Context );
return NtImpersonateClientOfPort( LpcContext->CommPort, Message );
}
NTSTATUS RtlCallbackLpcClient( PVOID Context, PPORT_MESSAGE Request, PPORT_MESSAGE Callback ) { NTSTATUS Status ; PLPCSVR_CONTEXT LpcContext ;
if ( Request != Callback ) { Callback->ClientId = Request->ClientId ; Callback->MessageId = Request->MessageId ; }
LpcContext = RtlpLpcContextFromClient( Context );
Status = NtRequestWaitReplyPort( LpcContext->CommPort, Callback, Callback );
return Status ;
}
|