/*++

Copyright (c) 1989  Microsoft Corporation

Module Name:

    blkendp.c

Abstract:

    This module implements routines for managing endpoint blocks.

Author:

    Chuck Lenzmeier (chuckl) 4-Oct-1989

Revision History:

--*/

#include "precomp.h"
#include "blkendp.tmh"
#pragma hdrstop

#define BugCheckFileId SRV_FILE_BLKENDP

#ifdef ALLOC_PRAGMA
#pragma alloc_text( PAGE, SrvAllocateEndpoint )
#pragma alloc_text( PAGE, SrvCheckAndReferenceEndpoint )
#pragma alloc_text( PAGE, SrvCloseEndpoint )
#pragma alloc_text( PAGE, SrvDereferenceEndpoint )
#pragma alloc_text( PAGE, SrvFreeEndpoint )
#pragma alloc_text( PAGE, SrvReferenceEndpoint )
#pragma alloc_text( PAGE, SrvFindNamedEndpoint )
#endif
#if 0
NOT PAGEABLE -- EmptyFreeConnectionList
NOT PAGEABLE -- WalkConnectionTable
#endif


VOID
SrvAllocateEndpoint (
    OUT PENDPOINT *Endpoint,
    IN PUNICODE_STRING NetworkName,
    IN PUNICODE_STRING TransportName,
    IN PANSI_STRING TransportAddress,
    IN PUNICODE_STRING DomainName
    )

/*++

Routine Description:

    This function allocates an Endpoint Block from the system nonpaged
    pool.

Arguments:

    Endpoint - Returns a pointer to the endpoint block, or NULL if no
        pool was available.

    NetworkName - Supplies a pointer to the network name (e.g., NET1).

    TransportName - The fully qualified name of the transport device.
        For example, "\Device\Nbf".

    TransportAddress - The fully qualified address (or name ) of the
        server's endpoint.  This name is used exactly as specified.  For
        NETBIOS-compatible networks, the caller must upcase and
        blank-fill the name.  E.g., "\Device\Nbf\NTSERVERbbbbbbbb".

    DomainName - the domain being serviced by this endpoint

Return Value:

    None.

--*/

