Source code of Windows XP (NT5)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

719 lines
15 KiB

  1. /*++
  2. Copyright (c) 2000 Microsoft Corporation
  3. Module Name:
  4. srv.c
  5. Abstract:
  6. Implements initialization and socket interface for smb server
  7. Author:
  8. Ahmed Mohamed (ahmedm) 1-Feb-2000
  9. Revision History:
  10. --*/
  11. #include "srv.h"
  12. #include <process.h> // for _beginthreadex
  13. #include <mswsock.h>
  14. #define PROTOCOL_TYPE SOCK_SEQPACKET
  15. #define PLUS_CLUSTER 1
  16. #define THREADAPI unsigned int WINAPI
  17. void
  18. SrvCloseEndpoint(EndPoint_t *endpoint);
  19. void
  20. PacketReset(SrvCtx_t *ctx)
  21. {
  22. int i, npackets, nbufs;
  23. Packet_t *p;
  24. char *buf;
  25. npackets = MAX_PACKETS;
  26. nbufs = npackets * 2;
  27. ctx->freelist = NULL;
  28. p = (Packet_t *) ctx->packet_pool;
  29. buf = (char *) ctx->buffer_pool;
  30. for (i = 0; i < npackets; i++) {
  31. p->buffer = (LPVOID) buf;
  32. p->ov.hEvent = NULL;
  33. buf += SRV_PACKET_SIZE;
  34. p->outbuf = (LPVOID) buf;
  35. buf += SRV_PACKET_SIZE;
  36. p->next = ctx->freelist;
  37. ctx->freelist = p;
  38. p++;
  39. }
  40. }
  41. BOOL
  42. PacketInit(SrvCtx_t *ctx)
  43. {
  44. int npackets, nbufs;
  45. // Allocate 2 buffers for each packet
  46. npackets = MAX_PACKETS;
  47. nbufs = npackets * 2;
  48. ctx->packet_pool = xmalloc(sizeof(Packet_t) * npackets);
  49. if (ctx->packet_pool == NULL) {
  50. SrvLogError(("Unable to allocate packet pool!\n"));
  51. return FALSE;
  52. }
  53. ctx->buffer_pool = xmalloc(SRV_PACKET_SIZE * nbufs);
  54. if (ctx->buffer_pool == NULL) {
  55. xfree(ctx->packet_pool);
  56. SrvLogError(("Unable to allocate buffer pool!\n"));
  57. return FALSE;
  58. }
  59. PacketReset(ctx);
  60. return TRUE;
  61. }
  62. Packet_t *
  63. PacketAlloc(EndPoint_t *ep)
  64. {
  65. // allocate a packet from free list, if no packet is available then
  66. // we set the wanted flag and wait on event
  67. SrvCtx_t *ctx;
  68. Packet_t *p;
  69. ASSERT(ep);
  70. ctx = ep->SrvCtx;
  71. retry:
  72. EnterCriticalSection(&ctx->cs);
  73. if (ctx->running == FALSE) {
  74. LeaveCriticalSection(&ctx->cs);
  75. return NULL;
  76. }
  77. if (p = ctx->freelist) {
  78. ctx->freelist = p->next;
  79. } else {
  80. ctx->waiters++;
  81. LeaveCriticalSection(&ctx->cs);
  82. if (WaitForSingleObject(ctx->event, INFINITE) != WAIT_OBJECT_0) {
  83. return NULL;
  84. }
  85. goto retry;
  86. }
  87. // Insert into per endpoint packet list
  88. p->endpoint = ep;
  89. p->next = ep->PacketList;
  90. ep->PacketList = p;
  91. LeaveCriticalSection(&ctx->cs);
  92. return p;
  93. }
  94. void
  95. PacketRelease(SrvCtx_t *ctx, Packet_t *p)
  96. {
  97. p->next = ctx->freelist;
  98. ctx->freelist = p;
  99. if (ctx->waiters > 0) {
  100. ctx->waiters--;
  101. SetEvent(ctx->event);
  102. }
  103. }
  104. void
  105. PacketFree(Packet_t *p)
  106. {
  107. EndPoint_t *ep;
  108. SrvCtx_t *ctx;
  109. Packet_t **last;
  110. ep = p->endpoint;
  111. ASSERT(ep);
  112. ctx = ep->SrvCtx;
  113. ASSERT(ctx);
  114. // insert packet into head of freelist. if wanted flag is set, we signal event
  115. EnterCriticalSection(&ctx->cs);
  116. // Remove packet from ep list
  117. last = &ep->PacketList;
  118. while (*last != NULL) {
  119. if ((*last) == p) {
  120. *last = p->next;
  121. break;
  122. }
  123. last = &(*last)->next;
  124. }
  125. PacketRelease(ctx, p);
  126. if (ep->PacketList == NULL) {
  127. // Free this endpoint
  128. SrvCloseEndpoint(ep);
  129. }
  130. LeaveCriticalSection(&ctx->cs);
  131. }
  132. int
  133. ProcessPacket(EndPoint_t *ep, Packet_t *p)
  134. {
  135. BOOL disp;
  136. if (IsSmb(p->buffer, p->len)) {
  137. p->in.smb = (PNT_SMB_HEADER)p->buffer;
  138. p->in.size = p->len;
  139. p->in.offset = sizeof(NT_SMB_HEADER);
  140. p->in.command = p->in.smb->Command;
  141. p->out.smb = (PNT_SMB_HEADER)p->outbuf;
  142. p->out.size = SRV_PACKET_SIZE;
  143. p->out.valid = sizeof(NT_SMB_HEADER);
  144. InitSmbHeader(p);
  145. DumpSmb(p->buffer, p->len, TRUE);
  146. SrvLog(("dispatching Tid:%d Uid:%d Mid:%d Flags:%x Cmd:%d...\n",
  147. p->in.smb->Tid, p->in.smb->Uid, p->in.smb->Mid,
  148. p->in.smb->Flags2, p->in.command));
  149. p->tag = 0;
  150. disp = SrvDispatch(p);
  151. if (disp == ERROR_IO_PENDING) {
  152. return ERROR_IO_PENDING;
  153. }
  154. // If we handled it ok...
  155. if (disp) {
  156. char *buffer;
  157. int len;
  158. int rc;
  159. buffer = (char *)p->out.smb;
  160. len = (int) p->out.valid;
  161. DumpSmb(buffer, len, FALSE);
  162. SrvLog(("sending...len %d\n", len));
  163. rc = send(ep->Sock, buffer, len, 0);
  164. if (rc == SOCKET_ERROR || rc != len) {
  165. SrvLog(("Send clnt failed %d\n", WSAGetLastError()));
  166. closesocket(ep->Sock);
  167. }
  168. } else {
  169. SrvLog(("dispatch failed!\n"));
  170. // did not understand...hangup on virtual circuit...
  171. SrvLog(("hangup! -- disp failed on sock %s\n", ep->ClientId));
  172. closesocket(ep->Sock);
  173. }
  174. }
  175. return ERROR_SUCCESS;
  176. }
  177. THREADAPI
  178. CompletionThread(LPVOID arg)
  179. {
  180. Packet_t* p;
  181. DWORD len;
  182. ULONG_PTR id;
  183. LPOVERLAPPED lpo;
  184. SrvCtx_t *ctx = (SrvCtx_t *) arg;
  185. HANDLE port = ctx->comport;
  186. EndPoint_t *endpoint;
  187. HANDLE ev;
  188. ev = CreateEvent(NULL, FALSE, FALSE, NULL);
  189. // Each thread needs its own event, msg to use
  190. while(ctx->running) {
  191. BOOL b;
  192. b = GetQueuedCompletionStatus (
  193. port,
  194. &len,
  195. &id,
  196. &lpo,
  197. INFINITE
  198. );
  199. p = (Packet_t *) lpo;
  200. if (p == NULL) {
  201. SrvLog(("SrvThread exiting, %x...\n", id));
  202. CloseHandle(ev);
  203. return 0;
  204. }
  205. if (!b && !lpo) {
  206. SrvLog(("Getqueued failed %d\n",GetLastError()));
  207. CloseHandle(ev);
  208. PacketFree(p);
  209. return 0;
  210. }
  211. // todo: when socket is closed, I need to free this endpoint.
  212. // I need to tag the endpoint with how many packets got scheduled
  213. // on it, when the refcnt reachs zero, I free it.
  214. endpoint = (EndPoint_t *) id;
  215. ASSERT(p->endpoint == endpoint);
  216. p->ev = ev;
  217. p->len = len;
  218. if (ProcessPacket(endpoint, p) != ERROR_IO_PENDING) {
  219. // schedule next read
  220. b = ReadFile ((HANDLE)endpoint->Sock,
  221. p->buffer,
  222. SRV_PACKET_SIZE,
  223. &len,
  224. &p->ov);
  225. if (!b && GetLastError () != ERROR_IO_PENDING) {
  226. SrvLog(("SrvThread read ep 0x%x failed %d\n", endpoint, GetLastError()));
  227. // Return packet to queue
  228. PacketFree(p);
  229. }
  230. }
  231. }
  232. CloseHandle(ev);
  233. SrvLog(("SrvThread exiting, not running...\n"));
  234. return 0;
  235. }
  236. void
  237. SrvFinalize(Packet_t *p)
  238. {
  239. char *buffer;
  240. DWORD len, rc;
  241. EndPoint_t *endpoint = p->endpoint;
  242. ASSERT(p->tag == ERROR_IO_PENDING);
  243. p->tag = 0;
  244. buffer = (char *)p->out.smb;
  245. len = (DWORD) p->out.valid;
  246. DumpSmb(buffer, len, FALSE);
  247. SrvLog(("sending...len %d\n", len));
  248. rc = send(endpoint->Sock, buffer, len, 0);
  249. if (rc == SOCKET_ERROR || rc != len) {
  250. SrvLog(("Finalize Send clnt failed <%d>\n", WSAGetLastError()));
  251. }
  252. rc = ReadFile ((HANDLE)endpoint->Sock,
  253. p->buffer,
  254. SRV_PACKET_SIZE,
  255. &len,
  256. &p->ov);
  257. if (!rc && GetLastError () != ERROR_IO_PENDING) {
  258. // Return packet to queue
  259. PacketFree(p);
  260. }
  261. }
  262. void
  263. SrvCloseEndpoint(EndPoint_t *endpoint)
  264. {
  265. EndPoint_t **p;
  266. Packet_t *packet;
  267. // lock must be held
  268. while (packet = endpoint->PacketList) {
  269. endpoint->PacketList = packet->next;
  270. // return to free list now
  271. PacketRelease(endpoint->SrvCtx, packet);
  272. }
  273. // remove from ctx list
  274. p = &endpoint->SrvCtx->EndPointList;
  275. while (*p != NULL) {
  276. if (*p == endpoint) {
  277. *p = endpoint->Next;
  278. break;
  279. }
  280. p = &(*p)->Next;
  281. }
  282. closesocket(endpoint->Sock);
  283. // We need to inform filesystem that this
  284. // tree is gone.
  285. FsLogoffUser(endpoint->SrvCtx->FsCtx, endpoint->LogonId);
  286. free(endpoint);
  287. }
  288. DWORD
  289. ListenSocket(SrvCtx_t *ctx, int nic)
  290. {
  291. DWORD err = ERROR_SUCCESS;
  292. SOCKET listen_socket = INVALID_SOCKET;
  293. struct sockaddr_nb local;
  294. unsigned char *srvname = ctx->nb_local_name;
  295. SET_NETBIOS_SOCKADDR(&local, NETBIOS_UNIQUE_NAME, srvname, ' ');
  296. listen_socket = socket(AF_NETBIOS, PROTOCOL_TYPE, -nic);
  297. if (listen_socket == INVALID_SOCKET){
  298. err = WSAGetLastError();
  299. SrvLogError(("socket() '%s' nic %d failed with error %d\n",
  300. srvname, nic, err));
  301. return err;
  302. }
  303. //
  304. // bind socket
  305. //
  306. if (bind(listen_socket,(struct sockaddr*)&local,sizeof(local)) == SOCKET_ERROR) {
  307. err = WSAGetLastError();
  308. SrvLogError(("srv nic %d bind() failed with error %d\n",nic, err));
  309. closesocket(listen_socket);
  310. return err;
  311. }
  312. // issue listen
  313. if (listen(listen_socket,5) == SOCKET_ERROR) {
  314. err = WSAGetLastError();
  315. SrvLogError(("listen() failed with error %d\n", err));
  316. closesocket(listen_socket);
  317. return err;
  318. }
  319. // all is well.
  320. ctx->listen_socket = listen_socket;
  321. return ERROR_SUCCESS;
  322. }
  323. THREADAPI
  324. ListenThread(LPVOID arg)
  325. {
  326. SOCKET listen_socket, msgsock;
  327. struct sockaddr_nb from;
  328. int fromlen;
  329. HANDLE comport;
  330. SrvCtx_t *ctx = (SrvCtx_t *) arg;
  331. EndPoint_t *endpoint;
  332. char localname[64];
  333. gethostname(localname, sizeof(localname));
  334. listen_socket = ctx->listen_socket;
  335. comport = ctx->comport;
  336. while(ctx->running) {
  337. int i;
  338. fromlen =sizeof(from);
  339. msgsock = accept(listen_socket,(struct sockaddr*)&from, &fromlen);
  340. if (msgsock == INVALID_SOCKET) {
  341. if (ctx->running)
  342. SrvLogError(("accept() error %d\n",WSAGetLastError()));
  343. break;
  344. }
  345. from.snb_name[NETBIOS_NAME_LENGTH-1] = '\0';
  346. {
  347. char *s = strchr(from.snb_name, ' ');
  348. if (s != NULL) *s = '\0';
  349. }
  350. SrvLog(("Received call from '%s'\n", from.snb_name));
  351. // Fence off all nodes except cluster nodes. We ask
  352. // our resource to check for us. For now we fence off all nodes but the this node
  353. if (_stricmp(localname, from.snb_name)) {
  354. // sorry, we just close the connection now
  355. closesocket(msgsock);
  356. continue;
  357. }
  358. // allocate a new endpoint
  359. endpoint = (EndPoint_t *) malloc(sizeof(*endpoint));
  360. if (endpoint == NULL) {
  361. SrvLogError(("Failed allocate failed %d\n", GetLastError()));
  362. closesocket(msgsock);
  363. continue;
  364. }
  365. memset(endpoint, 0, sizeof(*endpoint));
  366. // add endpoint now
  367. EnterCriticalSection(&ctx->cs);
  368. endpoint->Next = ctx->EndPointList;
  369. ctx->EndPointList = endpoint;
  370. LeaveCriticalSection(&ctx->cs);
  371. endpoint->Sock = msgsock;
  372. endpoint->SrvCtx = ctx;
  373. memcpy(endpoint->ClientId, from.snb_name, sizeof(endpoint->ClientId));
  374. comport = CreateIoCompletionPort((HANDLE)msgsock, comport,
  375. (ULONG_PTR)endpoint, 8);
  376. if (!comport) {
  377. SrvLogError(("CompletionPort bind Failed %d\n", GetLastError()));
  378. SrvCloseEndpoint(endpoint);
  379. comport = ctx->comport;
  380. continue;
  381. }
  382. for (i = 0; i < SRV_NUM_WORKERS; i++) {
  383. Packet_t *p;
  384. BOOL b;
  385. DWORD nbytes;
  386. p = PacketAlloc(endpoint);
  387. if (p == NULL) {
  388. SrvLog(("Listen thread got null packet, exiting posted...\n"));
  389. break;
  390. }
  391. b = ReadFile (
  392. (HANDLE) msgsock,
  393. p->buffer,
  394. SRV_PACKET_SIZE,
  395. &nbytes,
  396. &p->ov);
  397. if (!b && GetLastError () != ERROR_IO_PENDING) {
  398. SrvLog(("Srv ReadFile Failed %d\n",
  399. GetLastError()));
  400. // Return packet to queue
  401. PacketFree(p);
  402. break;
  403. }
  404. }
  405. }
  406. return (0);
  407. }
  408. DWORD
  409. SrvInit(PVOID resHdl, PVOID fsHdl, PVOID *Hdl)
  410. {
  411. SrvCtx_t *ctx;
  412. DWORD err;
  413. ctx = (SrvCtx_t *) malloc(sizeof(*ctx));
  414. if (ctx == NULL) {
  415. return ERROR_NOT_ENOUGH_MEMORY;
  416. }
  417. memset(ctx, 0, sizeof(*ctx));
  418. ctx->FsCtx = fsHdl;
  419. ctx->resHdl = resHdl;
  420. // init lsa now
  421. err = LsaInit(&ctx->LsaHandle, &ctx->LsaPack);
  422. if (err != ERROR_SUCCESS) {
  423. SrvLogError(("LsaInit failed with error %x\n", err));
  424. free(ctx);
  425. return err;
  426. }
  427. // init winsock now
  428. if (WSAStartup(0x202,&ctx->wsaData) == SOCKET_ERROR) {
  429. err = WSAGetLastError();
  430. SrvLogError(("WSAStartup failed with error %d\n", err));
  431. free(ctx);
  432. return err;
  433. }
  434. InitializeCriticalSection(&ctx->cs);
  435. ctx->running = FALSE;
  436. ctx->event = CreateEvent(NULL, FALSE, FALSE, NULL);
  437. ctx->waiters = 0;
  438. if (PacketInit(ctx) != TRUE) {
  439. WSACleanup();
  440. return ERROR_NO_SYSTEM_RESOURCES;
  441. }
  442. SrvUtilInit(ctx);
  443. *Hdl = (PVOID) ctx;
  444. return ERROR_SUCCESS;
  445. }
  446. DWORD
  447. SrvOnline(PVOID Hdl, LPWSTR name, DWORD nic)
  448. {
  449. SrvCtx_t *ctx = (SrvCtx_t *) Hdl;
  450. DWORD err;
  451. int i;
  452. int nFixedThreads = 1;
  453. char localname[128];
  454. SYSTEM_INFO sysinfo;
  455. if (ctx == NULL) {
  456. return ERROR_INVALID_PARAMETER;
  457. }
  458. if (nic > 0)
  459. nic--;
  460. //
  461. // Start up threads in suspended mode
  462. //
  463. if (ctx->running == TRUE)
  464. return ERROR_SUCCESS;
  465. // save name to use
  466. if (name != NULL) {
  467. // we need to translate name to ascii
  468. i = wcstombs(localname, name, NETBIOS_NAME_LENGTH-1);
  469. localname[i] = '\0';
  470. strncpy(ctx->nb_local_name, localname, NETBIOS_NAME_LENGTH);
  471. } else {
  472. // use local name and append our -crs extension
  473. gethostname(localname, sizeof(localname));
  474. strcat(localname, SRV_NAME_EXTENSION);
  475. strncpy(ctx->nb_local_name, localname, NETBIOS_NAME_LENGTH);
  476. }
  477. for (i = 0; i < NETBIOS_NAME_LENGTH; i++) {
  478. ctx->nb_local_name[i] = (char) toupper(ctx->nb_local_name[i]);
  479. }
  480. // create completion port
  481. GetSystemInfo(&sysinfo);
  482. ctx->comport = CreateIoCompletionPort(INVALID_HANDLE_VALUE, 0, 0,
  483. sysinfo.dwNumberOfProcessors*8);
  484. if (ctx->comport == INVALID_HANDLE_VALUE) {
  485. err = GetLastError();
  486. SrvLogError(("Unable to create completion port %d\n", err));
  487. WSACleanup();
  488. return err;
  489. }
  490. // create listen socket
  491. ctx->nic = nic;
  492. err = ListenSocket(ctx, nic);
  493. if ( err != ERROR_SUCCESS) {
  494. WSACleanup();
  495. return err;
  496. }
  497. // start up 1 listener/receiver, a few workers, a few senders....
  498. ctx->nThreads = nFixedThreads + SRV_NUM_SENDERS;
  499. ctx->hThreads = (HANDLE *) malloc(sizeof(HANDLE) * ctx->nThreads);
  500. if (ctx->hThreads == NULL) {
  501. WSACleanup();
  502. return ERROR_NOT_ENOUGH_MEMORY;
  503. }
  504. for (i = 0; i < nFixedThreads; i++) {
  505. ctx->hThreads[i] = (HANDLE)
  506. _beginthreadex(NULL, 0, &ListenThread, (LPVOID)ctx, CREATE_SUSPENDED, NULL);
  507. }
  508. for ( ; i < ctx->nThreads; i++) {
  509. ctx->hThreads[i] = (HANDLE)
  510. _beginthreadex(NULL, 0, &CompletionThread, (LPVOID)ctx, CREATE_SUSPENDED, NULL);
  511. }
  512. ctx->running = TRUE;
  513. for (i = 0; i < ctx->nThreads; i++)
  514. ResumeThread(ctx->hThreads[i]);
  515. return ERROR_SUCCESS;
  516. }
  517. DWORD
  518. SrvOffline(PVOID Hdl)
  519. {
  520. int i;
  521. SrvCtx_t *ctx = (SrvCtx_t *) Hdl;
  522. if (ctx == NULL) {
  523. return ERROR_INVALID_PARAMETER;
  524. }
  525. // we shutdown all threads in the completion port
  526. // we close all currently open sockets
  527. // we free all memory
  528. if (ctx->running) {
  529. EndPoint_t *ep;
  530. ctx->running = FALSE;
  531. closesocket(ctx->listen_socket);
  532. EnterCriticalSection(&ctx->cs);
  533. for (ep = ctx->EndPointList; ep; ep = ep->Next)
  534. closesocket(ep->Sock);
  535. LeaveCriticalSection(&ctx->cs);
  536. SrvLog(("waiting for threads to die off...\n"));
  537. // send a kill packet to all threads on the completion port
  538. for (i = 0; i < ctx->nThreads; i++) {
  539. if (!PostQueuedCompletionStatus(ctx->comport, 0, 0, NULL)) {
  540. SrvLog(("Port queued port failed %d\n", GetLastError()));
  541. break;
  542. }
  543. }
  544. if (i == ctx->nThreads) {
  545. // wait for them to die of natural causes before we kill them...
  546. WaitForMultipleObjects(ctx->nThreads, ctx->hThreads, TRUE, INFINITE);
  547. }
  548. // close handles
  549. for (i = 0; i < ctx->nThreads; i++) {
  550. CloseHandle(ctx->hThreads[i]);
  551. }
  552. CloseHandle(ctx->comport);
  553. free((char *)ctx->hThreads);
  554. // free endpoints
  555. EnterCriticalSection(&ctx->cs);
  556. while (ep = ctx->EndPointList)
  557. SrvCloseEndpoint(ep);
  558. LeaveCriticalSection(&ctx->cs);
  559. }
  560. return ERROR_SUCCESS;
  561. }
  562. void
  563. SrvExit(PVOID Hdl)
  564. {
  565. SrvCtx_t *ctx = (SrvCtx_t *) Hdl;
  566. if (ctx != NULL) {
  567. SrvUtilExit(ctx);
  568. // must do this last!
  569. if (ctx->packet_pool)
  570. xfree(ctx->packet_pool);
  571. if (ctx->buffer_pool)
  572. xfree(ctx->buffer_pool);
  573. free(ctx);
  574. }
  575. }