Leaked source code of windows server 2003
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1468 lines
40 KiB

  1. ///////////////////////////////////////////////////////////////////////////////
  2. //
  3. // Copyright (c) 2000, Microsoft Corp. All rights reserved.
  4. //
  5. // FILE
  6. //
  7. // radprxy.cpp
  8. //
  9. // SYNOPSIS
  10. //
  11. // Defines the reusable RadiusProxy engine. This should have no IAS specific
  12. // dependencies.
  13. //
  14. // MODIFICATION HISTORY
  15. //
  16. // 02/08/2000 Original version.
  17. // 05/30/2000 Eliminate QUESTIONABLE state.
  18. //
  19. ///////////////////////////////////////////////////////////////////////////////
  20. #include <proxypch.h>
  21. #include <radproxyp.h>
  22. #include <radproxy.h>
  23. // Avoid dependencies on ntrtl.h
  24. extern "C" ULONG __stdcall RtlRandom(PULONG seed);
  25. // Extract a 32-bit integer from a buffer.
  26. ULONG ExtractUInt32(const BYTE* p) throw ()
  27. {
  28. return (ULONG)(p[0] << 24) | (ULONG)(p[1] << 16) |
  29. (ULONG)(p[2] << 8) | (ULONG)(p[3] );
  30. }
  31. // Insert a 32-bit integer into a buffer.
  32. void InsertUInt32(BYTE* p, ULONG val) throw ()
  33. {
  34. *p++ = (BYTE)(val >> 24);
  35. *p++ = (BYTE)(val >> 16);
  36. *p++ = (BYTE)(val >> 8);
  37. *p = (BYTE)(val );
  38. }
  39. //
  40. // Layout of a Microsoft State attribute
  41. //
  42. // struct MicrosoftState
  43. // {
  44. // BYTE checksum[4];
  45. // BYTE vendorID[4];
  46. // BYTE version[2];
  47. // BYTE serverAddress[4];
  48. // BYTE sourceID[4];
  49. // BYTE sessionID[4];
  50. // };
  51. //
  52. // Extracts the creators address from a State attribute or INADDR_NONE if this
  53. // isn't a valid Microsoft State attributes.
  54. ULONG ExtractAddressFromState(const RadiusAttribute& state) throw ()
  55. {
  56. if (state.length == 22 &&
  57. !memcmp(state.value + 4, "\x00\x00\x01\x37\x00\x01", 6) &&
  58. IASAdler32(state.value + 4, 18) == ExtractUInt32(state.value))
  59. {
  60. return ExtractUInt32(state.value + 10);
  61. }
  62. return INADDR_NONE;
  63. }
  64. // Returns true if this is an Accounting-On/Off packet.
  65. bool IsNasStateRequest(const RadiusPacket& packet) throw ()
  66. {
  67. const RadiusAttribute* status = FindAttribute(
  68. packet,
  69. RADIUS_ACCT_STATUS_TYPE
  70. );
  71. if (!status) { return false; }
  72. ULONG value = ExtractUInt32(status->value);
  73. return value == 7 || value == 8;
  74. }
  75. RemotePort::RemotePort(
  76. ULONG ipAddress,
  77. USHORT port,
  78. PCSTR sharedSecret
  79. )
  80. : address(ipAddress, port),
  81. secret((const BYTE*)sharedSecret, strlen(sharedSecret))
  82. {
  83. }
  84. RemotePort::RemotePort(const RemotePort& port)
  85. : address(port.address),
  86. secret(port.secret),
  87. nextIdentifier(port.nextIdentifier)
  88. {
  89. }
  90. RemoteServer::RemoteServer(
  91. const RemoteServerConfig& config
  92. )
  93. : guid(config.guid),
  94. authPort(config.ipAddress, config.authPort, config.authSecret),
  95. acctPort(config.ipAddress, config.acctPort, config.acctSecret),
  96. timeout(config.timeout),
  97. maxEvents((LONG)config.maxLost),
  98. blackout(config.blackout),
  99. priority(config.priority),
  100. weight(config.weight),
  101. sendSignature(config.sendSignature),
  102. sendAcctOnOff(config.sendAcctOnOff),
  103. usable(true),
  104. onProbation(false),
  105. eventCount(0),
  106. expiry(0)
  107. {
  108. }
  109. bool RemoteServer::shouldBroadcast() throw ()
  110. {
  111. bool broadcastable = false;
  112. if (!onProbation && !usable)
  113. {
  114. ULONG64 now = GetSystemTime64();
  115. lock.lock();
  116. // Has the blackout interval expired ?
  117. if (now > expiry)
  118. {
  119. // Yes, so set a new expiration.
  120. expiry = now + blackout * 10000i64;
  121. broadcastable = true;
  122. }
  123. lock.unlock();
  124. }
  125. return broadcastable;
  126. }
  127. bool RemoteServer::onReceive(BYTE code) throw ()
  128. {
  129. const bool authoritative = (code != RADIUS_ACCESS_CHALLENGE);
  130. // Did the server transition from unavailable to available?
  131. bool downToUp = false;
  132. lock.lock();
  133. if (onProbation)
  134. {
  135. if (authoritative)
  136. {
  137. // Bump the success count.
  138. if (++eventCount >= maxEvents)
  139. {
  140. // We're off probation w/ a lost count of zero.
  141. onProbation = false;
  142. eventCount = 0;
  143. downToUp = true;
  144. }
  145. // We successfully finished a request, so we can send another.
  146. usable = true;
  147. }
  148. }
  149. else if (usable)
  150. {
  151. if (authoritative)
  152. {
  153. // An authoritative response resets the lost count.
  154. eventCount = 0;
  155. }
  156. }
  157. else
  158. {
  159. // An unavailable server has responded to a broadcast, so put it on
  160. // probation. Set the success count accordingly.
  161. usable = true;
  162. onProbation = true;
  163. eventCount = authoritative ? 1 : 0;
  164. }
  165. lock.unlock();
  166. return downToUp;
  167. }
  168. void RemoteServer::onSend() throw ()
  169. {
  170. if (onProbation)
  171. {
  172. lock.lock();
  173. if (onProbation)
  174. {
  175. // Probationary servers can only send one request at a time.
  176. usable = false;
  177. }
  178. lock.unlock();
  179. }
  180. }
  181. bool RemoteServer::onTimeout() throw ()
  182. {
  183. // Did the server transition from available to unavailable?
  184. bool upToDown = false;
  185. lock.lock();
  186. if (onProbation)
  187. {
  188. // Sudden death for probationary servers. Move it straight to
  189. // unavailable.
  190. usable = false;
  191. onProbation = false;
  192. expiry = GetSystemTime64() + blackout * 10000ui64;
  193. }
  194. else if (usable)
  195. {
  196. // Bump the lost count.
  197. if (++eventCount >= maxEvents)
  198. {
  199. // Server is now unavailable.
  200. usable = false;
  201. expiry = GetSystemTime64() + blackout * 10000ui64;
  202. upToDown = true;
  203. }
  204. }
  205. else
  206. {
  207. // If the server is already unavailable, ignore the timeout.
  208. }
  209. lock.unlock();
  210. return upToDown;
  211. }
  212. void RemoteServer::copyState(const RemoteServer& target) throw ()
  213. {
  214. // Synchronize the ports.
  215. authPort.copyState(target.authPort);
  216. acctPort.copyState(target.acctPort);
  217. // Synchronize server availability.
  218. usable = target.usable;
  219. onProbation = target.onProbation;
  220. eventCount = target.eventCount;
  221. expiry = target.expiry;
  222. }
  223. bool RemoteServer::operator==(const RemoteServer& s) const throw ()
  224. {
  225. return authPort == s.authPort &&
  226. acctPort == s.acctPort &&
  227. priority == s.priority &&
  228. weight == s.weight &&
  229. timeout == s.timeout &&
  230. eventCount == s.eventCount &&
  231. blackout == s.blackout &&
  232. sendSignature == s.sendSignature &&
  233. sendAcctOnOff == s.sendAcctOnOff;
  234. }
  235. //////////
  236. // Used for sorting servers by priority.
  237. //////////
  238. int __cdecl sortServersByPriority(
  239. const RemoteServer* const* server1,
  240. const RemoteServer* const* server2
  241. ) throw ()
  242. {
  243. return (int)(*server1)->priority - (int)(*server2)->priority;
  244. }
  245. ULONG ServerGroup::theSeed;
  246. ServerGroup::ServerGroup(
  247. PCWSTR groupName,
  248. RemoteServer* const* first,
  249. RemoteServer* const* last
  250. )
  251. : servers(first, last),
  252. name(groupName)
  253. {
  254. // We don't allow empty groups.
  255. if (servers.empty()) { _com_issue_error(E_INVALIDARG); }
  256. if (theSeed == 0)
  257. {
  258. FILETIME ft;
  259. GetSystemTimeAsFileTime(&ft);
  260. theSeed = ft.dwLowDateTime | ft.dwHighDateTime;
  261. }
  262. // Sort by priority.
  263. servers.sort(sortServersByPriority);
  264. // Find the end of the top priority servers. This will be useful when doing
  265. // a forced pick.
  266. ULONG topPriority = (*servers.begin())->priority;
  267. for (endTopPriority = servers.begin();
  268. endTopPriority != servers.end();
  269. ++endTopPriority)
  270. {
  271. if ((*endTopPriority)->priority != topPriority) { break; }
  272. }
  273. // Find the max number of servers at any priority level. This will be useful
  274. // when allocating a buffer to hold the candidates.
  275. ULONG maxCount = 0, count = 0, priority = (*servers.begin())->priority;
  276. for (RemoteServer* const* i = begin(); i != end(); ++i)
  277. {
  278. if ((*i)->priority != priority)
  279. {
  280. priority = (*i)->priority;
  281. count = 0;
  282. }
  283. if (++count > maxCount) { maxCount = count; }
  284. }
  285. maxCandidatesSize = maxCount * sizeof(RemoteServer*);
  286. }
  287. RemoteServer* ServerGroup::pickServer(
  288. RemoteServers::iterator first,
  289. RemoteServers::iterator last,
  290. const RemoteServer* avoid
  291. ) throw ()
  292. {
  293. // If the list has exactly one entry, there's nothing to do.
  294. if (last == first + 1) { return *first; }
  295. RemoteServer* const* i;
  296. // Compute the combined weight off all the servers.
  297. ULONG weight = 0;
  298. for (i = first; i != last; ++i)
  299. {
  300. if (*i != avoid)
  301. {
  302. weight += (*i)->weight;
  303. }
  304. }
  305. // Pick a random number from [0, weight)
  306. ULONG offset = (ULONG)
  307. (((ULONG64)RtlRandom(&theSeed) * (ULONG64)weight) >> 31);
  308. // We don't test the last server since if we make it that far we have to use
  309. // it anyway.
  310. --last;
  311. // Iterate through the candidates until we reach the offset.
  312. for (i = first; i != last; ++i)
  313. {
  314. if (*i != avoid)
  315. {
  316. if ((*i)->weight >= offset) { break; }
  317. offset -= (*i)->weight;
  318. }
  319. }
  320. return *i;
  321. }
  322. void ServerGroup::getServersForRequest(
  323. ProxyContext* context,
  324. BYTE packetCode,
  325. const RemoteServer* avoid,
  326. RequestStack& result
  327. ) const
  328. {
  329. // List of candidates.
  330. RemoteServer** first = (RemoteServer**)_alloca(maxCandidatesSize);
  331. RemoteServer** last = first;
  332. // Iterate through the servers.
  333. ULONG maxPriority = (ULONG)-1;
  334. for (RemoteServer* const* i = servers.begin(); i != servers.end(); ++i)
  335. {
  336. // If this test fails, we must have found a higher priority server that's
  337. // usable.
  338. if ((*i)->priority > maxPriority) { break; }
  339. if ((*i)->isUsable())
  340. {
  341. // Don't consider lower priority servers.
  342. maxPriority = (*i)->priority;
  343. // Add this to the list of candidates.
  344. *last++ = *i;
  345. }
  346. else if ((*i)->shouldBroadcast())
  347. {
  348. // It's not available, but it's ready for a broadcast
  349. result.push(new Request(context, *i, packetCode));
  350. }
  351. }
  352. if (first == last)
  353. {
  354. // No usable servers, so look for in progress servers.
  355. maxPriority = (ULONG)-1;
  356. for (RemoteServer* const* i = servers.begin(); i != servers.end(); ++i)
  357. {
  358. // If this test fails, we must have found a higher priority server
  359. // that's in progress.
  360. if ((*i)->priority > maxPriority) { break; }
  361. if ((*i)->isInProgress())
  362. {
  363. // Don't consider lower priority servers.
  364. maxPriority = (*i)->priority;
  365. // Add this to the list of candidates.
  366. *last++ = *i;
  367. }
  368. }
  369. }
  370. if (first != last)
  371. {
  372. // We have at least one candidate, so pick one and add it to the list.
  373. result.push(new Request(
  374. context,
  375. pickServer(first, last, avoid),
  376. packetCode
  377. ));
  378. }
  379. else if (result.empty() && !servers.empty())
  380. {
  381. // We have no candidates and no servers available for broadcast, so just
  382. // force a pick from the top priority servers.
  383. result.push(new Request(
  384. context,
  385. pickServer(servers.begin(), endTopPriority, avoid),
  386. packetCode
  387. ));
  388. }
  389. }
  390. //////////
  391. // Used for sorting and searching groups by name.
  392. //////////
  393. int __cdecl sortGroupsByName(
  394. const ServerGroup* const* group1,
  395. const ServerGroup* const* group2
  396. ) throw ()
  397. {
  398. return wcscmp((*group1)->getName(), (*group2)->getName());
  399. }
  400. int __cdecl findGroupByName(
  401. const void* key,
  402. const ServerGroup* const* group
  403. ) throw ()
  404. {
  405. return wcscmp((PCWSTR)key, (*group)->getName());
  406. }
  407. //////////
  408. // Used for sorting and searching servers by address.
  409. //////////
  410. int __cdecl sortServersByAddress(
  411. const RemoteServer* const* server1,
  412. const RemoteServer* const* server2
  413. )
  414. {
  415. if ((*server1)->getAddress() < (*server2)->getAddress()) { return -1; }
  416. if ((*server1)->getAddress() > (*server2)->getAddress()) { return 1; }
  417. return 0;
  418. }
  419. int __cdecl findServerByAddress(
  420. const void* key,
  421. const RemoteServer* const* server
  422. ) throw ()
  423. {
  424. if ((ULONG_PTR)key < (*server)->getAddress()) { return -1; }
  425. if ((ULONG_PTR)key > (*server)->getAddress()) { return 1; }
  426. return 0;
  427. }
  428. //////////
  429. // Used for sorting and searching servers by guid.
  430. //////////
  431. int __cdecl sortServersByGUID(
  432. const RemoteServer* const* server1,
  433. const RemoteServer* const* server2
  434. ) throw ()
  435. {
  436. return memcmp(&(*server1)->guid, &(*server2)->guid, sizeof(GUID));
  437. }
  438. int __cdecl findServerByGUID(
  439. const void* key,
  440. const RemoteServer* const* server
  441. ) throw ()
  442. {
  443. return memcmp(key, &(*server)->guid, sizeof(GUID));
  444. }
  445. //////////
  446. // Used for sorting accounting servers by port.
  447. //////////
  448. int __cdecl sortServersByAcctPort(
  449. const RemoteServer* const* server1,
  450. const RemoteServer* const* server2
  451. )
  452. {
  453. const sockaddr_in& a1 = (*server1)->acctPort.address;
  454. const sockaddr_in& a2 = (*server2)->acctPort.address;
  455. return memcmp(&a1.sin_port, &a2.sin_port, 6);
  456. }
  457. bool ServerGroupManager::setServerGroups(
  458. ServerGroup* const* first,
  459. ServerGroup* const* last
  460. ) throw ()
  461. {
  462. bool success;
  463. try
  464. {
  465. // Save the new server groups ...
  466. ServerGroups newGroups(first, last);
  467. // Sort by name.
  468. newGroups.sort(sortGroupsByName);
  469. // Useful iterators.
  470. ServerGroups::iterator i;
  471. RemoteServers::iterator j;
  472. // Count the number of servers and accounting servers.
  473. ULONG count = 0, acctCount = 0;
  474. for (i = first; i != last; ++i)
  475. {
  476. for (j = (*i)->begin(); j != (*i)->end(); ++j)
  477. {
  478. ++count;
  479. if ((*j)->sendAcctOnOff) { ++acctCount; }
  480. }
  481. }
  482. // Reserve space for the servers.
  483. RemoteServers newServers(count);
  484. RemoteServers newAcctServers(acctCount);
  485. // Populate the servers.
  486. for (i = first; i != last; ++i)
  487. {
  488. for (j = (*i)->begin(); j != (*i)->end(); ++j)
  489. {
  490. RemoteServer* newServer = *j;
  491. // Does this server already exist?
  492. RemoteServer* existing = byGuid.search(
  493. (const void*)&newServer->guid,
  494. findServerByGUID
  495. );
  496. if (existing)
  497. {
  498. if (*existing == *newServer)
  499. {
  500. // If it's an exact match, use the existing server.
  501. newServer = existing;
  502. }
  503. else
  504. {
  505. // Otherwise, copy the state of the existing server.
  506. newServer->copyState(*existing);
  507. }
  508. }
  509. newServers.push_back(newServer);
  510. if (newServer->sendAcctOnOff)
  511. {
  512. newAcctServers.push_back(newServer);
  513. }
  514. }
  515. }
  516. // Sort the servers by address ...
  517. newServers.sort(sortServersByAddress);
  518. // ... and GUID.
  519. RemoteServers newServersByGuid(newServers);
  520. newServersByGuid.sort(sortServersByGUID);
  521. // Everything is ready so now we grab the write lock ...
  522. monitor.LockExclusive();
  523. // ... and swap in the collections.
  524. groups.swap(newGroups);
  525. byAddress.swap(newServers);
  526. byGuid.swap(newServersByGuid);
  527. acctServers.swap(newAcctServers);
  528. monitor.Unlock();
  529. success = true;
  530. }
  531. catch (const std::bad_alloc&)
  532. {
  533. success = false;
  534. }
  535. return success;
  536. }
  537. RemoteServerPtr ServerGroupManager::findServer(
  538. ULONG address
  539. ) const throw ()
  540. {
  541. monitor.Lock();
  542. RemoteServer* server = byAddress.search(
  543. (const void*)ULongToPtr(address),
  544. findServerByAddress
  545. );
  546. monitor.Unlock();
  547. return server;
  548. }
  549. void ServerGroupManager::getServersByGroup(
  550. ProxyContext* context,
  551. BYTE packetCode,
  552. PCWSTR name,
  553. const RemoteServer* avoid,
  554. RequestStack& result
  555. ) const throw ()
  556. {
  557. monitor.Lock();
  558. ServerGroup* group = groups.search(name, findGroupByName);
  559. if (group)
  560. {
  561. group->getServersForRequest(context, packetCode, avoid, result);
  562. }
  563. monitor.Unlock();
  564. }
  565. void ServerGroupManager::getServersForAcctOnOff(
  566. ProxyContext* context,
  567. RequestStack& result
  568. ) const
  569. {
  570. monitor.Lock();
  571. for (RemoteServer* const* i = acctServers.begin();
  572. i != acctServers.end();
  573. ++i)
  574. {
  575. result.push(new Request(context, *i, RADIUS_ACCOUNTING_REQUEST));
  576. }
  577. monitor.Unlock();
  578. }
  579. RadiusProxyEngine* RadiusProxyEngine::theProxy;
  580. RadiusProxyEngine::RadiusProxyEngine(RadiusProxyClient* source) throw ()
  581. : client(source),
  582. proxyAddress(INADDR_NONE),
  583. pending(Request::hash, 1),
  584. sessions(ServerBinding::hash, 1, 10000, (2 * 60 * 1000), true),
  585. avoid(ServerBinding::hash, 1, 10000, (35 * 60 * 1000), false),
  586. crypto(0)
  587. {
  588. theProxy = this;
  589. // We don't care if this fails. The proxy will just use INADDR_NONE in it's
  590. // proxy-state attribute.
  591. PHOSTENT he = IASGetHostByName(NULL);
  592. if (he)
  593. {
  594. if (he->h_addr_list[0])
  595. {
  596. proxyAddress = *(PULONG)he->h_addr_list[0];
  597. }
  598. LocalFree(he);
  599. }
  600. }
  601. RadiusProxyEngine::~RadiusProxyEngine() throw ()
  602. {
  603. // Block any new reponses.
  604. authSock.close();
  605. acctSock.close();
  606. // Clear the pending request table.
  607. pending.clear();
  608. // Cancel all the timers.
  609. timers.cancelAllTimers();
  610. // At this point all our threads should be done, but let's just make sure.
  611. SwitchToThread();
  612. if (crypto != 0)
  613. {
  614. CryptReleaseContext(crypto, 0);
  615. }
  616. theProxy = NULL;
  617. }
  618. HRESULT RadiusProxyEngine::finalConstruct() throw ()
  619. {
  620. HRESULT hr = S_OK;
  621. if (!CryptAcquireContext(
  622. &crypto,
  623. 0,
  624. 0,
  625. PROV_RSA_FULL,
  626. CRYPT_VERIFYCONTEXT
  627. ))
  628. {
  629. DWORD error = GetLastError();
  630. hr = HRESULT_FROM_WIN32(error);
  631. }
  632. return hr;
  633. }
  634. bool RadiusProxyEngine::setServerGroups(
  635. ServerGroup* const* begin,
  636. ServerGroup* const* end
  637. ) throw ()
  638. {
  639. // We don't open the sockets unless we actually have some server groups
  640. // configured. This is just to be a good corporate citizen.
  641. if (begin != end)
  642. {
  643. if ((!authSock.isOpen() && !authSock.open(this, portAuthentication)) ||
  644. (!acctSock.isOpen() && !acctSock.open(this, portAccounting)))
  645. {
  646. return false;
  647. }
  648. }
  649. return groups.setServerGroups(begin, end);
  650. }
  651. void RadiusProxyEngine::forwardRequest(
  652. PVOID context,
  653. PCWSTR serverGroup,
  654. BYTE code,
  655. const BYTE* requestAuthenticator,
  656. const RadiusAttribute* begin,
  657. const RadiusAttribute* end
  658. ) throw ()
  659. {
  660. // Save the request context. We have to handle this carefully since we rely
  661. // on the ProxyContext object to ensure that onComplete gets called exactly
  662. // one. If we can't allocate the object, we have to handle it specially.
  663. ProxyContextPtr ctxt(new (std::nothrow) ProxyContext(context));
  664. if (!ctxt)
  665. {
  666. client->onComplete(
  667. resultNotEnoughMemory,
  668. context,
  669. NULL,
  670. code,
  671. NULL,
  672. NULL
  673. );
  674. return;
  675. }
  676. Result retval = resultUnknownServerGroup;
  677. try
  678. {
  679. // Store the in parameters in a RadiusPacket struct.
  680. RadiusPacket packet;
  681. packet.code = code;
  682. packet.begin = const_cast<RadiusAttribute*>(begin);
  683. packet.end = const_cast<RadiusAttribute*>(end);
  684. // Generate the list of RADIUS requests to be sent.
  685. RequestStack requests;
  686. switch (code)
  687. {
  688. case RADIUS_ACCESS_REQUEST:
  689. {
  690. // Is this request associated with a particular server?
  691. RemoteServerPtr server = getServerAffinity(packet);
  692. if (server)
  693. {
  694. requests.push(new Request(ctxt, server, RADIUS_ACCESS_REQUEST));
  695. }
  696. else
  697. {
  698. server = getServerAvoidance(packet);
  699. groups.getServersByGroup(
  700. ctxt,
  701. code,
  702. serverGroup,
  703. server,
  704. requests
  705. );
  706. }
  707. // Put request authenticator in the packet. The request
  708. // authenticator can be NULL. The authenticator will not be
  709. // changed.
  710. packet.authenticator = requestAuthenticator;
  711. break;
  712. }
  713. case RADIUS_ACCOUNTING_REQUEST:
  714. {
  715. if (!IsNasStateRequest(packet))
  716. {
  717. groups.getServersByGroup(
  718. ctxt,
  719. code,
  720. serverGroup,
  721. 0,
  722. requests
  723. );
  724. }
  725. else
  726. {
  727. groups.getServersForAcctOnOff(
  728. ctxt,
  729. requests
  730. );
  731. // NAS State requests are always reported as a success since we
  732. // don't care if it gets to all the destinations.
  733. context = ctxt->takeOwnership();
  734. if (context)
  735. {
  736. client->onComplete(
  737. resultSuccess,
  738. context,
  739. NULL,
  740. RADIUS_ACCOUNTING_RESPONSE,
  741. NULL,
  742. NULL
  743. );
  744. }
  745. retval = resultSuccess;
  746. }
  747. break;
  748. }
  749. default:
  750. {
  751. retval = resultInvalidRequest;
  752. }
  753. }
  754. if (!requests.empty())
  755. {
  756. // First we handle the primary.
  757. RequestPtr request = requests.pop();
  758. ctxt->setPrimaryServer(&request->getServer());
  759. retval = sendRequest(packet, request);
  760. // Now we broadcast.
  761. while (!requests.empty())
  762. {
  763. request = requests.pop();
  764. Result result = sendRequest(packet, request);
  765. if (result == resultSuccess && retval != resultSuccess)
  766. {
  767. // This was the first request to succeed so mark it as primary.
  768. retval = resultSuccess;
  769. ctxt->setPrimaryServer(&request->getServer());
  770. }
  771. }
  772. }
  773. }
  774. catch (const std::bad_alloc&)
  775. {
  776. retval = resultNotEnoughMemory;
  777. }
  778. if (retval != resultSuccess)
  779. {
  780. // If we made it here, then we didn't successfully send a packet to any
  781. // server, so we have to report the result ourself.
  782. context = ctxt->takeOwnership();
  783. if (context)
  784. {
  785. client->onComplete(
  786. retval,
  787. context,
  788. ctxt->getPrimaryServer(),
  789. code,
  790. NULL,
  791. NULL
  792. );
  793. }
  794. }
  795. }
  796. void RadiusProxyEngine::onRequestAbandoned(
  797. PVOID context,
  798. RemoteServer* server
  799. ) throw ()
  800. {
  801. // Nobody took responsibility for the request, so we time it out.
  802. theProxy->client->onComplete(
  803. resultRequestTimeout,
  804. context,
  805. server,
  806. 0,
  807. NULL,
  808. NULL
  809. );
  810. }
  811. inline void RadiusProxyEngine::reportEvent(
  812. const RadiusEvent& event
  813. ) const throw ()
  814. {
  815. client->onEvent(event);
  816. }
  817. inline void RadiusProxyEngine::reportEvent(
  818. RadiusEvent& event,
  819. RadiusEventType type
  820. ) const throw ()
  821. {
  822. event.eventType = type;
  823. client->onEvent(event);
  824. }
  825. void RadiusProxyEngine::onRequestTimeout(
  826. Request* request
  827. ) throw ()
  828. {
  829. // Erase the pending request. If it's not there, that's okay; it means that
  830. // we received a response, but weren't able to cancel the timer in time.
  831. if (theProxy->pending.erase(request->getRequestID()))
  832. {
  833. // Avoid this server next time.
  834. theProxy->setServerAvoidance(*request);
  835. RadiusEvent event =
  836. {
  837. request->getPortType(),
  838. eventTimeout,
  839. &request->getServer(),
  840. request->getPort().address.address(),
  841. request->getPort().address.port()
  842. };
  843. // Report the protocol event.
  844. theProxy->reportEvent(event);
  845. // Update request state.
  846. if (request->onTimeout())
  847. {
  848. // The server was just marked unavailable, so notify the client.
  849. theProxy->reportEvent(event, eventServerUnavailable);
  850. }
  851. }
  852. }
  853. RemoteServerPtr RadiusProxyEngine::getServerAffinity(
  854. const RadiusPacket& packet
  855. ) throw ()
  856. {
  857. // Find the State attribute.
  858. const RadiusAttribute* attr = FindAttribute(packet, RADIUS_STATE);
  859. if (!attr) { return NULL; }
  860. // Map it to a session.
  861. RadiusRawOctets key = { attr->value, attr->length };
  862. ServerBindingPtr session = sessions.find(key);
  863. if (!session) { return NULL; }
  864. return &session->getServer();
  865. }
  866. void RadiusProxyEngine::setServerAffinity(
  867. const RadiusPacket& packet,
  868. RemoteServer& server
  869. ) throw ()
  870. {
  871. // Is this an Access-Challenge ?
  872. if (packet.code != RADIUS_ACCESS_CHALLENGE) { return; }
  873. // Find the State attribute.
  874. const RadiusAttribute* state = FindAttribute(packet, RADIUS_STATE);
  875. if (!state) { return; }
  876. // Do we already have an entry for this State value.
  877. RadiusRawOctets key = { state->value, state->length };
  878. ServerBindingPtr session = sessions.find(key);
  879. if (session)
  880. {
  881. // Make sure the server matches.
  882. session->setServer(server);
  883. return;
  884. }
  885. // Otherwise, we'll have to create a new one.
  886. try
  887. {
  888. session = new ServerBinding(key, server);
  889. sessions.insert(*session);
  890. }
  891. catch (const std::bad_alloc&)
  892. {
  893. // We don't care if this fails.
  894. }
  895. }
  896. void RadiusProxyEngine::clearServerAvoidance(
  897. const RadiusPacket& packet,
  898. RemoteServer& server
  899. ) throw ()
  900. {
  901. // Is this packet authoritative?
  902. if ((packet.code == RADIUS_ACCESS_ACCEPT) ||
  903. (packet.code == RADIUS_ACCESS_REJECT))
  904. {
  905. // Find the User-Name attribute.
  906. const RadiusAttribute* attr = FindAttribute(packet, RADIUS_USER_NAME);
  907. if (attr != 0)
  908. {
  909. // Map it to a server.
  910. RadiusRawOctets key = { attr->value, attr->length };
  911. ServerBindingPtr avoidance = avoid.find(key);
  912. if (avoidance && (avoidance->getServer() == server))
  913. {
  914. avoid.erase(key);
  915. }
  916. }
  917. }
  918. }
  919. RemoteServerPtr RadiusProxyEngine::getServerAvoidance(
  920. const RadiusPacket& packet
  921. ) throw ()
  922. {
  923. // Find the User-Name attribute.
  924. const RadiusAttribute* attr = FindAttribute(packet, RADIUS_USER_NAME);
  925. if (!attr) { return NULL; }
  926. // Map it to a server.
  927. RadiusRawOctets key = { attr->value, attr->length };
  928. ServerBindingPtr avoidance = avoid.find(key);
  929. if (!avoidance) { return NULL; }
  930. return &avoidance->getServer();
  931. }
  932. void RadiusProxyEngine::setServerAvoidance(const Request& request) throw ()
  933. {
  934. if ((request.getCode() != RADIUS_ACCESS_REQUEST) ||
  935. (request.getUserName().len == 0))
  936. {
  937. return;
  938. }
  939. // Do we already have an entry for this User-Name value.
  940. ServerBindingPtr avoidance = avoid.find(request.getUserName());
  941. if (avoidance)
  942. {
  943. // Make sure the server matches.
  944. avoidance->setServer(request.getServer());
  945. return;
  946. }
  947. // Otherwise, we'll have to create a new one.
  948. try
  949. {
  950. avoidance = new ServerBinding(
  951. request.getUserName(),
  952. request.getServer()
  953. );
  954. avoid.insert(*avoidance);
  955. }
  956. catch (const std::bad_alloc&)
  957. {
  958. // We don't care if this fails.
  959. }
  960. }
  961. void RadiusProxyEngine::onReceive(
  962. UDPSocket& socket,
  963. ULONG_PTR key,
  964. const SOCKADDR_IN& remoteAddress,
  965. BYTE* buffer,
  966. ULONG bufferLength
  967. ) throw ()
  968. {
  969. //////////
  970. // Set up the event struct. We'll fill in the other fields as we go along.
  971. //////////
  972. RadiusEvent event =
  973. {
  974. (RadiusPortType)key,
  975. eventNone,
  976. NULL,
  977. remoteAddress.sin_addr.s_addr,
  978. remoteAddress.sin_port,
  979. buffer,
  980. bufferLength,
  981. 0
  982. };
  983. //////////
  984. // Validate the remote address.
  985. //////////
  986. RemoteServerPtr server = groups.findServer(
  987. remoteAddress.sin_addr.s_addr
  988. );
  989. if (!server)
  990. {
  991. reportEvent(event, eventInvalidAddress);
  992. return;
  993. }
  994. // Use the server as the event context.
  995. event.context = server;
  996. //////////
  997. // Validate the packet type.
  998. //////////
  999. if (bufferLength == 0)
  1000. {
  1001. reportEvent(event, eventUnknownType);
  1002. return;
  1003. }
  1004. switch (MAKELONG(key, buffer[0]))
  1005. {
  1006. case MAKELONG(portAuthentication, RADIUS_ACCESS_ACCEPT):
  1007. reportEvent(event, eventAccessAccept);
  1008. break;
  1009. case MAKELONG(portAuthentication, RADIUS_ACCESS_REJECT):
  1010. reportEvent(event, eventAccessReject);
  1011. break;
  1012. case MAKELONG(portAuthentication, RADIUS_ACCESS_CHALLENGE):
  1013. reportEvent(event, eventAccessChallenge);
  1014. break;
  1015. case MAKELONG(portAccounting, RADIUS_ACCOUNTING_RESPONSE):
  1016. reportEvent(event, eventAccountingResponse);
  1017. break;
  1018. default:
  1019. reportEvent(event, eventUnknownType);
  1020. return;
  1021. }
  1022. //////////
  1023. // Validate that the packet is properly formatted.
  1024. //////////
  1025. RadiusPacket* packet;
  1026. ALLOC_PACKET_FOR_BUFFER(packet, buffer, bufferLength);
  1027. if (!packet)
  1028. {
  1029. reportEvent(event, eventMalformedPacket);
  1030. return;
  1031. }
  1032. // Unpack the attributes.
  1033. UnpackBuffer(buffer, bufferLength, *packet);
  1034. //////////
  1035. // Validate that we were expecting this response.
  1036. //////////
  1037. // Look for our Proxy-State attribute.
  1038. RadiusAttribute* proxyState = FindAttribute(
  1039. *packet,
  1040. RADIUS_PROXY_STATE
  1041. );
  1042. // If we didn't find it OR it's the wrong length OR it doesn't start with
  1043. // our address, then we weren't expecting this packet.
  1044. if (!proxyState ||
  1045. proxyState->length != 8 ||
  1046. memcmp(proxyState->value, &proxyAddress, 4))
  1047. {
  1048. reportEvent(event, eventUnexpectedResponse);
  1049. return;
  1050. }
  1051. // Extract the request ID.
  1052. ULONG requestID = ExtractUInt32(proxyState->value + 4);
  1053. // Don't send the Proxy-State back to our client.
  1054. --packet->end;
  1055. memmove(
  1056. proxyState,
  1057. proxyState + 1,
  1058. (packet->end - proxyState) * sizeof(RadiusAttribute)
  1059. );
  1060. // Look up the request object. We don't remove it yet because we don't know
  1061. // if this is an authentic response.
  1062. RequestPtr request = pending.find(requestID);
  1063. if (!request)
  1064. {
  1065. // If it's not there, we'll assume that this is a packet that's
  1066. // already been reported as a timeout.
  1067. reportEvent(event, eventLateResponse);
  1068. return;
  1069. }
  1070. // Get the actual server we used for the request in case there are multiple
  1071. // servers defined for the same IP address.
  1072. event.context = server = &request->getServer();
  1073. const RemotePort& port = request->getPort();
  1074. // Validate the packet source && identifier.
  1075. if (!(port.address == remoteAddress) ||
  1076. request->getIdentifier() != packet->identifier)
  1077. {
  1078. reportEvent(event, eventUnexpectedResponse);
  1079. return;
  1080. }
  1081. //////////
  1082. // Validate that the packet is authentic.
  1083. //////////
  1084. AuthResult authResult = AuthenticateAndDecrypt(
  1085. request->getAuthenticator(),
  1086. port.secret,
  1087. port.secret.length(),
  1088. buffer,
  1089. bufferLength,
  1090. *packet
  1091. );
  1092. switch (authResult)
  1093. {
  1094. case AUTH_BAD_AUTHENTICATOR:
  1095. reportEvent(event, eventBadAuthenticator);
  1096. return;
  1097. case AUTH_BAD_SIGNATURE:
  1098. reportEvent(event, eventBadSignature);
  1099. return;
  1100. case AUTH_MISSING_SIGNATURE:
  1101. reportEvent(event, eventMissingSignature);
  1102. return;
  1103. }
  1104. //////////
  1105. // At this point, all the tests have passed -- we have the real thing.
  1106. //////////
  1107. if (!pending.erase(requestID))
  1108. {
  1109. // It must have timed out while we were authenticating it.
  1110. reportEvent(event, eventLateResponse);
  1111. return;
  1112. }
  1113. // Update endpoint state.
  1114. if (request->onReceive(packet->code))
  1115. {
  1116. // The server just came up, so notify the client.
  1117. reportEvent(event, eventServerAvailable);
  1118. }
  1119. // Report the round-trip time.
  1120. event.data = request->getRoundTripTime();
  1121. reportEvent(event, eventRoundTrip);
  1122. // Set the server affinity and clear the server avoidance.
  1123. setServerAffinity(*packet, *server);
  1124. clearServerAvoidance(*packet, *server);
  1125. // Take ownership of the context.
  1126. PVOID context = request->getContext().takeOwnership();
  1127. if (context)
  1128. {
  1129. // The magic moment -- we have successfully processed the response.
  1130. client->onComplete(
  1131. resultSuccess,
  1132. context,
  1133. &request->getServer(),
  1134. packet->code,
  1135. packet->begin,
  1136. packet->end
  1137. );
  1138. }
  1139. }
  1140. void RadiusProxyEngine::onReceiveError(
  1141. UDPSocket& socket,
  1142. ULONG_PTR key,
  1143. ULONG errorCode
  1144. ) throw ()
  1145. {
  1146. RadiusEvent event =
  1147. {
  1148. (RadiusPortType)key,
  1149. eventReceiveError,
  1150. NULL,
  1151. socket.getLocalAddress().address(),
  1152. socket.getLocalAddress().port(),
  1153. NULL,
  1154. 0,
  1155. errorCode
  1156. };
  1157. client->onEvent(event);
  1158. }
  1159. RadiusProxyEngine::Result RadiusProxyEngine::sendRequest(
  1160. RadiusPacket& packet,
  1161. Request* request
  1162. ) throw ()
  1163. {
  1164. // Fill in the packet identifier.
  1165. packet.identifier = request->getIdentifier();
  1166. // Get the info for the Signature.
  1167. BOOL sign = request->getServer().sendSignature;
  1168. // Format the Proxy-State attributes.
  1169. BYTE proxyStateValue[8];
  1170. RadiusAttribute proxyState = { RADIUS_PROXY_STATE, 8, proxyStateValue };
  1171. // First our IP address ...
  1172. memcpy(proxyStateValue, &proxyAddress, 4);
  1173. // ... and then the unique request ID.
  1174. InsertUInt32(proxyStateValue + 4, request->getRequestID());
  1175. // Allocate a buffer to hold the packet on the wire.
  1176. PBYTE buffer;
  1177. ALLOC_BUFFER_FOR_PACKET(buffer, &packet, &proxyState, sign);
  1178. if (!buffer) { return resultInvalidRequest; }
  1179. // Get the port for this request.
  1180. const RemotePort& port = request->getPort();
  1181. // Generate the request authenticator if necessary.
  1182. BYTE requestAuthenticator[16];
  1183. if ((packet.code == RADIUS_ACCESS_REQUEST) &&
  1184. (packet.authenticator == 0))
  1185. {
  1186. if (!CryptGenRandom(
  1187. crypto,
  1188. sizeof(requestAuthenticator),
  1189. requestAuthenticator
  1190. ))
  1191. {
  1192. return resultCryptoError;
  1193. }
  1194. packet.authenticator = requestAuthenticator;
  1195. }
  1196. // Pack the buffer. packet.authenticator is used for CHAP when the request
  1197. // authenticator is used for the chap-challenge. It can be null
  1198. PackBuffer(
  1199. port.secret,
  1200. port.secret.length(),
  1201. packet,
  1202. &proxyState,
  1203. sign,
  1204. buffer
  1205. );
  1206. // Save the request authenticator and packet.
  1207. request->setAuthenticator(buffer + 4);
  1208. request->setPacket(packet);
  1209. // Determine the request type.
  1210. bool isAuth = request->isAccReq();
  1211. // Set up the event struct.
  1212. RadiusEvent event =
  1213. {
  1214. (isAuth ? portAuthentication : portAccounting),
  1215. (isAuth ? eventAccessRequest : eventAccountingRequest),
  1216. &request->getServer(),
  1217. port.address.address(),
  1218. port.address.port(),
  1219. buffer,
  1220. packet.length
  1221. };
  1222. // Get the appropriate socket.
  1223. UDPSocket& sock = isAuth ? authSock : acctSock;
  1224. // Insert the pending request before we send it to avoid a race condition.
  1225. pending.insert(*request);
  1226. // The magic moment -- we actually send the request.
  1227. Result result;
  1228. if (sock.send(port.address, buffer, packet.length))
  1229. {
  1230. // Update request state.
  1231. request->onSend();
  1232. // Set a timer to clean up if the server doesn't answer.
  1233. if (timers.setTimer(request, request->getServer().timeout, 0))
  1234. {
  1235. result = resultSuccess;
  1236. }
  1237. else
  1238. {
  1239. // If we can't set at timer we have to remove it from the pending
  1240. // requests table or else it could leak.
  1241. pending.erase(*request);
  1242. result = resultNotEnoughMemory;
  1243. }
  1244. }
  1245. else
  1246. {
  1247. // Update the event with the error data.
  1248. event.eventType = eventSendError;
  1249. event.data = GetLastError();
  1250. // If we received "Port Unreachable" ICMP packet, we'll count this as a
  1251. // timeout since it means the server is unavailable.
  1252. if (event.data == WSAECONNRESET) { request->onTimeout(); }
  1253. // Remove from the pending requests table.
  1254. pending.erase(*request);
  1255. }
  1256. // Report the event ...
  1257. reportEvent(event);
  1258. // ... and the result.
  1259. return result;
  1260. }