{
    CLONG length;
    PENDPOINT endpoint;
    NTSTATUS status;

    PAGED_CODE( );

    //
    // Attempt to allocate from nonpaged pool.
    //

    length = sizeof(ENDPOINT) +
                NetworkName->Length + sizeof(*NetworkName->Buffer) +
                TransportName->Length + sizeof(*TransportName->Buffer) +
                TransportAddress->Length + sizeof(*TransportAddress->Buffer) +
                RtlOemStringToUnicodeSize( TransportAddress ) +
                DNLEN * sizeof( *DomainName->Buffer ) +
                DNLEN + sizeof(CHAR);

    endpoint = ALLOCATE_NONPAGED_POOL( length, BlockTypeEndpoint );
    *Endpoint = endpoint;

    if ( endpoint == NULL ) {

        INTERNAL_ERROR (
            ERROR_LEVEL_EXPECTED,
            "SrvAllocateEndpoint: Unable to allocate %d bytes from nonpaged "
                "pool.",
            length,
            NULL
            );

        return;
    }

    IF_DEBUG(HEAP) {
        SrvPrint1( "SrvAllocateEndpoint: Allocated endpoint at %p\n",
                    endpoint );
    }

    //
    // Initialize the endpoint block.  Zero it first.
    //

    RtlZeroMemory( endpoint, length );

    SET_BLOCK_TYPE_STATE_SIZE( endpoint, BlockTypeEndpoint, BlockStateActive, length );
    endpoint->BlockHeader.ReferenceCount = 2;       // allow for Active status
                                                    //  and caller's pointer

    //
    // Allocate connection table.
    //

    SrvAllocateTable(
        &endpoint->ConnectionTable,
        6, // !!!
        TRUE
        );
    if ( endpoint->ConnectionTable.Table == NULL ) {
        DEALLOCATE_NONPAGED_POOL( endpoint );
        *Endpoint = NULL;
        return;
    }

    InitializeListHead( &endpoint->FreeConnectionList );
#if SRVDBG29
    UpdateConnectionHistory( "INIT", endpoint, NULL );
#endif

    //
    // Copy the network name, transport name, and server address, and domain
    // name into the block.
    //

    endpoint->NetworkName.Length = NetworkName->Length;
    endpoint->NetworkName.MaximumLength =
            (SHORT)(NetworkName->Length + sizeof(*NetworkName->Buffer));
    endpoint->NetworkName.Buffer = (PWCH)(endpoint + 1);
    RtlCopyMemory(
        endpoint->NetworkName.Buffer,
        NetworkName->Buffer,
        NetworkName->Length
        );

    endpoint->TransportName.Length = TransportName->Length;
    endpoint->TransportName.MaximumLength =
            (SHORT)(TransportName->Length + sizeof(*TransportName->Buffer));
    endpoint->TransportName.Buffer =
                            (PWCH)((PCHAR)endpoint->NetworkName.Buffer +
                                    endpoint->NetworkName.MaximumLength);
    RtlCopyMemory(
        endpoint->TransportName.Buffer,
        TransportName->Buffer,
        TransportName->Length
        );

    endpoint->ServerName.MaximumLength = (USHORT)RtlOemStringToUnicodeSize( TransportAddress );
    endpoint->ServerName.Length = 0;
    endpoint->ServerName.Buffer = endpoint->TransportName.Buffer +
                                    endpoint->TransportName.MaximumLength / sizeof( WCHAR );

    endpoint->TransportAddress.Length = TransportAddress->Length;
    endpoint->TransportAddress.MaximumLength =
                                (SHORT)(TransportAddress->Length + 1);
    endpoint->TransportAddress.Buffer =
                            (PCHAR)endpoint->ServerName.Buffer +
                                    endpoint->ServerName.MaximumLength;
    RtlCopyMemory(
        endpoint->TransportAddress.Buffer,
        TransportAddress->Buffer,
        TransportAddress->Length
        );

    status = RtlOemStringToUnicodeString( &endpoint->ServerName, TransportAddress, FALSE );

    if (!NT_SUCCESS(status)) {
        DbgPrint("SRv ENDPOINT Name translation failed status %lx\n",status);
        KdPrint(("SRv ENDPOINT Name translation failed status %lx\n",status));
    }

    //
    // Trim the trailing blanks off the end of servername
    //
    while( endpoint->ServerName.Length &&
        endpoint->ServerName.Buffer[ (endpoint->ServerName.Length / sizeof(WCHAR))-1 ] == L' ' ) {

        endpoint->ServerName.Length -= sizeof( WCHAR );
    }

    endpoint->DomainName.Length = DomainName->Length;
    endpoint->DomainName.MaximumLength =  DNLEN * sizeof( *endpoint->DomainName.Buffer );
    endpoint->DomainName.Buffer = (PWCH)((PCHAR)endpoint->TransportAddress.Buffer +
                                         TransportAddress->MaximumLength);
    RtlCopyMemory(
        endpoint->DomainName.Buffer,
        DomainName->Buffer,
        DomainName->Length
    );

    endpoint->OemDomainName.Length = (SHORT)RtlUnicodeStringToOemSize( DomainName );
    endpoint->OemDomainName.MaximumLength = DNLEN + sizeof( CHAR );
    endpoint->OemDomainName.Buffer = (PCHAR)endpoint->DomainName.Buffer +
                                     endpoint->DomainName.MaximumLength;

    status = RtlUnicodeStringToOemString(
                &endpoint->OemDomainName,
                &endpoint->DomainName,
                FALSE     // Do not allocate the OEM string
                );
    ASSERT( NT_SUCCESS(status) );


    //
    // Initialize the network address field.
    //

    endpoint->NetworkAddress.Buffer = endpoint->NetworkAddressData;
    endpoint->NetworkAddress.Length = sizeof( endpoint->NetworkAddressData ) -
                                      sizeof(endpoint->NetworkAddressData[0]);
    endpoint->NetworkAddress.MaximumLength = sizeof( endpoint->NetworkAddressData );

    //
    // Increment the count of endpoints in the server.
    //

    ACQUIRE_LOCK( &SrvEndpointLock );
    SrvEndpointCount++;
    RELEASE_LOCK( &SrvEndpointLock );

    INITIALIZE_REFERENCE_HISTORY( endpoint );

    INCREMENT_DEBUG_STAT( SrvDbgStatistics.EndpointInfo.Allocations );

    return;

} // SrvAllocateEndpoint


