Leaked source code of windows server 2003
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1098 lines
29 KiB

/*++
Copyright (c) 1989-2001 Microsoft Corporation
Module Name:
smbsvc.c
Abstract:
This the user-mode proxy for the kernel mode DNS resolver.
features:
1. Multi-threaded
NBT4 use a single-threaded design. The DNS name resolution is a performance blocker.
When a connection request is being served by LmhSvc, all the other requests requiring
DNS resolution will be blocked.
2. IPv6 and IPv4 compatiable
3. can be run either as a service or standalone executable (for debug purpose)
When started as service, debug spew is sent to debugger.
When started as a standalong executable, the debug spew is sent to either
the console or debugger. smbhelper.c contain the _main for the standardalone
executable.
Author:
Jiandong Ruan
Revision History:
--*/
#include "precomp.h"
#include "smbtrace.h"
#include <nt.h>
#include <ntrtl.h>
#include <nturtl.h>
#include <windows.h>
#include <shellapi.h>
#include <stdio.h>
#include <stdlib.h>
#include <locale.h>
#include <tdi.h>
#include <winsock2.h>
#include <ws2tcpip.h>
#include <mstcpip.h>
#include <ipexport.h>
#include <winsock2.h>
#include <icmpapi.h>
#include "svclib.h"
#include "ping6.h"
#include "smbsvc.tmh"
#if DBG
//
// In order to disable the compiler optimization,
// don't use static on SmbDebug.
// (When static is used, compiler may find that
// SmbDebug isn't changed in this module, it will
// remove the "if (SmbDebug)" )
//
static int (*SmbDbgPrint)(char *fmt,...) = NULL;
# define KDPRINT(y) \
do { \
int (*MyDbgPrint)(char *fmt,...) = SmbDbgPrint; \
if (NULL != MyDbgPrint) { \
MyDbgPrint y; \
MyDbgPrint(": %d of %s\n", __LINE__, __FILE__); \
} \
} while(0)
#else
# define KDPRINT(y)
#endif
static HANDLE OpenSmb(LPWSTR Name);
#define DEFAULT_RECV_SIZE (0x2000) // Icmp recv buffer size
#define DEFAULT_TTL 32
#define DEFAULT_TOS 0
#define DEFAULT_TIMEOUT 2000L // default timeout set to 2 secs.
static HANDLE hTerminateEvent;
static BYTE SendBuffer[] = "SMBEcho";
static IP_OPTION_INFORMATION SendOpts = {
DEFAULT_TTL, DEFAULT_TOS, 0, 0, NULL
};
#ifdef __USE_GETADDRINFO__
//
// Please use getaddrinfo whenever it has a UNICODE version
//
VOID
ReturnAllIPAddress(
PSMB_DNS_BUFFER dns,
struct addrinfo *res
)
{
SOCKET_ADDRESS SelectedAddr;
struct addrinfo *p = res;
BOOL bFoundIPv4 = FALSE;
for (p = res; p; p = p->ai_next) {
if (p->ai_family == PF_INET6) {
//
// Save the last slot for IPv4 address
//
if (!bFoundIPv4 && dns->IpAddrsNum == SMB_MAX_IPADDRS_PER_HOST - 1) {
continue;
}
//
// Skip all the multicast address
//
if (IN6_IS_ADDR_MULTICAST(&((struct sockaddr_in6*)p->ai_addr)->sin6_addr)) {
continue;
}
//
// Only allow supported addresses in
//
if (!(dns->RequestType & SMB_DNS_AAAA_GLOBAL) &&
!SMB_IS_ADDRESS_ALLOWED(((struct sockaddr_in6*)p->ai_addr)->sin6_addr.u.Byte)) {
continue;
}
SelectedAddr.lpSockdddr = p->ai_addr;
SelectedAddr.iSockaddrLength = sizeof (sockaddr_in6);
} else if (p->ai_family == PF_INET) {
SelectedAddr.lpSockdddr = p->ai_addr;
SelectedAddr.iSockaddrLength = sizeof (sockaddr_in);
bFoundIPv4 = TRUE;
} else {
ASSERT (0);
continue;
}
SmbReturnIPAddress (dns, &SelectedAddr);
}
return p;
}
VOID
SmbGetHostByName(
PSMB_DNS_BUFFER dns
)
{
struct addrinfo hints = {0};
struct addrinfo *res;
CHAR nodename[DNS_NAME_BUFFER_LENGTH];
KDPRINT(("Resolving %ws", dns->Name));
dns->Resolved = FALSE;
dns->IpAddrsNum = 0;
if (WideCharToMultiByte(
CP_THREAD_ACP,
WC_NO_BEST_FIT_CHARS,
dns->Name,
dns->NameLen,
nodename,
DNS_NAME_BUFFER_LENGTH,
NULL,
NULL
) == 0) {
KDPRINT(("WideCharToMultiByte return %d", GetLastError()));
return;
}
hints.ai_family = 0;
hints.ai_flags = AI_CANONNAME;
hints.ai_socktype = SOCK_STREAM;
if (getaddrinfo(nodename, NULL, &hints, &res)) {
KDPRINT(("getaddrinfo return %d", GetLastError()));
return;
}
if (NULL == res) {
//
// This should not happen since getaddrinfo returns success above
//
ASSERT(0);
return;
}
dns->NameLen = MultiByteToWideChar(
CP_THREAD_ACP,
MB_ERR_INVALID_CHARS,
res->ai_canonname,
-1,
dns->Name,
DNS_NAME_BUFFER_LENGTH
);
if (0 == dns->NameLen) {
KDPRINT(("MultiByteToWideChar return %d", GetLastError()));
}
dns->Resolved = TRUE;
ReturnAllIPAddress(dns, res);
freeaddrinfo(res);
}
#else
VOID
SmbCopyDnsName(
IN OUT PSMB_DNS_BUFFER dns,
IN WCHAR *name
)
/*++
Routine Description:
This routine update the Name and NameLen field of dns.
If the input name is too long, nothing will happen
Arguments:
Return Value:
--*/
{
int len;
if (NULL == name) {
return;
}
len = wcslen(name) + 1;
if ( len > DNS_NAME_BUFFER_LENGTH ) {
return;
}
memcpy (dns->Name, name, len * sizeof(WCHAR));
dns->NameLen = len - 1;
KDPRINT(("DNS name is updated with %ws", dns->Name));
SmbTrace(SMB_TRACE_DNS, ("\tFQDN %ws", dns->Name));
}
VOID
SmbReturnIPAddress (
IN OUT PSMB_DNS_BUFFER dns,
IN PSOCKET_ADDRESS pSelectedAddr
)
{
struct sockaddr_in *pv4addr = NULL;
struct sockaddr_in6 *pv6addr = NULL;
PSMB_IP_ADDRESS pSmbIpAddress = NULL;
#ifdef DBG
CHAR name_buffer[48];
#endif
//
// Update the return buffer (returned to the kernel mode driver)
//
dns->Resolved = TRUE;
if (dns->IpAddrsNum >= SMB_MAX_IPADDRS_PER_HOST) {
return;
}
pSmbIpAddress = &dns->IpAddrsList[dns->IpAddrsNum];
dns->IpAddrsNum++;
if (pSelectedAddr->lpSockaddr->sa_family == AF_INET) {
pSmbIpAddress->sin_family = SMB_AF_INET;
pv4addr = (struct sockaddr_in*)(pSelectedAddr->lpSockaddr);
RtlCopyMemory(&pSmbIpAddress->ip4.sin4_addr, &pv4addr->sin_addr, sizeof(pv4addr->sin_addr));
SmbTrace(SMB_TRACE_DNS, ("\treturn %!ipaddr!", pv4addr->sin_addr.S_un.S_addr));
} else {
pSmbIpAddress->sin_family = SMB_AF_INET6;
pv6addr = (struct sockaddr_in6*)(pSelectedAddr->lpSockaddr);
RtlCopyMemory(pSmbIpAddress->ip6.sin6_addr, &pv6addr->sin6_addr, sizeof(pv6addr->sin6_addr));
pSmbIpAddress->ip6.sin6_scope_id = pv6addr->sin6_scope_id;
SmbTrace(SMB_TRACE_DNS, ("\treturn %!IPV6ADDR!", &pv6addr->sin6_addr));
}
}
int
SortIPAddrs(
IN int af,
OUT LPVOID Addrs,
IN u_int NumAddrs,
IN u_int width,
OUT SOCKET_ADDRESS_LIST ** pAddrlist
)
/*++
Routine Description:
Lifted from %SDXROOT%\net\sockets\winsock2\ws2_32\src\addrinfo.cpp
Sort addresses of the same family.
A wrapper around a sort Ioctl.
If the Ioctl isn't implemented, the sort is a no-op.
Arguments:
Return Value:
--*/
{
DWORD status = NO_ERROR;
DWORD bytesReturned = 0;
DWORD size = 0;
DWORD i = 0;
PSOCKADDR paddr = NULL;
UINT countAddrs = NumAddrs;
LPSOCKET_ADDRESS_LIST paddrlist = NULL;
//
// build SOCKET_ADDRESS_LIST
// - allocate
// - fill in with pointers into SOCKADDR array
//
size = FIELD_OFFSET( SOCKET_ADDRESS_LIST, Address[countAddrs] );
paddrlist = (SOCKET_ADDRESS_LIST *)LocalAlloc(LMEM_FIXED, size);
if ( !paddrlist ) {
status = WSA_NOT_ENOUGH_MEMORY;
goto Done;
}
for ( i=0; i<countAddrs; i++ ) {
paddr = (PSOCKADDR) (((PBYTE)Addrs) + i * width);
paddrlist->Address[i].lpSockaddr = paddr;
paddrlist->Address[i].iSockaddrLength = width;
}
paddrlist->iAddressCount = countAddrs;
//
// sort if multiple addresses and able to open socket
// - open socket of desired type for sort
// - sort, if sort fails just return unsorted
//
if ( countAddrs > 1 ) {
SOCKET s = INVALID_SOCKET;
s = socket( af, SOCK_DGRAM, 0 );
if ( s == INVALID_SOCKET ) {
goto Done;
}
status = WSAIoctl(
s,
SIO_ADDRESS_LIST_SORT,
(LPVOID)paddrlist,
size,
(LPVOID)paddrlist,
size,
& bytesReturned,
NULL,
NULL );
closesocket(s);
status = NO_ERROR;
}
Done:
if ( status == NO_ERROR ) {
*pAddrlist = paddrlist;
}
return status;
}
typedef enum {
SMB_DNS_UNRESOLVED = 0,
SMB_DNS_BUFFER_TOO_SMALL,
SMB_DNS_RESOLVED,
SMB_DNS_REACHABLE
} SMB_DNS_STATUS;
typedef struct {
union {
SOCKADDR_IN *pV4Addrs;
SOCKADDR_IN6 *pV6Addrs;
PVOID pAddrs;
};
u_int NumSlots;
u_int NumAddrs;
} SMB_ADDR_LIST, *PSMB_ADDR_LIST;
int
SaveV6Address(
IN struct in6_addr * NewAddr,
IN OUT PSMB_ADDR_LIST pSmbAddrList
)
/*++
Routine Description:
Save an address into our array of addresses,
possibly growing the array.
Arguments:
Return Value:
Returns FALSE on failure. (Couldn't grow array.)
--*/
{
SOCKADDR_IN6 *addr6 = NULL;
DWORD dwSize = 0;
DWORD dwNewSlots = 0;
SOCKADDR_IN6 *NewAddrs = NULL;
USHORT ServicePort = 0;
//
// add another sockaddr to array if not enough space
//
if (pSmbAddrList->NumSlots == pSmbAddrList->NumAddrs) {
if (pSmbAddrList->pV6Addrs) {
ASSERT (pSmbAddrList->NumSlots);
dwNewSlots = 2 * pSmbAddrList->NumSlots;
dwSize = dwNewSlots * sizeof (SOCKADDR_IN6);
NewAddrs = (SOCKADDR_IN6 *) LocalReAlloc(pSmbAddrList->pV6Addrs, dwSize, LMEM_ZEROINIT);
} else {
dwNewSlots = SMB_MAX_IPADDRS_PER_HOST;
dwSize = dwNewSlots * sizeof (SOCKADDR_IN6);
NewAddrs = (SOCKADDR_IN6 *) LocalAlloc(LPTR, dwSize);
}
if (NewAddrs == NULL)
return FALSE;
pSmbAddrList->pV6Addrs = NewAddrs;
pSmbAddrList->NumSlots = dwNewSlots;
}
// fill in IP6 sockaddr
addr6 = pSmbAddrList->pV6Addrs + (pSmbAddrList->NumAddrs++);
addr6->sin6_family = AF_INET6;
addr6->sin6_port = ServicePort;
memcpy(&addr6->sin6_addr, NewAddr, sizeof(*NewAddr));
return TRUE;
}
int
SaveV4Address(
IN struct in_addr * NewAddr,
IN OUT PSMB_ADDR_LIST pSmbAddrList
)
/*++
Routine Description:
Save an address into our array of addresses,
possibly growing the array.
Arguments:
Return Value:
Returns FALSE on failure. (Couldn't grow array.)
--*/
{
SOCKADDR_IN *addr = NULL;
DWORD dwSize = 0;
DWORD dwNewSlots = 0;
SOCKADDR_IN *NewAddrs = NULL;
USHORT ServicePort = 0;
//
// add another sockaddr to array if not enough space
//
if (pSmbAddrList->NumSlots == pSmbAddrList->NumAddrs) {
if (pSmbAddrList->pV4Addrs) {
ASSERT (pSmbAddrList->NumSlots);
dwNewSlots = 2 * pSmbAddrList->NumSlots;
dwSize = dwNewSlots * sizeof (SOCKADDR_IN);
NewAddrs = (SOCKADDR_IN *) LocalReAlloc(pSmbAddrList->pV4Addrs, dwSize, LMEM_ZEROINIT);
} else {
dwNewSlots = SMB_MAX_IPADDRS_PER_HOST;
dwSize = dwNewSlots * sizeof (SOCKADDR_IN);
NewAddrs = (SOCKADDR_IN *) LocalAlloc(LPTR, dwSize);
}
if (NewAddrs == NULL)
return FALSE;
pSmbAddrList->pV4Addrs = NewAddrs;
pSmbAddrList->NumSlots = dwNewSlots;
}
// fill in IP sockaddr
addr = pSmbAddrList->pV4Addrs + (pSmbAddrList->NumAddrs++);
addr->sin_family = AF_INET;
addr->sin_port = ServicePort;
memcpy(&addr->sin_addr, NewAddr, sizeof(*NewAddr));
return TRUE;
}
VOID
CleanupAddrList (
IN OUT PSMB_ADDR_LIST pSmbAddrList
)
{
if (pSmbAddrList->pAddrs) {
LocalFree (pSmbAddrList->pAddrs);
pSmbAddrList->pAddrs = NULL;
}
RtlZeroMemory (pSmbAddrList, sizeof(SMB_ADDR_LIST));
}
typedef struct {
SMB_ADDR_LIST V4AddrList;
SMB_ADDR_LIST V6AddrList;
LPSOCKET_ADDRESS_LIST pSortedV4AddrList;
LPSOCKET_ADDRESS_LIST pSortedV6AddrList;
} SMB_DNS_RESULT, *PSMB_DNS_RESULT;
VOID
CleanupDnsResult (
IN OUT PSMB_DNS_RESULT pSmbDnsResult
)
{
CleanupAddrList (&pSmbDnsResult->V4AddrList);
CleanupAddrList (&pSmbDnsResult->V6AddrList);
if (pSmbDnsResult->pSortedV4AddrList) {
LocalFree (pSmbDnsResult->pSortedV4AddrList);
pSmbDnsResult->pSortedV4AddrList = NULL;
}
if (pSmbDnsResult->pSortedV6AddrList) {
LocalFree (pSmbDnsResult->pSortedV6AddrList);
pSmbDnsResult->pSortedV6AddrList = NULL;
}
}
#define SMB_INIT_WSA_SIZE (sizeof(WSAQUERYSETW) + 2048)
//
// use WSALookupXXXX APIs since we don't have a UNICODE version of getaddrinfo
//
DWORD
LookupDns(
IN WCHAR * pwchName,
IN GUID * pgProvider,
IN OUT PSMB_DNS_BUFFER dns,
IN OUT PSMB_DNS_RESULT pSmbDnsResult
)
/*++
Routine Description:
Arguments:
Return Value:
WINERROR
--*/
{
PCSADDR_INFO pcsadr = NULL;
struct sockaddr_in *pv4addr = NULL;
struct sockaddr_in6 *pv6addr = NULL;
DWORD dwError = ERROR_SUCCESS;
HANDLE hRnR = NULL;
DWORD i = 0;
PWSAQUERYSETW pwsaq = NULL;
DWORD dwSize = 0;
DWORD dwCurrentSize = 0;
#ifdef DBG
CHAR name_buffer[48];
#endif
//
// Allocate the memory and setup the request
//
dwSize = dwCurrentSize = SMB_INIT_WSA_SIZE;
pwsaq = (PWSAQUERYSETW)LocalAlloc(LPTR, dwSize);
if (NULL == pwsaq) {
SmbTrace(SMB_TRACE_DNS, ("\tOut of memory"));
goto cleanup;
}
pwsaq->dwSize = sizeof(*pwsaq);
pwsaq->lpszServiceInstanceName = pwchName;
pwsaq->lpServiceClassId = pgProvider;
pwsaq->dwNameSpace = NS_DNS;
//
// Start the lookup
//
dwError = WSALookupServiceBeginW (pwsaq, LUP_RETURN_NAME| LUP_RETURN_ADDR, &hRnR);
if (dwError != NO_ERROR) {
dwError = GetLastError();
SmbTrace(SMB_TRACE_DNS, ("Error: %!winerr!", dwError));
goto cleanup;
}
while(1) {
dwSize = dwCurrentSize;
dwError = WSALookupServiceNextW (hRnR, 0, &dwSize, pwsaq);
if (dwError != NO_ERROR) {
dwError = GetLastError();
SmbTrace(SMB_TRACE_DNS, ("Error: %!winerr!", dwError));
if (dwError != WSAEFAULT) {
break;
}
if (dwSize <= SMB_INIT_WSA_SIZE) {
ASSERT(0);
SmbTrace(SMB_TRACE_DNS, ("\tInvalid buffer size %d returned by DNS", dwSize));
break;
}
//
// Realloc the buffer using the suggested size
//
LocalFree(pwsaq);
pwsaq = NULL;
pwsaq = (PWSAQUERYSETW)LocalAlloc(LPTR, dwSize);
if (NULL == pwsaq) {
SmbTrace(SMB_TRACE_DNS, ("\tOut of memory"));
goto cleanup;
}
dwCurrentSize = dwSize;
continue;
}
//
// Pick up the canonical name
//
SmbCopyDnsName(dns, pwsaq->lpszServiceInstanceName);
for (i = 0, pcsadr = pwsaq->lpcsaBuffer; i < pwsaq->dwNumberOfCsAddrs; i++, pcsadr++) {
switch (pcsadr->RemoteAddr.lpSockaddr->sa_family) {
case AF_INET:
if (pcsadr->RemoteAddr.iSockaddrLength >= sizeof(struct sockaddr_in)) {
pv4addr = (struct sockaddr_in*)(pcsadr->RemoteAddr.lpSockaddr);
SaveV4Address (&pv4addr->sin_addr, &pSmbDnsResult->V4AddrList);
SmbTrace(SMB_TRACE_DNS, ("\t%!ipaddr!", pv4addr->sin_addr.S_un.S_addr));
}
break;
case AF_INET6:
if (pcsadr->RemoteAddr.iSockaddrLength >= sizeof(struct sockaddr_in6)) {
pv6addr = (struct sockaddr_in6*)(pcsadr->RemoteAddr.lpSockaddr);
//
// Skip all the multicast address
//
if (IN6_IS_ADDR_MULTICAST(&pv6addr->sin6_addr)) {
SmbTrace(SMB_TRACE_DNS, ("\tSkip %!IPV6ADDR!", &pv6addr->sin6_addr));
continue;
}
if (!(dns->RequestType & SMB_DNS_AAAA_GLOBAL) &&
!SMB_IS_ADDRESS_ALLOWED(pv6addr->sin6_addr.u.Byte)) {
SmbTrace(SMB_TRACE_DNS, ("\tSkip %!IPV6ADDR!", &pv6addr->sin6_addr));
continue;
}
SaveV6Address (&pv6addr->sin6_addr, &pSmbDnsResult->V6AddrList);
SmbTrace(SMB_TRACE_DNS, ("\t%!IPV6ADDR!", &pv6addr->sin6_addr));
}
break;
default:
KDPRINT(("Skip non-IP address"));
break;
}
}
}
WSALookupServiceEnd(hRnR);
hRnR = NULL;
cleanup:
if (NULL != hRnR) {
WSALookupServiceEnd(hRnR);
hRnR = NULL;
}
LocalFree(pwsaq); // LocalFree can handle NULL case
pwsaq = NULL;
return dwError;
}
static GUID guidHostnameV6 = SVCID_DNS_TYPE_AAAA;
static GUID guidHostnameV4 = SVCID_DNS_TYPE_A;
VOID
SmbGetHostByName(
PSMB_DNS_BUFFER dns
)
/*++
Routine Description:
Resolve the name through DNS.
Arguments:
Return Value:
--*/
{
DWORD dwError = 0;
int i;
SMB_DNS_RESULT SmbDnsResult = { 0 };
dns->Resolved = FALSE;
dns->IpAddrsNum = 0;
if (dns->NameLen > DNS_MAX_NAME_LENGTH) {
SmbTrace(SMB_TRACE_DNS, ("Error: name too long %d", dns->NameLen));
KDPRINT(("Receive invalid request"));
goto cleanup;
}
dns->Name[dns->NameLen] = L'\0';
SmbTrace(SMB_TRACE_DNS, ("Resolving %ws", dns->Name));
if (dns->RequestType & SMB_DNS_AAAA_GLOBAL) {
dns->RequestType |= SMB_DNS_AAAA;
SmbTrace(SMB_TRACE_DNS, ("\tGlobal IPv6 Address is on"));
}
if (dns->RequestType & SMB_DNS_AAAA) {
KDPRINT(("Looking for AAAA record for %ws", dns->Name));
SmbTrace(SMB_TRACE_DNS, ("\tLookup AAAA record"));
dwError = LookupDns(dns->Name, &guidHostnameV6, dns, &SmbDnsResult);
SmbTrace(SMB_TRACE_DNS, ("\tFound %d IPv6 address", SmbDnsResult.V6AddrList.NumAddrs));
if (SmbDnsResult.V6AddrList.NumAddrs > 0) {
SortIPAddrs(
AF_INET6,
(LPVOID)SmbDnsResult.V6AddrList.pAddrs,
SmbDnsResult.V6AddrList.NumAddrs,
sizeof(SOCKADDR_IN6),
&SmbDnsResult.pSortedV6AddrList
);
}
}
if (dns->RequestType & SMB_DNS_A) {
KDPRINT(("Looking for A record for %ws", dns->Name));
SmbTrace(SMB_TRACE_DNS, ("\tLookup A record"));
dwError = LookupDns(dns->Name, &guidHostnameV4, dns, &SmbDnsResult);
SmbTrace(SMB_TRACE_DNS, ("\tFound %d IPv4 address", SmbDnsResult.V4AddrList.NumAddrs));
if (SmbDnsResult.V4AddrList.NumAddrs > 0) {
SortIPAddrs(
AF_INET,
(LPVOID)SmbDnsResult.V4AddrList.pAddrs,
SmbDnsResult.V4AddrList.NumAddrs,
sizeof(SOCKADDR_IN),
&SmbDnsResult.pSortedV4AddrList
);
}
}
cleanup:
if (SmbDnsResult.pSortedV6AddrList) {
for (i = 0; i < SmbDnsResult.pSortedV6AddrList->iAddressCount; i++) {
SmbReturnIPAddress (dns, &SmbDnsResult.pSortedV6AddrList->Address[i]);
}
}
if (SmbDnsResult.pSortedV4AddrList) {
//
// Make sure we have an IPv4 address to failover
//
if (dns->IpAddrsNum >= SMB_MAX_IPADDRS_PER_HOST &&
SmbDnsResult.pSortedV4AddrList->iAddressCount > 0) {
dns->IpAddrsNum = SMB_MAX_IPADDRS_PER_HOST / 2;
SmbTrace(SMB_TRACE_DNS, ("\tTruncate IPv6 address to give room to IPv4"));
}
for (i = 0; i < SmbDnsResult.pSortedV4AddrList->iAddressCount; i++) {
SmbReturnIPAddress (dns, &SmbDnsResult.pSortedV4AddrList->Address[i]);
}
}
SmbTrace(SMB_TRACE_DNS, ("\tNum of IP address returned: %d", dns->IpAddrsNum));
CleanupDnsResult (&SmbDnsResult);
}
#endif // __USE_GETADDRINFO__
DWORD
SmbGetHostThread(
LPVOID ctx
)
{
NTSTATUS status;
DWORD Error;
HANDLE WaitObject[2];
SMB_DNS_BUFFER Dns;
IO_STATUS_BLOCK iosb;
#define TERMINATE_EVENT (WAIT_OBJECT_0)
#define IO_EVENT (WAIT_OBJECT_0+1)
SmbTrace(SMB_TRACE_DNS, ("Thread start: detecting NetbiosSmb driver"));
WaitObject[TERMINATE_EVENT] = hTerminateEvent;
while(1) {
DWORD dwLocalError = ERROR_SUCCESS;
WaitObject[IO_EVENT] = OpenSmb(DD_SMB6_EXPORT_NAME);
if (WaitObject[IO_EVENT] != NULL) {
break;
}
Error = GetLastError();
//
// Wait for 60 seconds and try again
//
dwLocalError = WaitForSingleObject(hTerminateEvent, 60000);
if (dwLocalError != WAIT_TIMEOUT) {
return Error;
}
}
SmbTrace(SMB_TRACE_DNS, ("Thread start: NetbiosSmb driver detected"));
while (1) {
status = NtDeviceIoControlFile(
WaitObject[IO_EVENT], // Handle
NULL, // Wait Event
NULL, // ApcRoutine
NULL, // ApcContext
&iosb, // IoStatusBlock
IOCTL_SMB_DNS, // IoControlCode
&Dns, // InputBuffer
sizeof(Dns), // InputBufferSize
&Dns, // OutputBuffer
sizeof(Dns) // OutputBufferSize
);
if (status == STATUS_PENDING) {
Error = WaitForMultipleObjects(
2,
WaitObject,
FALSE,
INFINITE
);
if (Error == TERMINATE_EVENT) {
IO_STATUS_BLOCK iosb2;
KDPRINT(("Receive terminate event, cancelling request"));
status = NtCancelIoFile(WaitObject[IO_EVENT], &iosb2);
KDPRINT(("NtCancelIoFile: status 0x%08lx, iosb.Status 0x%08lx", status, iosb.Status));
break;
}
ASSERT(Error == IO_EVENT);
status = iosb.Status;
ASSERT(status != STATUS_PENDING);
}
SmbTrace(SMB_TRACE_DNS, ("%!status!", status));
//
// Bail out immediately if the underlying driver tell us quota exceeded
//
if (status == STATUS_QUOTA_EXCEEDED) {
KDPRINT(("NtDeviceIoControlFile: STATUS_QUOTA_EXCEEDED"));
break;
}
if (status != STATUS_SUCCESS) {
KDPRINT(("NtDeviceIoControlFile: status = 0x%08lx", status));
Error = WaitForSingleObject(hTerminateEvent, 5000);
if (Error == WAIT_OBJECT_0) {
KDPRINT(("Receive terminate event"));
break;
}
ASSERT(Error == WAIT_TIMEOUT);
} else {
ASSERT(iosb.Information == sizeof(Dns));
SmbGetHostByName(&Dns);
}
}
NtClose(WaitObject[IO_EVENT]);
SmbTrace(SMB_TRACE_DNS, ("Thread exit"));
KDPRINT(("Thread exit"));
return ERROR_SUCCESS;
}
LONG ThreadNumber;
HANDLE hThread[DNS_MAX_RESOLVER];
VOID
SmbSetTraceRoutine(
int (*trace)(char *,...)
)
{
#if DBG
SmbDbgPrint = trace;
#endif
}
DWORD
SmbStartService(
LONG NumWorker,
SMBSVC_UPDATE_STATUS HeartBeating
)
{
WSADATA WsaData;
int i;
DWORD Error = ERROR_SUCCESS;
if (NumWorker > DNS_MAX_RESOLVER || NumWorker < 0) {
SmbTrace(SMB_TRACE_DNS, ("Error: %d worker threads", NumWorker));
return ERROR_INVALID_PARAMETER;
}
if (NumWorker == 0) {
NumWorker = DNS_MAX_RESOLVER;
}
SmbTrace(SMB_TRACE_DNS, ("Starting %d worker threads", NumWorker));
if (WSAStartup(MAKEWORD(2, 0), &WsaData) == SOCKET_ERROR) {
Error = WSAGetLastError();
KDPRINT (("Failed to startup Winsock2"));
SmbTrace(SMB_TRACE_DNS, ("Error: %!winerr!", Error));
return Error;
}
hTerminateEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
if (NULL == hTerminateEvent) {
Error = GetLastError();
KDPRINT(("Create hTerminateEvent returns %d", Error));
SmbTrace(SMB_TRACE_DNS, ("Error: %!winerr!", Error));
return Error;
}
if (NULL != HeartBeating) {
HeartBeating();
}
Error = ERROR_SUCCESS;
for (i = 0; i < NumWorker; i++) {
hThread[i] = CreateThread(
NULL,
8192,
(LPTHREAD_START_ROUTINE)SmbGetHostThread,
NULL,
0,
NULL
);
if (NULL != HeartBeating) {
HeartBeating();
}
if (NULL == hThread[i]) {
Error = GetLastError();
KDPRINT(("CreateThread returns %d", Error));
SmbTrace(SMB_TRACE_DNS, ("Error: %!winerr!", Error));
break;
}
}
ThreadNumber = i;
if (i > 0) {
Error = ERROR_SUCCESS;
}
SmbTrace(SMB_TRACE_DNS, ("Thread Num: %d: %!winerr!", ThreadNumber, Error));
return Error;
}
VOID
SmbStopService(
SMBSVC_UPDATE_STATUS HeartBeating
)
{
DWORD Error;
LONG i;
SmbTrace(SMB_TRACE_DNS, ("Stop request received"));
if (NULL == hTerminateEvent) {
SmbTrace(SMB_TRACE_DNS, ("Already Stopped"));
return;
}
SmbTrace(SMB_TRACE_DNS, ("Stopping"));
KDPRINT(("Send terminate event"));
SetEvent(hTerminateEvent);
do {
Error = WaitForMultipleObjects(
ThreadNumber,
hThread,
TRUE,
1000
);
if (NULL != HeartBeating) {
HeartBeating();
}
} while(Error == WAIT_TIMEOUT);
for (i = 0; i < ThreadNumber; i++) {
CloseHandle(hThread[i]);
hThread[i] = NULL;
}
ThreadNumber = 0;
CloseHandle(hTerminateEvent);
hTerminateEvent = NULL;
if (NULL != HeartBeating) {
HeartBeating();
}
SmbTrace(SMB_TRACE_DNS, ("Stopped"));
}
HANDLE
OpenSmb(
LPWSTR Name
)
{
UNICODE_STRING ucName;
OBJECT_ATTRIBUTES ObAttr;
HANDLE StreamHandle;
IO_STATUS_BLOCK IoStatusBlock;
NTSTATUS status;
RtlInitUnicodeString(&ucName, Name);
InitializeObjectAttributes(
&ObAttr,
&ucName,
OBJ_CASE_INSENSITIVE,
(HANDLE) NULL,
(PSECURITY_DESCRIPTOR) NULL
);
status = NtCreateFile (
&StreamHandle,
SYNCHRONIZE | FILE_READ_DATA | FILE_WRITE_DATA,
&ObAttr,
&IoStatusBlock,
NULL,
FILE_ATTRIBUTE_NORMAL,
FILE_SHARE_READ | FILE_SHARE_WRITE,
FILE_OPEN_IF,
0,
NULL,
0
);
if (status != STATUS_SUCCESS) {
SetLastError(RtlNtStatusToDosError(status));
KDPRINT(("Open device %ws: status = 0x%08lx", Name, status));
return NULL;
}
SetLastError(ERROR_SUCCESS);
return StreamHandle;
}