|
|
/*++
Copyright (c) 2000 Microsoft Corporation
Module Name:
srv.c
Abstract:
Implements initialization and socket interface for smb server
Author:
Ahmed Mohamed (ahmedm) 1-Feb-2000
Revision History:
--*/
#include "srv.h"
#include <process.h> // for _beginthreadex
#include <mswsock.h>
#define PROTOCOL_TYPE SOCK_SEQPACKET
#define PLUS_CLUSTER 1
#define THREADAPI unsigned int WINAPI
void SrvCloseEndpoint(EndPoint_t *endpoint);
void PacketReset(SrvCtx_t *ctx) { int i, npackets, nbufs; Packet_t *p; char *buf;
npackets = MAX_PACKETS; nbufs = npackets * 2;
ctx->freelist = NULL; p = (Packet_t *) ctx->packet_pool; buf = (char *) ctx->buffer_pool; for (i = 0; i < npackets; i++) { p->buffer = (LPVOID) buf; p->ov.hEvent = NULL; buf += SRV_PACKET_SIZE; p->outbuf = (LPVOID) buf; buf += SRV_PACKET_SIZE; p->next = ctx->freelist; ctx->freelist = p; p++; }
}
BOOL PacketInit(SrvCtx_t *ctx) { int npackets, nbufs;
// Allocate 2 buffers for each packet
npackets = MAX_PACKETS; nbufs = npackets * 2;
ctx->packet_pool = xmalloc(sizeof(Packet_t) * npackets); if (ctx->packet_pool == NULL) { SrvLogError(("Unable to allocate packet pool!\n")); return FALSE; } ctx->buffer_pool = xmalloc(SRV_PACKET_SIZE * nbufs); if (ctx->buffer_pool == NULL) { xfree(ctx->packet_pool); SrvLogError(("Unable to allocate buffer pool!\n")); return FALSE; }
PacketReset(ctx);
return TRUE; }
Packet_t * PacketAlloc(EndPoint_t *ep) { // allocate a packet from free list, if no packet is available then
// we set the wanted flag and wait on event
SrvCtx_t *ctx; Packet_t *p;
ASSERT(ep); ctx = ep->SrvCtx;
retry: EnterCriticalSection(&ctx->cs); if (ctx->running == FALSE) { LeaveCriticalSection(&ctx->cs); return NULL; }
if (p = ctx->freelist) { ctx->freelist = p->next; } else { ctx->waiters++; LeaveCriticalSection(&ctx->cs); if (WaitForSingleObject(ctx->event, INFINITE) != WAIT_OBJECT_0) { return NULL; } goto retry; }
// Insert into per endpoint packet list
p->endpoint = ep; p->next = ep->PacketList; ep->PacketList = p;
LeaveCriticalSection(&ctx->cs);
return p; }
void PacketRelease(SrvCtx_t *ctx, Packet_t *p) { p->next = ctx->freelist; ctx->freelist = p; if (ctx->waiters > 0) { ctx->waiters--; SetEvent(ctx->event); } }
void PacketFree(Packet_t *p) { EndPoint_t *ep; SrvCtx_t *ctx; Packet_t **last;
ep = p->endpoint; ASSERT(ep); ctx = ep->SrvCtx; ASSERT(ctx);
// insert packet into head of freelist. if wanted flag is set, we signal event
EnterCriticalSection(&ctx->cs); // Remove packet from ep list
last = &ep->PacketList; while (*last != NULL) { if ((*last) == p) { *last = p->next; break; } last = &(*last)->next; } PacketRelease(ctx, p); if (ep->PacketList == NULL) { // Free this endpoint
SrvCloseEndpoint(ep); } LeaveCriticalSection(&ctx->cs); }
int ProcessPacket(EndPoint_t *ep, Packet_t *p)
{ BOOL disp;
if (IsSmb(p->buffer, p->len)) { p->in.smb = (PNT_SMB_HEADER)p->buffer; p->in.size = p->len; p->in.offset = sizeof(NT_SMB_HEADER); p->in.command = p->in.smb->Command;
p->out.smb = (PNT_SMB_HEADER)p->outbuf; p->out.size = SRV_PACKET_SIZE; p->out.valid = sizeof(NT_SMB_HEADER); InitSmbHeader(p);
DumpSmb(p->buffer, p->len, TRUE);
SrvLog(("dispatching Tid:%d Uid:%d Mid:%d Flags:%x Cmd:%d...\n", p->in.smb->Tid, p->in.smb->Uid, p->in.smb->Mid, p->in.smb->Flags2, p->in.command));
p->tag = 0;
disp = SrvDispatch(p);
if (disp == ERROR_IO_PENDING) { return ERROR_IO_PENDING; }
// If we handled it ok...
if (disp) { char *buffer; int len; int rc;
buffer = (char *)p->out.smb; len = (int) p->out.valid; DumpSmb(buffer, len, FALSE);
SrvLog(("sending...len %d\n", len));
rc = send(ep->Sock, buffer, len, 0); if (rc == SOCKET_ERROR || rc != len) { SrvLog(("Send clnt failed %d\n", WSAGetLastError())); closesocket(ep->Sock); }
} else { SrvLog(("dispatch failed!\n")); // did not understand...hangup on virtual circuit...
SrvLog(("hangup! -- disp failed on sock %s\n", ep->ClientId)); closesocket(ep->Sock); } }
return ERROR_SUCCESS; }
THREADAPI CompletionThread(LPVOID arg) { Packet_t* p; DWORD len; ULONG_PTR id; LPOVERLAPPED lpo; SrvCtx_t *ctx = (SrvCtx_t *) arg; HANDLE port = ctx->comport; EndPoint_t *endpoint; HANDLE ev;
ev = CreateEvent(NULL, FALSE, FALSE, NULL);
// Each thread needs its own event, msg to use
while(ctx->running) { BOOL b; b = GetQueuedCompletionStatus ( port, &len, &id, &lpo, INFINITE );
p = (Packet_t *) lpo; if (p == NULL) { SrvLog(("SrvThread exiting, %x...\n", id)); CloseHandle(ev); return 0; }
if (!b && !lpo) { SrvLog(("Getqueued failed %d\n",GetLastError())); CloseHandle(ev); PacketFree(p); return 0; }
// todo: when socket is closed, I need to free this endpoint.
// I need to tag the endpoint with how many packets got scheduled
// on it, when the refcnt reachs zero, I free it.
endpoint = (EndPoint_t *) id;
ASSERT(p->endpoint == endpoint); p->ev = ev; p->len = len;
if (ProcessPacket(endpoint, p) != ERROR_IO_PENDING) { // schedule next read
b = ReadFile ((HANDLE)endpoint->Sock, p->buffer, SRV_PACKET_SIZE, &len, &p->ov);
if (!b && GetLastError () != ERROR_IO_PENDING) { SrvLog(("SrvThread read ep 0x%x failed %d\n", endpoint, GetLastError())); // Return packet to queue
PacketFree(p); } } } CloseHandle(ev); SrvLog(("SrvThread exiting, not running...\n")); return 0; }
void SrvFinalize(Packet_t *p) { char *buffer; DWORD len, rc; EndPoint_t *endpoint = p->endpoint; ASSERT(p->tag == ERROR_IO_PENDING);
p->tag = 0;
buffer = (char *)p->out.smb; len = (DWORD) p->out.valid; DumpSmb(buffer, len, FALSE);
SrvLog(("sending...len %d\n", len)); rc = send(endpoint->Sock, buffer, len, 0); if (rc == SOCKET_ERROR || rc != len) { SrvLog(("Finalize Send clnt failed <%d>\n", WSAGetLastError())); }
rc = ReadFile ((HANDLE)endpoint->Sock, p->buffer, SRV_PACKET_SIZE, &len, &p->ov);
if (!rc && GetLastError () != ERROR_IO_PENDING) { // Return packet to queue
PacketFree(p); } }
void SrvCloseEndpoint(EndPoint_t *endpoint) { EndPoint_t **p; Packet_t *packet;
// lock must be held
while (packet = endpoint->PacketList) { endpoint->PacketList = packet->next; // return to free list now
PacketRelease(endpoint->SrvCtx, packet); }
// remove from ctx list
p = &endpoint->SrvCtx->EndPointList; while (*p != NULL) { if (*p == endpoint) { *p = endpoint->Next; break; } p = &(*p)->Next; }
closesocket(endpoint->Sock);
// We need to inform filesystem that this
// tree is gone.
FsLogoffUser(endpoint->SrvCtx->FsCtx, endpoint->LogonId);
free(endpoint); }
DWORD ListenSocket(SrvCtx_t *ctx, int nic) { DWORD err = ERROR_SUCCESS; SOCKET listen_socket = INVALID_SOCKET; struct sockaddr_nb local; unsigned char *srvname = ctx->nb_local_name;
SET_NETBIOS_SOCKADDR(&local, NETBIOS_UNIQUE_NAME, srvname, ' ');
listen_socket = socket(AF_NETBIOS, PROTOCOL_TYPE, -nic);
if (listen_socket == INVALID_SOCKET){ err = WSAGetLastError(); SrvLogError(("socket() '%s' nic %d failed with error %d\n", srvname, nic, err)); return err; }
//
// bind socket
//
if (bind(listen_socket,(struct sockaddr*)&local,sizeof(local)) == SOCKET_ERROR) { err = WSAGetLastError(); SrvLogError(("srv nic %d bind() failed with error %d\n",nic, err)); closesocket(listen_socket); return err; }
// issue listen
if (listen(listen_socket,5) == SOCKET_ERROR) { err = WSAGetLastError(); SrvLogError(("listen() failed with error %d\n", err)); closesocket(listen_socket); return err; }
// all is well.
ctx->listen_socket = listen_socket; return ERROR_SUCCESS; }
THREADAPI ListenThread(LPVOID arg) { SOCKET listen_socket, msgsock; struct sockaddr_nb from; int fromlen; HANDLE comport; SrvCtx_t *ctx = (SrvCtx_t *) arg; EndPoint_t *endpoint; char localname[64];
gethostname(localname, sizeof(localname));
listen_socket = ctx->listen_socket; comport = ctx->comport; while(ctx->running) { int i;
fromlen =sizeof(from); msgsock = accept(listen_socket,(struct sockaddr*)&from, &fromlen); if (msgsock == INVALID_SOCKET) { if (ctx->running) SrvLogError(("accept() error %d\n",WSAGetLastError())); break; } from.snb_name[NETBIOS_NAME_LENGTH-1] = '\0'; { char *s = strchr(from.snb_name, ' '); if (s != NULL) *s = '\0'; } SrvLog(("Received call from '%s'\n", from.snb_name));
// Fence off all nodes except cluster nodes. We ask
// our resource to check for us. For now we fence off all nodes but the this node
if (_stricmp(localname, from.snb_name)) { // sorry, we just close the connection now
closesocket(msgsock); continue; }
// allocate a new endpoint
endpoint = (EndPoint_t *) malloc(sizeof(*endpoint)); if (endpoint == NULL) { SrvLogError(("Failed allocate failed %d\n", GetLastError())); closesocket(msgsock); continue; } memset(endpoint, 0, sizeof(*endpoint));
// add endpoint now
EnterCriticalSection(&ctx->cs); endpoint->Next = ctx->EndPointList; ctx->EndPointList = endpoint; LeaveCriticalSection(&ctx->cs);
endpoint->Sock = msgsock; endpoint->SrvCtx = ctx; memcpy(endpoint->ClientId, from.snb_name, sizeof(endpoint->ClientId)); comport = CreateIoCompletionPort((HANDLE)msgsock, comport, (ULONG_PTR)endpoint, 8); if (!comport) { SrvLogError(("CompletionPort bind Failed %d\n", GetLastError())); SrvCloseEndpoint(endpoint); comport = ctx->comport; continue; }
for (i = 0; i < SRV_NUM_WORKERS; i++) { Packet_t *p; BOOL b; DWORD nbytes;
p = PacketAlloc(endpoint); if (p == NULL) { SrvLog(("Listen thread got null packet, exiting posted...\n")); break; } b = ReadFile ( (HANDLE) msgsock, p->buffer, SRV_PACKET_SIZE, &nbytes, &p->ov);
if (!b && GetLastError () != ERROR_IO_PENDING) { SrvLog(("Srv ReadFile Failed %d\n", GetLastError())); // Return packet to queue
PacketFree(p); break; } } }
return (0); }
DWORD SrvInit(PVOID resHdl, PVOID fsHdl, PVOID *Hdl) { SrvCtx_t *ctx; DWORD err;
ctx = (SrvCtx_t *) malloc(sizeof(*ctx)); if (ctx == NULL) { return ERROR_NOT_ENOUGH_MEMORY; }
memset(ctx, 0, sizeof(*ctx));
ctx->FsCtx = fsHdl; ctx->resHdl = resHdl;
// init lsa now
err = LsaInit(&ctx->LsaHandle, &ctx->LsaPack); if (err != ERROR_SUCCESS) { SrvLogError(("LsaInit failed with error %x\n", err)); free(ctx); return err; }
// init winsock now
if (WSAStartup(0x202,&ctx->wsaData) == SOCKET_ERROR) { err = WSAGetLastError(); SrvLogError(("WSAStartup failed with error %d\n", err));
free(ctx); return err; }
InitializeCriticalSection(&ctx->cs); ctx->running = FALSE; ctx->event = CreateEvent(NULL, FALSE, FALSE, NULL); ctx->waiters = 0;
if (PacketInit(ctx) != TRUE) { WSACleanup(); return ERROR_NO_SYSTEM_RESOURCES; }
SrvUtilInit(ctx);
*Hdl = (PVOID) ctx; return ERROR_SUCCESS; }
DWORD SrvOnline(PVOID Hdl, LPWSTR name, DWORD nic) { SrvCtx_t *ctx = (SrvCtx_t *) Hdl; DWORD err; int i; int nFixedThreads = 1; char localname[128]; SYSTEM_INFO sysinfo;
if (ctx == NULL) { return ERROR_INVALID_PARAMETER; }
if (nic > 0) nic--;
//
// Start up threads in suspended mode
//
if (ctx->running == TRUE) return ERROR_SUCCESS;
// save name to use
if (name != NULL) { // we need to translate name to ascii
i = wcstombs(localname, name, NETBIOS_NAME_LENGTH-1); localname[i] = '\0'; strncpy(ctx->nb_local_name, localname, NETBIOS_NAME_LENGTH); } else { // use local name and append our -crs extension
gethostname(localname, sizeof(localname)); strcat(localname, SRV_NAME_EXTENSION); strncpy(ctx->nb_local_name, localname, NETBIOS_NAME_LENGTH); }
for (i = 0; i < NETBIOS_NAME_LENGTH; i++) { ctx->nb_local_name[i] = (char) toupper(ctx->nb_local_name[i]); } // create completion port
GetSystemInfo(&sysinfo); ctx->comport = CreateIoCompletionPort(INVALID_HANDLE_VALUE, 0, 0, sysinfo.dwNumberOfProcessors*8); if (ctx->comport == INVALID_HANDLE_VALUE) { err = GetLastError(); SrvLogError(("Unable to create completion port %d\n", err)); WSACleanup(); return err; }
// create listen socket
ctx->nic = nic; err = ListenSocket(ctx, nic); if ( err != ERROR_SUCCESS) { WSACleanup(); return err; }
// start up 1 listener/receiver, a few workers, a few senders....
ctx->nThreads = nFixedThreads + SRV_NUM_SENDERS; ctx->hThreads = (HANDLE *) malloc(sizeof(HANDLE) * ctx->nThreads); if (ctx->hThreads == NULL) { WSACleanup(); return ERROR_NOT_ENOUGH_MEMORY; }
for (i = 0; i < nFixedThreads; i++) { ctx->hThreads[i] = (HANDLE) _beginthreadex(NULL, 0, &ListenThread, (LPVOID)ctx, CREATE_SUSPENDED, NULL); } for ( ; i < ctx->nThreads; i++) { ctx->hThreads[i] = (HANDLE) _beginthreadex(NULL, 0, &CompletionThread, (LPVOID)ctx, CREATE_SUSPENDED, NULL); }
ctx->running = TRUE; for (i = 0; i < ctx->nThreads; i++) ResumeThread(ctx->hThreads[i]);
return ERROR_SUCCESS; }
DWORD SrvOffline(PVOID Hdl) { int i; SrvCtx_t *ctx = (SrvCtx_t *) Hdl;
if (ctx == NULL) { return ERROR_INVALID_PARAMETER; }
// we shutdown all threads in the completion port
// we close all currently open sockets
// we free all memory
if (ctx->running) { EndPoint_t *ep; ctx->running = FALSE; closesocket(ctx->listen_socket);
EnterCriticalSection(&ctx->cs); for (ep = ctx->EndPointList; ep; ep = ep->Next) closesocket(ep->Sock); LeaveCriticalSection(&ctx->cs); SrvLog(("waiting for threads to die off...\n"));
// send a kill packet to all threads on the completion port
for (i = 0; i < ctx->nThreads; i++) { if (!PostQueuedCompletionStatus(ctx->comport, 0, 0, NULL)) { SrvLog(("Port queued port failed %d\n", GetLastError())); break; } } if (i == ctx->nThreads) { // wait for them to die of natural causes before we kill them...
WaitForMultipleObjects(ctx->nThreads, ctx->hThreads, TRUE, INFINITE); }
// close handles
for (i = 0; i < ctx->nThreads; i++) { CloseHandle(ctx->hThreads[i]); }
CloseHandle(ctx->comport);
free((char *)ctx->hThreads);
// free endpoints
EnterCriticalSection(&ctx->cs); while (ep = ctx->EndPointList) SrvCloseEndpoint(ep); LeaveCriticalSection(&ctx->cs);
}
return ERROR_SUCCESS; }
void SrvExit(PVOID Hdl) { SrvCtx_t *ctx = (SrvCtx_t *) Hdl;
if (ctx != NULL) { SrvUtilExit(ctx);
// must do this last!
if (ctx->packet_pool) xfree(ctx->packet_pool); if (ctx->buffer_pool) xfree(ctx->buffer_pool);
free(ctx); } }
|