BOOLEAN SRVFASTCALL
SrvCheckAndReferenceEndpoint (
    PENDPOINT Endpoint
    )

/*++

Routine Description:

    This function atomically verifies that an endpoint is active and
    increments the reference count on the endpoint if it is.

Arguments:

    Endpoint - Address of endpoint

Return Value:

    BOOLEAN - Returns TRUE if the endpoint is active, FALSE otherwise.

--*/

{
    PAGED_CODE( );

    //
    // Acquire the lock that guards the endpoint's state field.
    //

    ACQUIRE_LOCK( &SrvEndpointLock );

    //
    // If the endpoint is active, reference it and return TRUE.
    //

    if ( GET_BLOCK_STATE(Endpoint) == BlockStateActive ) {

        SrvReferenceEndpoint( Endpoint );

        RELEASE_LOCK( &SrvEndpointLock );

        return TRUE;

    }

    //
    // The endpoint isn't active.  Return FALSE.
    //

    RELEASE_LOCK( &SrvEndpointLock );

    return FALSE;

} // SrvCheckAndReferenceEndpoint


VOID
SrvCloseEndpoint (
    IN PENDPOINT Endpoint
    )

/*++

Routine Description:

    This function closes a transport endpoint.

    *** This function must be called with SrvEndpointLock held exactly
        once.  The lock is released on exit.

Arguments:

    Endpoint - Supplies a pointer to an Endpoint Block

Return Value:

    None.

--*/

{
    USHORT index;
    PCONNECTION connection;

    PAGED_CODE( );

    ASSERT( ExIsResourceAcquiredExclusiveLite(&RESOURCE_OF(SrvEndpointLock)) );

    if ( GET_BLOCK_STATE(Endpoint) == BlockStateActive ) {

        IF_DEBUG(BLOCK1) SrvPrint1( "Closing endpoint at %p\n", Endpoint );

        SET_BLOCK_STATE( Endpoint, BlockStateClosing );

        //
        // Close all active connections.
        //

        index = (USHORT)-1;

        while ( TRUE ) {

            //
            // Get the next active connection in the table.  If no more
            // are available, WalkConnectionTable returns NULL.
            // Otherwise, it returns a referenced pointer to a
            // connection.
            //

            connection = WalkConnectionTable( Endpoint, &index );
            if ( connection == NULL ) {
                break;
            }

            //
            // We don't want to hold the endpoint lock while we close the
            // connection (this causes lock level problems).  Since we
            // already have a referenced pointer to the connection, this
            // is safe.
            //

            RELEASE_LOCK( &SrvEndpointLock );

#if SRVDBG29
            UpdateConnectionHistory( "CEND", Endpoint, connection );
#endif
            connection->DisconnectReason = DisconnectEndpointClosing;
            SrvCloseConnection( connection, FALSE );

            ACQUIRE_LOCK( &SrvEndpointLock );

            SrvDereferenceConnection( connection );

        }

        //
        // Close all free connections.
        //

        EmptyFreeConnectionList( Endpoint );

        //
        // We don't need to hold the endpoint lock anymore.
        //

        RELEASE_LOCK( &SrvEndpointLock );

        //
        // Close the endpoint file handle.  This causes all pending
        // requests to be aborted.  It also deregisters all event
        // handlers.
        //
        // *** Note that we have a separate reference to the file
        //     object, in addition to the handle.  We don't release that
        //     reference until all activity on the endpoint has ceased
        //     (in SrvDereferenceEndpoint).
        //

        SRVDBG_RELEASE_HANDLE( Endpoint->EndpointHandle, "END", 2, Endpoint );
        SrvNtClose( Endpoint->EndpointHandle, FALSE );
        if ( Endpoint->IsConnectionless ) {
            SRVDBG_RELEASE_HANDLE( Endpoint->NameSocketHandle, "END", 2, Endpoint );
            SrvNtClose( Endpoint->NameSocketHandle, FALSE );
        }

        //
        // Dereference the endpoint (to indicate that it's no longer
        // open).
        //

        SrvDereferenceEndpoint( Endpoint );

        INCREMENT_DEBUG_STAT( SrvDbgStatistics.EndpointInfo.Closes );

    } else {

        RELEASE_LOCK( &SrvEndpointLock );

    }

    return;

} // SrvCloseEndpoint


