Source code of Windows XP (NT5)
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1255 lines
34 KiB

  1. ///////////////////////////////////////////////////////////////////////////////
  2. //
  3. // Copyright (c) 2000, Microsoft Corp. All rights reserved.
  4. //
  5. // FILE
  6. //
  7. // radprxy.cpp
  8. //
  9. // SYNOPSIS
  10. //
  11. // Defines the reusable RadiusProxy engine. This should have no IAS specific
  12. // dependencies.
  13. //
  14. // MODIFICATION HISTORY
  15. //
  16. // 02/08/2000 Original version.
  17. // 05/30/2000 Eliminate QUESTIONABLE state.
  18. //
  19. ///////////////////////////////////////////////////////////////////////////////
  20. #include <proxypch.h>
  21. #include <radproxyp.h>
  22. #include <radproxy.h>
  23. // Avoid dependencies on ntrtl.h
  24. extern "C" ULONG __stdcall RtlRandom(PULONG seed);
  25. // Extract a 32-bit integer from a buffer.
  26. ULONG ExtractUInt32(const BYTE* p) throw ()
  27. {
  28. return (ULONG)(p[0] << 24) | (ULONG)(p[1] << 16) |
  29. (ULONG)(p[2] << 8) | (ULONG)(p[3] );
  30. }
  31. // Insert a 32-bit integer into a buffer.
  32. void InsertUInt32(BYTE* p, ULONG val) throw ()
  33. {
  34. *p++ = (BYTE)(val >> 24);
  35. *p++ = (BYTE)(val >> 16);
  36. *p++ = (BYTE)(val >> 8);
  37. *p = (BYTE)(val );
  38. }
  39. //
  40. // Layout of a Microsoft State attribute
  41. //
  42. // struct MicrosoftState
  43. // {
  44. // BYTE checksum[4];
  45. // BYTE vendorID[4];
  46. // BYTE version[2];
  47. // BYTE serverAddress[4];
  48. // BYTE sourceID[4];
  49. // BYTE sessionID[4];
  50. // };
  51. //
  52. // Extracts the creators address from a State attribute or INADDR_NONE if this
  53. // isn't a valid Microsoft State attributes.
  54. ULONG ExtractAddressFromState(const RadiusAttribute& state) throw ()
  55. {
  56. if (state.length == 22 &&
  57. !memcmp(state.value + 4, "\x00\x00\x01\x37\x00\x01", 6) &&
  58. IASAdler32(state.value + 4, 18) == ExtractUInt32(state.value))
  59. {
  60. return ExtractUInt32(state.value + 10);
  61. }
  62. return INADDR_NONE;
  63. }
  64. // Returns true if this is an Accounting-On/Off packet.
  65. bool IsNasStateRequest(const RadiusPacket& packet) throw ()
  66. {
  67. const RadiusAttribute* status = FindAttribute(
  68. packet,
  69. RADIUS_ACCT_STATUS_TYPE
  70. );
  71. if (!status) { return false; }
  72. ULONG value = ExtractUInt32(status->value);
  73. return value == 7 || value == 8;
  74. }
  75. RemotePort::RemotePort(
  76. ULONG ipAddress,
  77. USHORT port,
  78. PCSTR sharedSecret
  79. )
  80. : address(ipAddress, port),
  81. secret((const BYTE*)sharedSecret, strlen(sharedSecret))
  82. {
  83. }
  84. RemotePort::RemotePort(const RemotePort& port)
  85. : address(port.address),
  86. secret(port.secret),
  87. nextIdentifier(port.nextIdentifier)
  88. {
  89. }
  90. RemoteServer::RemoteServer(
  91. const RemoteServerConfig& config
  92. )
  93. : guid(config.guid),
  94. authPort(config.ipAddress, config.authPort, config.authSecret),
  95. acctPort(config.ipAddress, config.acctPort, config.acctSecret),
  96. timeout(config.timeout),
  97. maxLost((LONG)config.maxLost),
  98. blackout(config.blackout),
  99. priority(config.priority),
  100. weight(config.weight),
  101. sendSignature(config.sendSignature),
  102. sendAcctOnOff(config.sendAcctOnOff),
  103. state(AVAILABLE),
  104. lostCount(0),
  105. expiry(0)
  106. {
  107. }
  108. //////////
  109. // We can make this inline since it's never called from outside this file.
  110. //////////
  111. inline bool RemoteServer::changeState(State from, State to) throw ()
  112. {
  113. return (State)InterlockedCompareExchange(
  114. (LPLONG)&state,
  115. to,
  116. from
  117. ) == from;
  118. }
  119. bool RemoteServer::shouldBroadcast() throw ()
  120. {
  121. // We only broadcast to unavailable servers and then only if we can lock the
  122. // server.
  123. if (changeState(UNAVAILABLE, LOCKED))
  124. {
  125. ULONG64 now = GetSystemTime64();
  126. // Has the blackout interval expired ?
  127. if (now > expiry)
  128. {
  129. // Yes, so set a new expiration ...
  130. expiry = now + blackout * 10000i64;
  131. // ... and unlock the server. If this transition fails, then the
  132. // server must have been marked available, so we don't want to send a
  133. // broadcast.
  134. return changeState(LOCKED, UNAVAILABLE);
  135. }
  136. // No, so just unlock and return false.
  137. changeState(LOCKED, UNAVAILABLE);
  138. }
  139. return false;
  140. }
  141. bool RemoteServer::onReceive() throw ()
  142. {
  143. // Reset the lost count.
  144. lostCount = 0;
  145. // Force a state change to AVAILABLE.
  146. State previous = (State)InterlockedExchange((LPLONG)&state, AVAILABLE);
  147. // Did we transition?
  148. return previous == UNAVAILABLE;
  149. }
  150. bool RemoteServer::onTimeout() throw ()
  151. {
  152. bool retval = false;
  153. // Have we exceeded the maxLost?
  154. if (InterlockedIncrement(&lostCount) >= maxLost)
  155. {
  156. // Yes, so lock the server ...
  157. if (changeState(AVAILABLE, LOCKED))
  158. {
  159. // ... and set the blackout interval.
  160. expiry = GetSystemTime64() + blackout * 10000i64;
  161. if (changeState(LOCKED, UNAVAILABLE))
  162. {
  163. // The server transitioned from AVAILABLE to UNAVAILABLE.
  164. retval = true;
  165. }
  166. }
  167. }
  168. return retval;
  169. }
  170. void RemoteServer::copyState(const RemoteServer& target) throw ()
  171. {
  172. // Synchronize the ports.
  173. authPort.copyState(target.authPort);
  174. acctPort.copyState(target.acctPort);
  175. // Synchronize server availability.
  176. state = target.state;
  177. lostCount = target.lostCount;
  178. expiry = target.expiry;
  179. }
  180. bool RemoteServer::operator==(const RemoteServer& s) const throw ()
  181. {
  182. return authPort == s.authPort &&
  183. acctPort == s.acctPort &&
  184. priority == s.priority &&
  185. weight == s.weight &&
  186. timeout == s.timeout &&
  187. maxLost == s.maxLost &&
  188. blackout == s.blackout &&
  189. sendSignature == s.sendSignature &&
  190. sendAcctOnOff == s.sendAcctOnOff;
  191. }
  192. //////////
  193. // Used for sorting servers by priority.
  194. //////////
  195. int __cdecl sortServersByPriority(
  196. const RemoteServer* const* server1,
  197. const RemoteServer* const* server2
  198. ) throw ()
  199. {
  200. return (int)(*server1)->priority - (int)(*server2)->priority;
  201. }
  202. ULONG ServerGroup::theSeed;
  203. ServerGroup::ServerGroup(
  204. PCWSTR groupName,
  205. RemoteServer* const* first,
  206. RemoteServer* const* last
  207. )
  208. : servers(first, last),
  209. name(groupName)
  210. {
  211. // We don't allow empty groups.
  212. if (servers.empty()) { _com_issue_error(E_INVALIDARG); }
  213. if (theSeed == 0)
  214. {
  215. FILETIME ft;
  216. GetSystemTimeAsFileTime(&ft);
  217. theSeed = ft.dwLowDateTime | ft.dwHighDateTime;
  218. }
  219. // Sort by priority.
  220. servers.sort(sortServersByPriority);
  221. // Find the end of the top priority servers. This will be useful when doing
  222. // a forced pick.
  223. ULONG topPriority = (*servers.begin())->priority;
  224. for (endTopPriority = servers.begin();
  225. endTopPriority != servers.end();
  226. ++endTopPriority)
  227. {
  228. if ((*endTopPriority)->priority != topPriority) { break; }
  229. }
  230. // Find the max number of servers at any priority level. This will be useful
  231. // when allocating a buffer to hold the candidates.
  232. ULONG maxCount = 0, count = 0, priority = (*servers.begin())->priority;
  233. for (RemoteServer* const* i = begin(); i != end(); ++i)
  234. {
  235. if ((*i)->priority != priority)
  236. {
  237. priority = (*i)->priority;
  238. count = 0;
  239. }
  240. if (++count > maxCount) { maxCount = count; }
  241. }
  242. maxCandidatesSize = maxCount * sizeof(RemoteServer*);
  243. }
  244. RemoteServer* ServerGroup::pickServer(
  245. RemoteServer* const* first,
  246. RemoteServer* const* last
  247. ) throw ()
  248. {
  249. // If the list has exactly one entry, there's nothing to do.
  250. if (last == first + 1) { return *first; }
  251. RemoteServer* const* i;
  252. // Compute the combined weight off all the servers.
  253. ULONG weight = 0;
  254. for (i = first; i != last; ++i) { weight += (*i)->weight; }
  255. // Pick a random number from [0, weight)
  256. ULONG offset = (ULONG)
  257. (((ULONG64)RtlRandom(&theSeed) * (ULONG64)weight) >> 31);
  258. // We don't test the last server since if we make it that far we have to use
  259. // it anyway.
  260. --last;
  261. // Iterate through the candidates until we reach the offset.
  262. for (i = first; i != last; ++i)
  263. {
  264. if ((*i)->weight >= offset) { break; }
  265. offset -= (*i)->weight;
  266. }
  267. return *i;
  268. }
  269. void ServerGroup::getServersForRequest(
  270. ProxyContext* context,
  271. BYTE packetCode,
  272. RequestStack& result
  273. ) const
  274. {
  275. // List of candidates.
  276. RemoteServer** first = (RemoteServer**)_alloca(maxCandidatesSize);
  277. RemoteServer** last = first;
  278. // Iterate through the servers.
  279. ULONG maxPriority = (ULONG)-1;
  280. for (RemoteServer* const* i = servers.begin(); i != servers.end(); ++i)
  281. {
  282. // If this test fails, we must have found a higher priority server that's
  283. // available.
  284. if ((*i)->priority > maxPriority) { break; }
  285. if ((*i)->isAvailable())
  286. {
  287. // Don't consider lower priority servers.
  288. maxPriority = (*i)->priority;
  289. // Add this to the list of candidates.
  290. *last++ = *i;
  291. }
  292. else if ((*i)->shouldBroadcast())
  293. {
  294. // It's not available, but it's ready for a broadcast
  295. result.push(new Request(context, *i, packetCode));
  296. }
  297. }
  298. if (first != last)
  299. {
  300. // We have at least one candidate, so pick one and add it to the list.
  301. result.push(new Request(
  302. context,
  303. pickServer(first, last),
  304. packetCode
  305. ));
  306. }
  307. else if (result.empty() && !servers.empty())
  308. {
  309. // We have no candidates and no servers available for broadcast, so just
  310. // force a pick from the top priority servers.
  311. result.push(new Request(
  312. context,
  313. pickServer(servers.begin(), endTopPriority),
  314. packetCode
  315. ));
  316. }
  317. }
  318. //////////
  319. // Used for sorting and searching groups by name.
  320. //////////
  321. int __cdecl sortGroupsByName(
  322. const ServerGroup* const* group1,
  323. const ServerGroup* const* group2
  324. ) throw ()
  325. {
  326. return wcscmp((*group1)->getName(), (*group2)->getName());
  327. }
  328. int __cdecl findGroupByName(
  329. const void* key,
  330. const ServerGroup* const* group
  331. ) throw ()
  332. {
  333. return wcscmp((PCWSTR)key, (*group)->getName());
  334. }
  335. //////////
  336. // Used for sorting and searching servers by address.
  337. //////////
  338. int __cdecl sortServersByAddress(
  339. const RemoteServer* const* server1,
  340. const RemoteServer* const* server2
  341. )
  342. {
  343. if ((*server1)->getAddress() < (*server2)->getAddress()) { return -1; }
  344. if ((*server1)->getAddress() > (*server2)->getAddress()) { return 1; }
  345. return 0;
  346. }
  347. int __cdecl findServerByAddress(
  348. const void* key,
  349. const RemoteServer* const* server
  350. ) throw ()
  351. {
  352. if ((ULONG_PTR)key < (*server)->getAddress()) { return -1; }
  353. if ((ULONG_PTR)key > (*server)->getAddress()) { return 1; }
  354. return 0;
  355. }
  356. //////////
  357. // Used for sorting and searching servers by guid.
  358. //////////
  359. int __cdecl sortServersByGUID(
  360. const RemoteServer* const* server1,
  361. const RemoteServer* const* server2
  362. ) throw ()
  363. {
  364. return memcmp(&(*server1)->guid, &(*server2)->guid, sizeof(GUID));
  365. }
  366. int __cdecl findServerByGUID(
  367. const void* key,
  368. const RemoteServer* const* server
  369. ) throw ()
  370. {
  371. return memcmp(key, &(*server)->guid, sizeof(GUID));
  372. }
  373. //////////
  374. // Used for sorting accounting servers by port.
  375. //////////
  376. int __cdecl sortServersByAcctPort(
  377. const RemoteServer* const* server1,
  378. const RemoteServer* const* server2
  379. )
  380. {
  381. const sockaddr_in& a1 = (*server1)->acctPort.address;
  382. const sockaddr_in& a2 = (*server2)->acctPort.address;
  383. return memcmp(&a1.sin_port, &a2.sin_port, 6);
  384. }
  385. bool ServerGroupManager::setServerGroups(
  386. ServerGroup* const* first,
  387. ServerGroup* const* last
  388. ) throw ()
  389. {
  390. bool success;
  391. try
  392. {
  393. // Save the new server groups ...
  394. ServerGroups newGroups(first, last);
  395. // Sort by name.
  396. newGroups.sort(sortGroupsByName);
  397. // Useful iterators.
  398. ServerGroups::iterator i;
  399. RemoteServers::iterator j;
  400. // Count the number of servers and accounting servers.
  401. ULONG count = 0, acctCount = 0;
  402. for (i = first; i != last; ++i)
  403. {
  404. for (j = (*i)->begin(); j != (*i)->end(); ++j)
  405. {
  406. ++count;
  407. if ((*j)->sendAcctOnOff) { ++acctCount; }
  408. }
  409. }
  410. // Reserve space for the servers.
  411. RemoteServers newServers(count);
  412. RemoteServers newAcctServers(acctCount);
  413. // Populate the servers.
  414. for (i = first; i != last; ++i)
  415. {
  416. for (j = (*i)->begin(); j != (*i)->end(); ++j)
  417. {
  418. RemoteServer* newServer = *j;
  419. // Does this server already exist?
  420. RemoteServer* existing = byGuid.search(
  421. (const void*)&newServer->guid,
  422. findServerByGUID
  423. );
  424. if (existing)
  425. {
  426. if (*existing == *newServer)
  427. {
  428. // If it's an exact match, use the existing server.
  429. newServer = existing;
  430. }
  431. else
  432. {
  433. // Otherwise, copy the state of the existing server.
  434. newServer->copyState(*existing);
  435. }
  436. }
  437. newServers.push_back(newServer);
  438. if (newServer->sendAcctOnOff)
  439. {
  440. newAcctServers.push_back(newServer);
  441. }
  442. }
  443. }
  444. // Sort the servers by address ...
  445. newServers.sort(sortServersByAddress);
  446. // ... and GUID.
  447. RemoteServers newServersByGuid(newServers);
  448. newServersByGuid.sort(sortServersByGUID);
  449. // Everything is ready so now we grab the write lock ...
  450. monitor.LockExclusive();
  451. // ... and swap in the collections.
  452. groups.swap(newGroups);
  453. byAddress.swap(newServers);
  454. byGuid.swap(newServersByGuid);
  455. acctServers.swap(newAcctServers);
  456. monitor.Unlock();
  457. success = true;
  458. }
  459. catch (...)
  460. {
  461. success = false;
  462. }
  463. return success;
  464. }
  465. RemoteServerPtr ServerGroupManager::findServer(
  466. ULONG address
  467. ) const throw ()
  468. {
  469. monitor.Lock();
  470. RemoteServer* server = byAddress.search(
  471. (const void*)ULongToPtr(address),
  472. findServerByAddress
  473. );
  474. monitor.Unlock();
  475. return server;
  476. }
  477. void ServerGroupManager::getServersByGroup(
  478. ProxyContext* context,
  479. BYTE packetCode,
  480. PCWSTR name,
  481. RequestStack& result
  482. ) const throw ()
  483. {
  484. monitor.Lock();
  485. ServerGroup* group = groups.search(name, findGroupByName);
  486. if (group)
  487. {
  488. group->getServersForRequest(context, packetCode, result);
  489. }
  490. monitor.Unlock();
  491. }
  492. void ServerGroupManager::getServersForAcctOnOff(
  493. ProxyContext* context,
  494. RequestStack& result
  495. ) const throw ()
  496. {
  497. monitor.Lock();
  498. for (RemoteServer* const* i = acctServers.begin();
  499. i != acctServers.end();
  500. ++i)
  501. {
  502. result.push(new Request(context, *i, RADIUS_ACCOUNTING_REQUEST));
  503. }
  504. monitor.Unlock();
  505. }
  506. RadiusProxyEngine* RadiusProxyEngine::theProxy;
  507. RadiusProxyEngine::RadiusProxyEngine(RadiusProxyClient* source) throw ()
  508. : client(source),
  509. proxyAddress(INADDR_NONE),
  510. pending(Request::hash, 1),
  511. sessions(Session::hash, 1, 10000, 120000)
  512. {
  513. theProxy = this;
  514. // We don't care if this fails. The proxy will just use INADDR_NONE in it's
  515. // proxy-state attribute.
  516. PHOSTENT he = IASGetHostByName(NULL);
  517. if (he)
  518. {
  519. if (he->h_addr_list[0])
  520. {
  521. proxyAddress = *(PULONG)he->h_addr_list[0];
  522. }
  523. LocalFree(he);
  524. }
  525. }
  526. RadiusProxyEngine::~RadiusProxyEngine() throw ()
  527. {
  528. // Block any new reponses.
  529. authSock.close();
  530. acctSock.close();
  531. // Clear the pending request table.
  532. pending.clear();
  533. // Cancel all the timers.
  534. timers.cancelAllTimers();
  535. // At this point all our threads should be done, but let's just make sure.
  536. SwitchToThread();
  537. theProxy = NULL;
  538. }
  539. bool RadiusProxyEngine::setServerGroups(
  540. ServerGroup* const* begin,
  541. ServerGroup* const* end
  542. ) throw ()
  543. {
  544. // We don't open the sockets unless we actually have some server groups
  545. // configured. This is just to be a good corporate citizen.
  546. if (begin != end)
  547. {
  548. if ((!authSock.isOpen() && !authSock.open(this, portAuthentication)) ||
  549. (!acctSock.isOpen() && !acctSock.open(this, portAccounting)))
  550. {
  551. return false;
  552. }
  553. }
  554. return groups.setServerGroups(begin, end);
  555. }
  556. void RadiusProxyEngine::forwardRequest(
  557. PVOID context,
  558. PCWSTR serverGroup,
  559. BYTE code,
  560. const BYTE* requestAuthenticator,
  561. const RadiusAttribute* begin,
  562. const RadiusAttribute* end
  563. ) throw ()
  564. {
  565. // Save the request context. We have to handle this carefully since we rely
  566. // on the ProxyContext object to ensure that onComplete gets called exactly
  567. // one. If we can't allocate the object, we have to handle it specially.
  568. ProxyContextPtr ctxt(new (std::nothrow) ProxyContext(context));
  569. if (!ctxt)
  570. {
  571. client->onComplete(
  572. resultNotEnoughMemory,
  573. context,
  574. NULL,
  575. code,
  576. NULL,
  577. NULL
  578. );
  579. return;
  580. }
  581. Result retval;
  582. // Store the in parameters in a RadiusPacket struct.
  583. RadiusPacket packet;
  584. packet.code = code;
  585. packet.begin = const_cast<RadiusAttribute*>(begin);
  586. packet.end = const_cast<RadiusAttribute*>(end);
  587. // Generate the list of RADIUS requests to be sent.
  588. RequestStack requests;
  589. switch (code)
  590. {
  591. case RADIUS_ACCESS_REQUEST:
  592. {
  593. // Is this request associated with a particular server?
  594. RemoteServerPtr server = getServerAffinity(packet);
  595. if (server)
  596. {
  597. requests.push(new Request(ctxt, server, RADIUS_ACCESS_REQUEST));
  598. }
  599. else
  600. {
  601. groups.getServersByGroup(
  602. ctxt,
  603. code,
  604. serverGroup,
  605. requests
  606. );
  607. retval = resultUnknownServerGroup;
  608. }
  609. // put request authenticator in the packet.
  610. // the request authenticator can be NULL
  611. // the authenticator will not be changed.
  612. packet.authenticator = requestAuthenticator;
  613. break;
  614. }
  615. case RADIUS_ACCOUNTING_REQUEST:
  616. {
  617. if (!IsNasStateRequest(packet))
  618. {
  619. groups.getServersByGroup(
  620. ctxt,
  621. code,
  622. serverGroup,
  623. requests
  624. );
  625. retval = resultUnknownServerGroup;
  626. }
  627. else
  628. {
  629. groups.getServersForAcctOnOff(
  630. ctxt,
  631. requests
  632. );
  633. // NAS State requests are always reported as a success since we
  634. // don't care if it gets to all the destinations.
  635. context = ctxt->takeOwnership();
  636. if (context)
  637. {
  638. client->onComplete(
  639. resultSuccess,
  640. context,
  641. NULL,
  642. RADIUS_ACCOUNTING_RESPONSE,
  643. NULL,
  644. NULL
  645. );
  646. }
  647. retval = resultSuccess;
  648. }
  649. break;
  650. }
  651. default:
  652. {
  653. retval = resultInvalidRequest;
  654. }
  655. }
  656. if (!requests.empty())
  657. {
  658. // First we handle the primary.
  659. RequestPtr request = requests.pop();
  660. ctxt->setPrimaryServer(&request->getServer());
  661. retval = sendRequest(packet, request);
  662. // Now we broadcast.
  663. while (!requests.empty())
  664. {
  665. request = requests.pop();
  666. Result result = sendRequest(packet, request);
  667. if (result == resultSuccess && retval != resultSuccess)
  668. {
  669. // This was the first request to succeed so mark it as primary.
  670. retval = resultSuccess;
  671. ctxt->setPrimaryServer(&request->getServer());
  672. }
  673. }
  674. }
  675. if (retval != resultSuccess)
  676. {
  677. // If we made it here, then we didn't successfully send a packet to any
  678. // server, so we have to report the result ourself.
  679. context = ctxt->takeOwnership();
  680. if (context)
  681. {
  682. client->onComplete(
  683. retval,
  684. context,
  685. ctxt->getPrimaryServer(),
  686. code,
  687. NULL,
  688. NULL
  689. );
  690. }
  691. }
  692. }
  693. void RadiusProxyEngine::onRequestAbandoned(
  694. PVOID context,
  695. RemoteServer* server
  696. ) throw ()
  697. {
  698. // Nobody took responsibility for the request, so we time it out.
  699. theProxy->client->onComplete(
  700. resultRequestTimeout,
  701. context,
  702. server,
  703. 0,
  704. NULL,
  705. NULL
  706. );
  707. }
  708. inline void RadiusProxyEngine::reportEvent(
  709. const RadiusEvent& event
  710. ) const throw ()
  711. {
  712. client->onEvent(event);
  713. }
  714. inline void RadiusProxyEngine::reportEvent(
  715. RadiusEvent& event,
  716. RadiusEventType type
  717. ) const throw ()
  718. {
  719. event.eventType = type;
  720. client->onEvent(event);
  721. }
  722. void RadiusProxyEngine::onRequestTimeout(
  723. Request* request
  724. ) throw ()
  725. {
  726. // Erase the pending request. If it's not there, that's okay; it means that
  727. // we received a response, but weren't able to cancel the timer in time.
  728. if (theProxy->pending.erase(request->getRequestID()))
  729. {
  730. RadiusEvent event =
  731. {
  732. request->getPortType(),
  733. eventTimeout,
  734. &request->getServer(),
  735. request->getPort().address.address(),
  736. request->getPort().address.port()
  737. };
  738. // Report the protocol event.
  739. theProxy->reportEvent(event);
  740. // Update request state.
  741. if (request->onTimeout())
  742. {
  743. // The server was just marked unavailable, so notify the client.
  744. theProxy->reportEvent(event, eventServerUnavailable);
  745. }
  746. }
  747. }
  748. RemoteServerPtr RadiusProxyEngine::getServerAffinity(
  749. const RadiusPacket& packet
  750. ) throw ()
  751. {
  752. // Find the State attribute.
  753. const RadiusAttribute* attr = FindAttribute(packet, RADIUS_STATE);
  754. if (!attr) { return NULL; }
  755. // Map it to a session.
  756. RadiusRawOctets key = { attr->value, attr->length };
  757. SessionPtr session = sessions.find(key);
  758. if (!session) { return NULL; }
  759. return &session->getServer();
  760. }
  761. void RadiusProxyEngine::setServerAffinity(
  762. const RadiusPacket& packet,
  763. RemoteServer& server
  764. ) throw ()
  765. {
  766. // Is this an Access-Challenge ?
  767. if (packet.code != RADIUS_ACCESS_CHALLENGE) { return; }
  768. // Find the State attribute.
  769. const RadiusAttribute* state = FindAttribute(packet, RADIUS_STATE);
  770. if (!state) { return; }
  771. // Do we already have an entry for this State value.
  772. RadiusRawOctets key = { state->value, state->length };
  773. SessionPtr session = sessions.find(key);
  774. if (session)
  775. {
  776. // Make sure the server matches.
  777. session->setServer(server);
  778. return;
  779. }
  780. // Otherwise, we'll have to create a new one.
  781. try
  782. {
  783. session = new Session(key, server);
  784. sessions.insert(*session);
  785. }
  786. catch (...)
  787. {
  788. // We don't care if this fails.
  789. }
  790. }
  791. void RadiusProxyEngine::onReceive(
  792. UDPSocket& socket,
  793. ULONG_PTR key,
  794. const SOCKADDR_IN& remoteAddress,
  795. BYTE* buffer,
  796. ULONG bufferLength
  797. ) throw ()
  798. {
  799. //////////
  800. // Set up the event struct. We'll fill in the other fields as we go along.
  801. //////////
  802. RadiusEvent event =
  803. {
  804. (RadiusPortType)key,
  805. eventNone,
  806. NULL,
  807. remoteAddress.sin_addr.s_addr,
  808. remoteAddress.sin_port,
  809. buffer,
  810. bufferLength,
  811. 0
  812. };
  813. //////////
  814. // Validate the remote address.
  815. //////////
  816. RemoteServerPtr server = groups.findServer(
  817. remoteAddress.sin_addr.s_addr
  818. );
  819. if (!server)
  820. {
  821. reportEvent(event, eventInvalidAddress);
  822. return;
  823. }
  824. // Use the server as the event context.
  825. event.context = server;
  826. //////////
  827. // Validate the packet type.
  828. //////////
  829. if (bufferLength == 0)
  830. {
  831. reportEvent(event, eventUnknownType);
  832. return;
  833. }
  834. switch (MAKELONG(key, buffer[0]))
  835. {
  836. case MAKELONG(portAuthentication, RADIUS_ACCESS_ACCEPT):
  837. reportEvent(event, eventAccessAccept);
  838. break;
  839. case MAKELONG(portAuthentication, RADIUS_ACCESS_REJECT):
  840. reportEvent(event, eventAccessReject);
  841. break;
  842. case MAKELONG(portAuthentication, RADIUS_ACCESS_CHALLENGE):
  843. reportEvent(event, eventAccessChallenge);
  844. break;
  845. case MAKELONG(portAccounting, RADIUS_ACCOUNTING_RESPONSE):
  846. reportEvent(event, eventAccountingResponse);
  847. break;
  848. default:
  849. reportEvent(event, eventUnknownType);
  850. return;
  851. }
  852. //////////
  853. // Validate that the packet is properly formatted.
  854. //////////
  855. RadiusPacket* packet;
  856. ALLOC_PACKET_FOR_BUFFER(packet, buffer, bufferLength);
  857. if (!packet)
  858. {
  859. reportEvent(event, eventMalformedPacket);
  860. return;
  861. }
  862. // Unpack the attributes.
  863. UnpackBuffer(buffer, bufferLength, *packet);
  864. //////////
  865. // Validate that we were expecting this response.
  866. //////////
  867. // Look for our Proxy-State attribute.
  868. RadiusAttribute* proxyState = FindAttribute(
  869. *packet,
  870. RADIUS_PROXY_STATE
  871. );
  872. // If we didn't find it OR it's the wrong length OR it doesn't start with
  873. // our address, then we weren't expecting this packet.
  874. if (!proxyState ||
  875. proxyState->length != 8 ||
  876. memcmp(proxyState->value, &proxyAddress, 4))
  877. {
  878. reportEvent(event, eventUnexpectedResponse);
  879. return;
  880. }
  881. // Extract the request ID.
  882. ULONG requestID = ExtractUInt32(proxyState->value + 4);
  883. // Don't send the Proxy-State back to our client.
  884. --packet->end;
  885. memmove(
  886. proxyState,
  887. proxyState + 1,
  888. (packet->end - proxyState) * sizeof(RadiusAttribute)
  889. );
  890. // Look up the request object. We don't remove it yet because we don't know
  891. // if this is an authentic response.
  892. RequestPtr request = pending.find(requestID);
  893. if (!request)
  894. {
  895. // If it's not there, we'll assume that this is a packet that's
  896. // already been reported as a timeout.
  897. reportEvent(event, eventLateResponse);
  898. return;
  899. }
  900. // Get the actual server we used for the request in case there are multiple
  901. // servers defined for the same IP address.
  902. event.context = server = &request->getServer();
  903. const RemotePort& port = request->getPort();
  904. // Validate the packet source && identifier.
  905. if (!(port.address == remoteAddress) ||
  906. request->getIdentifier() != packet->identifier)
  907. {
  908. reportEvent(event, eventUnexpectedResponse);
  909. return;
  910. }
  911. //////////
  912. // Validate that the packet is authentic.
  913. //////////
  914. AuthResult authResult = AuthenticateAndDecrypt(
  915. request->getAuthenticator(),
  916. port.secret,
  917. port.secret.length(),
  918. buffer,
  919. bufferLength,
  920. *packet
  921. );
  922. switch (authResult)
  923. {
  924. case AUTH_BAD_AUTHENTICATOR:
  925. reportEvent(event, eventBadAuthenticator);
  926. return;
  927. case AUTH_BAD_SIGNATURE:
  928. reportEvent(event, eventBadSignature);
  929. return;
  930. case AUTH_MISSING_SIGNATURE:
  931. reportEvent(event, eventMissingSignature);
  932. return;
  933. }
  934. //////////
  935. // At this point, all the tests have passed -- we have the real thing.
  936. //////////
  937. if (!pending.erase(requestID))
  938. {
  939. // It must have timed out while we were authenticating it.
  940. reportEvent(event, eventLateResponse);
  941. return;
  942. }
  943. // Update endpoint state.
  944. if (request->onReceive())
  945. {
  946. // The server just came up, so notify the client.
  947. reportEvent(event, eventServerAvailable);
  948. }
  949. // Update the round-trip time.
  950. event.data = request->getRoundTripTime();
  951. reportEvent(event, eventRoundTrip);
  952. // Set the server affinity.
  953. setServerAffinity(*packet, *server);
  954. // Take ownership of the context.
  955. PVOID context = request->getContext().takeOwnership();
  956. if (context)
  957. {
  958. // The magic moment -- we have successfully processed the response.
  959. client->onComplete(
  960. resultSuccess,
  961. context,
  962. &request->getServer(),
  963. packet->code,
  964. packet->begin,
  965. packet->end
  966. );
  967. }
  968. }
  969. void RadiusProxyEngine::onReceiveError(
  970. UDPSocket& socket,
  971. ULONG_PTR key,
  972. ULONG errorCode
  973. ) throw ()
  974. {
  975. RadiusEvent event =
  976. {
  977. (RadiusPortType)key,
  978. eventReceiveError,
  979. NULL,
  980. socket.getLocalAddress().address(),
  981. socket.getLocalAddress().port(),
  982. NULL,
  983. 0,
  984. errorCode
  985. };
  986. client->onEvent(event);
  987. }
  988. RadiusProxyEngine::Result RadiusProxyEngine::sendRequest(
  989. RadiusPacket& packet,
  990. Request* request
  991. )
  992. {
  993. // Fill in the packet identifier.
  994. packet.identifier = request->getIdentifier();
  995. // Get the info for the Signature.
  996. BOOL sign = request->getServer().sendSignature;
  997. // Format the Proxy-State attributes.
  998. BYTE proxyStateValue[8];
  999. RadiusAttribute proxyState = { RADIUS_PROXY_STATE, 8, proxyStateValue };
  1000. // First our IP address ...
  1001. memcpy(proxyStateValue, &proxyAddress, 4);
  1002. // ... and then the unique request ID.
  1003. InsertUInt32(proxyStateValue + 4, request->getRequestID());
  1004. // Allocate a buffer to hold the packet on the wire.
  1005. PBYTE buffer;
  1006. ALLOC_BUFFER_FOR_PACKET(buffer, &packet, &proxyState, sign);
  1007. if (!buffer) { return resultInvalidRequest; }
  1008. // Get the port for this request.
  1009. const RemotePort& port = request->getPort();
  1010. // Pack the buffer.
  1011. // packet.authenticator is used for CHAP when the request authenticator
  1012. // is used for the chap-challenge. It can be null
  1013. PackBuffer(
  1014. port.secret,
  1015. port.secret.length(),
  1016. packet,
  1017. &proxyState,
  1018. sign,
  1019. buffer
  1020. );
  1021. // Save the request authenticator.
  1022. request->setAuthenticator(buffer + 4);
  1023. // Determine the request type.
  1024. bool isAuth = request->isAccReq();
  1025. // Set up the event struct.
  1026. RadiusEvent event =
  1027. {
  1028. (isAuth ? portAuthentication : portAccounting),
  1029. (isAuth ? eventAccessRequest : eventAccountingRequest),
  1030. &request->getServer(),
  1031. port.address.address(),
  1032. port.address.port(),
  1033. buffer,
  1034. packet.length
  1035. };
  1036. // Get the appropriate socket.
  1037. UDPSocket& sock = isAuth ? authSock : acctSock;
  1038. // Insert the pending request before we send it to avoid a race condition.
  1039. pending.insert(*request);
  1040. // The magic moment -- we actually send the request.
  1041. Result result;
  1042. if (sock.send(port.address, buffer, packet.length))
  1043. {
  1044. // Update request state.
  1045. request->onSend();
  1046. // Set a timer to clean up if the server doesn't answer.
  1047. if (timers.setTimer(request, request->getServer().timeout, 0))
  1048. {
  1049. result = resultSuccess;
  1050. }
  1051. else
  1052. {
  1053. // If we can't set at timer we have to remove it from the pending
  1054. // requests table or else it could leak.
  1055. pending.erase(*request);
  1056. result = resultNotEnoughMemory;
  1057. }
  1058. }
  1059. else
  1060. {
  1061. // Update the event with the error data.
  1062. event.eventType = eventSendError;
  1063. event.data = GetLastError();
  1064. // If we received "Port Unreachable" ICMP packet, we'll count this as a
  1065. // timeout since it means the server is unavailable.
  1066. if (event.data == WSAECONNRESET) { request->onTimeout(); }
  1067. // Remove from the pending requests table.
  1068. pending.erase(*request);
  1069. }
  1070. // Report the event ...
  1071. reportEvent(event);
  1072. // ... and the result.
  1073. return result;
  1074. }