VOID SRVFASTCALL
SrvDereferenceEndpoint (
    IN PENDPOINT Endpoint
    )

/*++

Routine Description:

    This function decrements the reference count on an endpoint.  If the
    reference count goes to zero, the endpoint block is deleted.

Arguments:

    Endpoint - Address of endpoint

Return Value:

    None.

--*/

{
    ULONG newEndpointCount;

    PAGED_CODE( );

    //
    // Enter a critical section and decrement the reference count on the
    // block.
    //

    ACQUIRE_LOCK( &SrvEndpointLock );

    IF_DEBUG(REFCNT) {
        SrvPrint2( "Dereferencing endpoint %p; old refcnt %lx\n",
                    Endpoint, Endpoint->BlockHeader.ReferenceCount );
    }

    ASSERT( GET_BLOCK_TYPE( Endpoint ) == BlockTypeEndpoint );
    ASSERT( (LONG)Endpoint->BlockHeader.ReferenceCount > 0 );
    UPDATE_REFERENCE_HISTORY( Endpoint, TRUE );

    if ( --Endpoint->BlockHeader.ReferenceCount == 0 ) {

        //
        // The new reference count is 0, meaning that it's time to
        // delete this block.
        //

        ASSERT( GET_BLOCK_STATE(Endpoint) != BlockStateActive );

        //
        // Decrement the count of endpoints in the server.  If the new
        // count is zero, set the endpoint event.
        //

        ASSERT( SrvEndpointCount >= 1 );

        newEndpointCount = --SrvEndpointCount;

        RELEASE_LOCK( &SrvEndpointLock );

        if ( newEndpointCount == 0 ) {
            KeSetEvent( &SrvEndpointEvent, 0, FALSE );
        }

        //
        // Remove the endpoint from the global list of endpoints.
        //

        SrvRemoveEntryOrderedList( &SrvEndpointList, Endpoint );

        //
        // Dereference the file object pointer.  (The handle to the file
        // object was closed in SrvCloseEndpoint.)
        //

        ObDereferenceObject( Endpoint->FileObject );
        if ( Endpoint->IsConnectionless ) {
            ObDereferenceObject( Endpoint->NameSocketFileObject );
        }

        //
        // Free the endpoint block's storage.
        //

        SrvFreeEndpoint( Endpoint );

    } else {

        RELEASE_LOCK( &SrvEndpointLock );

    }

    return;

} // SrvDereferenceEndpoint


VOID
SrvFreeEndpoint (
    IN PENDPOINT Endpoint
    )

/*++

Routine Description:

    This function returns an Endpoint Block to the system nonpaged pool.

Arguments:

    Endpoint - Address of endpoint

Return Value:

    None.

--*/

{
    PAGED_CODE( );

    DEBUG SET_BLOCK_TYPE_STATE_SIZE( Endpoint, BlockTypeGarbage, BlockStateDead, -1 );
    DEBUG Endpoint->BlockHeader.ReferenceCount = (ULONG)-1;
    TERMINATE_REFERENCE_HISTORY( Endpoint );

    if ( Endpoint->IpxMaxPacketSizeArray != NULL ) {
        FREE_HEAP( Endpoint->IpxMaxPacketSizeArray );
    }

    if ( Endpoint->ConnectionTable.Table != NULL ) {
        SrvFreeTable( &Endpoint->ConnectionTable );
    }

    DEALLOCATE_NONPAGED_POOL( Endpoint );
    IF_DEBUG(HEAP) SrvPrint1( "SrvFreeEndpoint: Freed endpoint block at %p\n", Endpoint );

    INCREMENT_DEBUG_STAT( SrvDbgStatistics.EndpointInfo.Frees );

    return;

} // SrvFreeEndpoint


VOID
SrvReferenceEndpoint (
    PENDPOINT Endpoint
    )

/*++

Routine Description:

    This function increments the reference count on an endpoint block.

Arguments:

    Endpoint - Address of endpoint

Return Value:

    None.

--*/

{
    PAGED_CODE( );

    //
    // Enter a critical section and increment the reference count on the
    // endpoint.
    //

    ACQUIRE_LOCK( &SrvEndpointLock );

    ASSERT( (LONG)Endpoint->BlockHeader.ReferenceCount > 0 );
    ASSERT( GET_BLOCK_TYPE(Endpoint) == BlockTypeEndpoint );
    ASSERT( GET_BLOCK_STATE(Endpoint) == BlockStateActive );
    UPDATE_REFERENCE_HISTORY( Endpoint, FALSE );

    Endpoint->BlockHeader.ReferenceCount++;

    IF_DEBUG(REFCNT) SrvPrint2( "Referencing endpoint %p; new refcnt %lx\n",
            Endpoint, Endpoint->BlockHeader.ReferenceCount );

    RELEASE_LOCK( &SrvEndpointLock );

    return;

} // SrvReferenceEndpoint

BOOLEAN
SrvFindNamedEndpoint(
    IN PUNICODE_STRING ServerName,
    OUT PBOOLEAN RemapPipeNames OPTIONAL
)
/*++

Routine Description:

    This routine returns TRUE of any endpoint is supporting 'ServerName'.

    Additionally, set the RemapPipeNames variable from the found endpoint.

--*/
{
    PLIST_ENTRY listEntry;
    PENDPOINT endpoint = NULL;

    PAGED_CODE( );

    if( ARGUMENT_PRESENT( RemapPipeNames ) ) {
        *RemapPipeNames = FALSE;
    }

    //
    // Find an endpoint block supporting the specified name.
    //

    ACQUIRE_LOCK_SHARED( &SrvEndpointLock );

    for( listEntry = SrvEndpointList.ListHead.Flink;
         listEntry != &SrvEndpointList.ListHead;
         endpoint = NULL, listEntry = listEntry->Flink ) {

        endpoint = CONTAINING_RECORD(
                        listEntry,
                        ENDPOINT,
                        GlobalEndpointListEntry
                        );

        //
        // Skip any inappropriate endpoints
        //
        if( GET_BLOCK_STATE( endpoint ) != BlockStateActive ||
            endpoint->IsConnectionless ||
            (ARGUMENT_PRESENT( RemapPipeNames ) && endpoint->IsNoNetBios) ) {

            continue;
        }

        //
        // See if this endpoint literally matches the name we're looking for
        //
        if( RtlEqualUnicodeString( ServerName, &endpoint->ServerName, TRUE ) ) {
            break;
        }

        //
        // We might have a case where the ServerName is something like
        //      server.dns.company.com
        //  but the endpoint netbios name is only 'server'.  We should match this
        //
        if( endpoint->ServerName.Length < ServerName->Length ) {
            UNICODE_STRING shortServerName;

            shortServerName = *ServerName;
            shortServerName.Length = endpoint->ServerName.Length;

            if (RtlEqualUnicodeString( &endpoint->ServerName, &shortServerName, TRUE)) {
                if (endpoint->ServerName.Length < ((NETBIOS_NAME_LEN - 1) * sizeof(WCHAR))) {
                    if (ServerName->Buffer[ shortServerName.Length / sizeof( WCHAR ) ] == L'.') {
                        break;
                    }
                } else {
                    if (endpoint->ServerName.Length == (NETBIOS_NAME_LEN - 1) * sizeof(WCHAR)) {
                        break;
                    }
                }
            }
        }

        //
        // See if this endpoint domain name literally matches the name we're
        // looking for. The following two tests against the domain name are
        // required to cover the case when there are certain components that
        // use the domain name to talk to the server. Given the way name resolution
        // records are setup this used to work before this checkin. This change
        // breaks them. These tests provide us the backward compatibility.
        //
        if( RtlEqualUnicodeString( ServerName, &endpoint->DomainName, TRUE ) ) {
            break;
        }

        //
        // We might have a case where the ServerName is something like
        //      server.dns.company.com
        //  but the endpoint netbios name is only 'server'.  We should match this
        //

        if( endpoint->DomainName.Length < ServerName->Length ) {
            UNICODE_STRING shortServerName;

            shortServerName = *ServerName;
            shortServerName.Length = endpoint->DomainName.Length;

            if (RtlEqualUnicodeString( &endpoint->DomainName, &shortServerName, TRUE)) {
                if (endpoint->DomainName.Length <= (NETBIOS_NAME_LEN * sizeof(WCHAR))) {
                    if (ServerName->Buffer[ shortServerName.Length / sizeof( WCHAR ) ] == L'.') {
                        break;
                    }
                } else {
                    if (endpoint->DomainName.Length == (NETBIOS_NAME_LEN - 1) * sizeof(WCHAR)) {
                        break;
                    }
                }
            }
        }

    }

    if( ARGUMENT_PRESENT( RemapPipeNames ) && endpoint != NULL ) {
        *RemapPipeNames = ( endpoint->RemapPipeNames == TRUE );
    }

    RELEASE_LOCK( &SrvEndpointLock );

    return endpoint != NULL;
}


VOID
EmptyFreeConnectionList (
    IN PENDPOINT Endpoint
    )
{
    PCONNECTION connection;
    PLIST_ENTRY listEntry;
    KIRQL oldIrql;

    //
    // *** In order to synchronize with the TDI connect handler in
    //     the FSD, which only uses a spin lock to serialize access
    //     to the free connection list (and does not check the
    //     endpoint state), we need to atomically capture the list
    //     head and empty the list.
    //

    ACQUIRE_GLOBAL_SPIN_LOCK( Fsd, &oldIrql );

    listEntry = Endpoint->FreeConnectionList.Flink;
    InitializeListHead( &Endpoint->FreeConnectionList );
#if SRVDBG29
    UpdateConnectionHistory( "CLOS", Endpoint, NULL );
#endif

    RELEASE_GLOBAL_SPIN_LOCK( Fsd, oldIrql );

    while ( listEntry != &Endpoint->FreeConnectionList ) {

        connection = CONTAINING_RECORD(
                        listEntry,
                        CONNECTION,
                        EndpointFreeListEntry
                        );

        listEntry = listEntry->Flink;
        SrvCloseFreeConnection( connection );

    }

    return;

} // EmptyFreeConnectionList


PCONNECTION
WalkConnectionTable (
    IN PENDPOINT Endpoint,
    IN OUT PUSHORT Index
    )
{
    USHORT i;
    PTABLE_HEADER tableHeader;
    PCONNECTION connection;
    KIRQL oldIrql;

    ACQUIRE_SPIN_LOCK( &ENDPOINT_SPIN_LOCK(0), &oldIrql );
    for ( i = 1; i < ENDPOINT_LOCK_COUNT ; i++ ) {
        ACQUIRE_DPC_SPIN_LOCK( &ENDPOINT_SPIN_LOCK(i) );
    }

    tableHeader = &Endpoint->ConnectionTable;

    for ( i = *Index + 1; i < tableHeader->TableSize; i++ ) {

        connection = (PCONNECTION)tableHeader->Table[i].Owner;
        if ( (connection != NULL) &&
             (GET_BLOCK_STATE(connection) == BlockStateActive) ) {
            *Index = i;
            SrvReferenceConnectionLocked( connection );
            goto exit;
        }
    }
    connection = NULL;

exit:

    for ( i = ENDPOINT_LOCK_COUNT-1 ; i > 0  ; i-- ) {
        RELEASE_DPC_SPIN_LOCK( &ENDPOINT_SPIN_LOCK(i) );
    }
    RELEASE_SPIN_LOCK( &ENDPOINT_SPIN_LOCK(0), oldIrql );

    return connection;
} // WalkConnectionTable