|
|
///////////////////////////////////////////////////////////////////////////////
//
// Copyright (c) 2000, Microsoft Corp. All rights reserved.
//
// FILE
//
// radprxy.cpp
//
// SYNOPSIS
//
// Defines the reusable RadiusProxy engine. This should have no IAS specific
// dependencies.
//
// MODIFICATION HISTORY
//
// 02/08/2000 Original version.
// 05/30/2000 Eliminate QUESTIONABLE state.
//
///////////////////////////////////////////////////////////////////////////////
#include <proxypch.h>
#include <radproxyp.h>
#include <radproxy.h>
// Avoid dependencies on ntrtl.h
extern "C" ULONG __stdcall RtlRandom(PULONG seed);
// Extract a 32-bit integer from a buffer.
ULONG ExtractUInt32(const BYTE* p) throw () { return (ULONG)(p[0] << 24) | (ULONG)(p[1] << 16) | (ULONG)(p[2] << 8) | (ULONG)(p[3] ); }
// Insert a 32-bit integer into a buffer.
void InsertUInt32(BYTE* p, ULONG val) throw () { *p++ = (BYTE)(val >> 24); *p++ = (BYTE)(val >> 16); *p++ = (BYTE)(val >> 8); *p = (BYTE)(val ); }
//
// Layout of a Microsoft State attribute
//
// struct MicrosoftState
// {
// BYTE checksum[4];
// BYTE vendorID[4];
// BYTE version[2];
// BYTE serverAddress[4];
// BYTE sourceID[4];
// BYTE sessionID[4];
// };
//
// Extracts the creators address from a State attribute or INADDR_NONE if this
// isn't a valid Microsoft State attributes.
ULONG ExtractAddressFromState(const RadiusAttribute& state) throw () { if (state.length == 22 && !memcmp(state.value + 4, "\x00\x00\x01\x37\x00\x01", 6) && IASAdler32(state.value + 4, 18) == ExtractUInt32(state.value)) { return ExtractUInt32(state.value + 10); }
return INADDR_NONE; }
// Returns true if this is an Accounting-On/Off packet.
bool IsNasStateRequest(const RadiusPacket& packet) throw () { const RadiusAttribute* status = FindAttribute( packet, RADIUS_ACCT_STATUS_TYPE ); if (!status) { return false; }
ULONG value = ExtractUInt32(status->value);
return value == 7 || value == 8; }
RemotePort::RemotePort( ULONG ipAddress, USHORT port, PCSTR sharedSecret ) : address(ipAddress, port), secret((const BYTE*)sharedSecret, strlen(sharedSecret)) { }
RemotePort::RemotePort(const RemotePort& port) : address(port.address), secret(port.secret), nextIdentifier(port.nextIdentifier) { }
RemoteServer::RemoteServer( const RemoteServerConfig& config ) : guid(config.guid), authPort(config.ipAddress, config.authPort, config.authSecret), acctPort(config.ipAddress, config.acctPort, config.acctSecret), timeout(config.timeout), maxEvents((LONG)config.maxLost), blackout(config.blackout), priority(config.priority), weight(config.weight), sendSignature(config.sendSignature), sendAcctOnOff(config.sendAcctOnOff), usable(true), onProbation(false), eventCount(0), expiry(0) { }
bool RemoteServer::shouldBroadcast() throw () { bool broadcastable = false;
if (!onProbation && !usable) { ULONG64 now = GetSystemTime64();
lock.lock();
// Has the blackout interval expired ?
if (now > expiry) { // Yes, so set a new expiration.
expiry = now + blackout * 10000i64;
broadcastable = true; }
lock.unlock(); }
return broadcastable; }
bool RemoteServer::onReceive(BYTE code) throw () { const bool authoritative = (code != RADIUS_ACCESS_CHALLENGE);
// Did the server transition from unavailable to available?
bool downToUp = false;
lock.lock();
if (onProbation) { if (authoritative) { // Bump the success count.
if (++eventCount >= maxEvents) { // We're off probation w/ a lost count of zero.
onProbation = false; eventCount = 0; downToUp = true; }
// We successfully finished a request, so we can send another.
usable = true; } } else if (usable) { if (authoritative) { // An authoritative response resets the lost count.
eventCount = 0; } } else { // An unavailable server has responded to a broadcast, so put it on
// probation. Set the success count accordingly.
usable = true; onProbation = true; eventCount = authoritative ? 1 : 0; }
lock.unlock();
return downToUp; }
void RemoteServer::onSend() throw () { if (onProbation) { lock.lock();
if (onProbation) { // Probationary servers can only send one request at a time.
usable = false; }
lock.unlock(); } }
bool RemoteServer::onTimeout() throw () { // Did the server transition from available to unavailable?
bool upToDown = false;
lock.lock();
if (onProbation) { // Sudden death for probationary servers. Move it straight to
// unavailable.
usable = false; onProbation = false; expiry = GetSystemTime64() + blackout * 10000ui64; } else if (usable) { // Bump the lost count.
if (++eventCount >= maxEvents) { // Server is now unavailable.
usable = false; expiry = GetSystemTime64() + blackout * 10000ui64; upToDown = true; } } else { // If the server is already unavailable, ignore the timeout.
}
lock.unlock();
return upToDown; }
void RemoteServer::copyState(const RemoteServer& target) throw () { // Synchronize the ports.
authPort.copyState(target.authPort); acctPort.copyState(target.acctPort);
// Synchronize server availability.
usable = target.usable; onProbation = target.onProbation; eventCount = target.eventCount; expiry = target.expiry; }
bool RemoteServer::operator==(const RemoteServer& s) const throw () { return authPort == s.authPort && acctPort == s.acctPort && priority == s.priority && weight == s.weight && timeout == s.timeout && eventCount == s.eventCount && blackout == s.blackout && sendSignature == s.sendSignature && sendAcctOnOff == s.sendAcctOnOff; }
//////////
// Used for sorting servers by priority.
//////////
int __cdecl sortServersByPriority( const RemoteServer* const* server1, const RemoteServer* const* server2 ) throw () { return (int)(*server1)->priority - (int)(*server2)->priority; }
ULONG ServerGroup::theSeed;
ServerGroup::ServerGroup( PCWSTR groupName, RemoteServer* const* first, RemoteServer* const* last ) : servers(first, last), name(groupName) { // We don't allow empty groups.
if (servers.empty()) { _com_issue_error(E_INVALIDARG); }
if (theSeed == 0) { FILETIME ft; GetSystemTimeAsFileTime(&ft); theSeed = ft.dwLowDateTime | ft.dwHighDateTime; }
// Sort by priority.
servers.sort(sortServersByPriority);
// Find the end of the top priority servers. This will be useful when doing
// a forced pick.
ULONG topPriority = (*servers.begin())->priority; for (endTopPriority = servers.begin(); endTopPriority != servers.end(); ++endTopPriority) { if ((*endTopPriority)->priority != topPriority) { break; } }
// Find the max number of servers at any priority level. This will be useful
// when allocating a buffer to hold the candidates.
ULONG maxCount = 0, count = 0, priority = (*servers.begin())->priority; for (RemoteServer* const* i = begin(); i != end(); ++i) { if ((*i)->priority != priority) { priority = (*i)->priority; count = 0; } if (++count > maxCount) { maxCount = count; } }
maxCandidatesSize = maxCount * sizeof(RemoteServer*); }
RemoteServer* ServerGroup::pickServer( RemoteServers::iterator first, RemoteServers::iterator last, const RemoteServer* avoid ) throw () { // If the list has exactly one entry, there's nothing to do.
if (last == first + 1) { return *first; }
RemoteServer* const* i;
// Compute the combined weight off all the servers.
ULONG weight = 0; for (i = first; i != last; ++i) { if (*i != avoid) { weight += (*i)->weight; } }
// Pick a random number from [0, weight)
ULONG offset = (ULONG) (((ULONG64)RtlRandom(&theSeed) * (ULONG64)weight) >> 31);
// We don't test the last server since if we make it that far we have to use
// it anyway.
--last;
// Iterate through the candidates until we reach the offset.
for (i = first; i != last; ++i) { if (*i != avoid) { if ((*i)->weight >= offset) { break; }
offset -= (*i)->weight; } }
return *i; }
void ServerGroup::getServersForRequest( ProxyContext* context, BYTE packetCode, const RemoteServer* avoid, RequestStack& result ) const { // List of candidates.
RemoteServer** first = (RemoteServer**)_alloca(maxCandidatesSize); RemoteServer** last = first;
// Iterate through the servers.
ULONG maxPriority = (ULONG)-1; for (RemoteServer* const* i = servers.begin(); i != servers.end(); ++i) { // If this test fails, we must have found a higher priority server that's
// usable.
if ((*i)->priority > maxPriority) { break; }
if ((*i)->isUsable()) { // Don't consider lower priority servers.
maxPriority = (*i)->priority;
// Add this to the list of candidates.
*last++ = *i; } else if ((*i)->shouldBroadcast()) { // It's not available, but it's ready for a broadcast
result.push(new Request(context, *i, packetCode)); } }
if (first == last) { // No usable servers, so look for in progress servers.
maxPriority = (ULONG)-1; for (RemoteServer* const* i = servers.begin(); i != servers.end(); ++i) { // If this test fails, we must have found a higher priority server
// that's in progress.
if ((*i)->priority > maxPriority) { break; }
if ((*i)->isInProgress()) { // Don't consider lower priority servers.
maxPriority = (*i)->priority;
// Add this to the list of candidates.
*last++ = *i; } } }
if (first != last) { // We have at least one candidate, so pick one and add it to the list.
result.push(new Request( context, pickServer(first, last, avoid), packetCode )); } else if (result.empty() && !servers.empty()) { // We have no candidates and no servers available for broadcast, so just
// force a pick from the top priority servers.
result.push(new Request( context, pickServer(servers.begin(), endTopPriority, avoid), packetCode )); } }
//////////
// Used for sorting and searching groups by name.
//////////
int __cdecl sortGroupsByName( const ServerGroup* const* group1, const ServerGroup* const* group2 ) throw () { return wcscmp((*group1)->getName(), (*group2)->getName()); }
int __cdecl findGroupByName( const void* key, const ServerGroup* const* group ) throw () { return wcscmp((PCWSTR)key, (*group)->getName()); }
//////////
// Used for sorting and searching servers by address.
//////////
int __cdecl sortServersByAddress( const RemoteServer* const* server1, const RemoteServer* const* server2 ) { if ((*server1)->getAddress() < (*server2)->getAddress()) { return -1; } if ((*server1)->getAddress() > (*server2)->getAddress()) { return 1; } return 0; }
int __cdecl findServerByAddress( const void* key, const RemoteServer* const* server ) throw () { if ((ULONG_PTR)key < (*server)->getAddress()) { return -1; } if ((ULONG_PTR)key > (*server)->getAddress()) { return 1; } return 0; }
//////////
// Used for sorting and searching servers by guid.
//////////
int __cdecl sortServersByGUID( const RemoteServer* const* server1, const RemoteServer* const* server2 ) throw () { return memcmp(&(*server1)->guid, &(*server2)->guid, sizeof(GUID)); }
int __cdecl findServerByGUID( const void* key, const RemoteServer* const* server ) throw () { return memcmp(key, &(*server)->guid, sizeof(GUID)); }
//////////
// Used for sorting accounting servers by port.
//////////
int __cdecl sortServersByAcctPort( const RemoteServer* const* server1, const RemoteServer* const* server2 ) { const sockaddr_in& a1 = (*server1)->acctPort.address; const sockaddr_in& a2 = (*server2)->acctPort.address; return memcmp(&a1.sin_port, &a2.sin_port, 6); }
bool ServerGroupManager::setServerGroups( ServerGroup* const* first, ServerGroup* const* last ) throw () { bool success;
try { // Save the new server groups ...
ServerGroups newGroups(first, last);
// Sort by name.
newGroups.sort(sortGroupsByName);
// Useful iterators.
ServerGroups::iterator i; RemoteServers::iterator j;
// Count the number of servers and accounting servers.
ULONG count = 0, acctCount = 0; for (i = first; i != last; ++i) { for (j = (*i)->begin(); j != (*i)->end(); ++j) { ++count;
if ((*j)->sendAcctOnOff) { ++acctCount; } } }
// Reserve space for the servers.
RemoteServers newServers(count); RemoteServers newAcctServers(acctCount);
// Populate the servers.
for (i = first; i != last; ++i) { for (j = (*i)->begin(); j != (*i)->end(); ++j) { RemoteServer* newServer = *j;
// Does this server already exist?
RemoteServer* existing = byGuid.search( (const void*)&newServer->guid, findServerByGUID ); if (existing) { if (*existing == *newServer) { // If it's an exact match, use the existing server.
newServer = existing; } else { // Otherwise, copy the state of the existing server.
newServer->copyState(*existing); } }
newServers.push_back(newServer);
if (newServer->sendAcctOnOff) { newAcctServers.push_back(newServer); } } }
// Sort the servers by address ...
newServers.sort(sortServersByAddress);
// ... and GUID.
RemoteServers newServersByGuid(newServers); newServersByGuid.sort(sortServersByGUID);
// Everything is ready so now we grab the write lock ...
monitor.LockExclusive();
// ... and swap in the collections.
groups.swap(newGroups); byAddress.swap(newServers); byGuid.swap(newServersByGuid); acctServers.swap(newAcctServers);
monitor.Unlock();
success = true; } catch (const std::bad_alloc&) { success = false; }
return success; }
RemoteServerPtr ServerGroupManager::findServer( ULONG address ) const throw () { monitor.Lock();
RemoteServer* server = byAddress.search( (const void*)ULongToPtr(address), findServerByAddress );
monitor.Unlock();
return server; }
void ServerGroupManager::getServersByGroup( ProxyContext* context, BYTE packetCode, PCWSTR name, const RemoteServer* avoid, RequestStack& result ) const throw () { monitor.Lock();
ServerGroup* group = groups.search(name, findGroupByName);
if (group) { group->getServersForRequest(context, packetCode, avoid, result); }
monitor.Unlock(); }
void ServerGroupManager::getServersForAcctOnOff( ProxyContext* context, RequestStack& result ) const { monitor.Lock();
for (RemoteServer* const* i = acctServers.begin(); i != acctServers.end(); ++i) { result.push(new Request(context, *i, RADIUS_ACCOUNTING_REQUEST)); }
monitor.Unlock(); }
RadiusProxyEngine* RadiusProxyEngine::theProxy;
RadiusProxyEngine::RadiusProxyEngine(RadiusProxyClient* source) throw () : client(source), proxyAddress(INADDR_NONE), pending(Request::hash, 1), sessions(ServerBinding::hash, 1, 10000, (2 * 60 * 1000), true), avoid(ServerBinding::hash, 1, 10000, (35 * 60 * 1000), false), crypto(0) { theProxy = this;
// We don't care if this fails. The proxy will just use INADDR_NONE in it's
// proxy-state attribute.
PHOSTENT he = IASGetHostByName(NULL); if (he) { if (he->h_addr_list[0]) { proxyAddress = *(PULONG)he->h_addr_list[0]; }
LocalFree(he); } }
RadiusProxyEngine::~RadiusProxyEngine() throw () { // Block any new reponses.
authSock.close(); acctSock.close();
// Clear the pending request table.
pending.clear();
// Cancel all the timers.
timers.cancelAllTimers();
// At this point all our threads should be done, but let's just make sure.
SwitchToThread();
if (crypto != 0) { CryptReleaseContext(crypto, 0); }
theProxy = NULL; }
HRESULT RadiusProxyEngine::finalConstruct() throw () { HRESULT hr = S_OK;
if (!CryptAcquireContext( &crypto, 0, 0, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT )) { DWORD error = GetLastError(); hr = HRESULT_FROM_WIN32(error); }
return hr; }
bool RadiusProxyEngine::setServerGroups( ServerGroup* const* begin, ServerGroup* const* end ) throw () { // We don't open the sockets unless we actually have some server groups
// configured. This is just to be a good corporate citizen.
if (begin != end) { if ((!authSock.isOpen() && !authSock.open(this, portAuthentication)) || (!acctSock.isOpen() && !acctSock.open(this, portAccounting))) { return false; } }
return groups.setServerGroups(begin, end); }
void RadiusProxyEngine::forwardRequest( PVOID context, PCWSTR serverGroup, BYTE code, const BYTE* requestAuthenticator, const RadiusAttribute* begin, const RadiusAttribute* end ) throw () { // Save the request context. We have to handle this carefully since we rely
// on the ProxyContext object to ensure that onComplete gets called exactly
// one. If we can't allocate the object, we have to handle it specially.
ProxyContextPtr ctxt(new (std::nothrow) ProxyContext(context)); if (!ctxt) { client->onComplete( resultNotEnoughMemory, context, NULL, code, NULL, NULL ); return; }
Result retval = resultUnknownServerGroup;
try { // Store the in parameters in a RadiusPacket struct.
RadiusPacket packet; packet.code = code; packet.begin = const_cast<RadiusAttribute*>(begin); packet.end = const_cast<RadiusAttribute*>(end);
// Generate the list of RADIUS requests to be sent.
RequestStack requests; switch (code) { case RADIUS_ACCESS_REQUEST: { // Is this request associated with a particular server?
RemoteServerPtr server = getServerAffinity(packet); if (server) { requests.push(new Request(ctxt, server, RADIUS_ACCESS_REQUEST)); } else { server = getServerAvoidance(packet);
groups.getServersByGroup( ctxt, code, serverGroup, server, requests ); }
// Put request authenticator in the packet. The request
// authenticator can be NULL. The authenticator will not be
// changed.
packet.authenticator = requestAuthenticator; break; }
case RADIUS_ACCOUNTING_REQUEST: { if (!IsNasStateRequest(packet)) { groups.getServersByGroup( ctxt, code, serverGroup, 0, requests ); } else { groups.getServersForAcctOnOff( ctxt, requests );
// NAS State requests are always reported as a success since we
// don't care if it gets to all the destinations.
context = ctxt->takeOwnership(); if (context) { client->onComplete( resultSuccess, context, NULL, RADIUS_ACCOUNTING_RESPONSE, NULL, NULL ); } retval = resultSuccess; } break; }
default: { retval = resultInvalidRequest; } }
if (!requests.empty()) { // First we handle the primary.
RequestPtr request = requests.pop(); ctxt->setPrimaryServer(&request->getServer()); retval = sendRequest(packet, request);
// Now we broadcast.
while (!requests.empty()) { request = requests.pop(); Result result = sendRequest(packet, request); if (result == resultSuccess && retval != resultSuccess) { // This was the first request to succeed so mark it as primary.
retval = resultSuccess; ctxt->setPrimaryServer(&request->getServer()); } } } } catch (const std::bad_alloc&) { retval = resultNotEnoughMemory; }
if (retval != resultSuccess) { // If we made it here, then we didn't successfully send a packet to any
// server, so we have to report the result ourself.
context = ctxt->takeOwnership(); if (context) { client->onComplete( retval, context, ctxt->getPrimaryServer(), code, NULL, NULL ); } } }
void RadiusProxyEngine::onRequestAbandoned( PVOID context, RemoteServer* server ) throw () { // Nobody took responsibility for the request, so we time it out.
theProxy->client->onComplete( resultRequestTimeout, context, server, 0, NULL, NULL ); }
inline void RadiusProxyEngine::reportEvent( const RadiusEvent& event ) const throw () { client->onEvent(event); }
inline void RadiusProxyEngine::reportEvent( RadiusEvent& event, RadiusEventType type ) const throw () { event.eventType = type; client->onEvent(event); }
void RadiusProxyEngine::onRequestTimeout( Request* request ) throw () { // Erase the pending request. If it's not there, that's okay; it means that
// we received a response, but weren't able to cancel the timer in time.
if (theProxy->pending.erase(request->getRequestID())) { // Avoid this server next time.
theProxy->setServerAvoidance(*request);
RadiusEvent event = { request->getPortType(), eventTimeout, &request->getServer(), request->getPort().address.address(), request->getPort().address.port() };
// Report the protocol event.
theProxy->reportEvent(event);
// Update request state.
if (request->onTimeout()) { // The server was just marked unavailable, so notify the client.
theProxy->reportEvent(event, eventServerUnavailable); } } }
RemoteServerPtr RadiusProxyEngine::getServerAffinity( const RadiusPacket& packet ) throw () { // Find the State attribute.
const RadiusAttribute* attr = FindAttribute(packet, RADIUS_STATE); if (!attr) { return NULL; }
// Map it to a session.
RadiusRawOctets key = { attr->value, attr->length }; ServerBindingPtr session = sessions.find(key); if (!session) { return NULL; }
return &session->getServer(); }
void RadiusProxyEngine::setServerAffinity( const RadiusPacket& packet, RemoteServer& server ) throw () { // Is this an Access-Challenge ?
if (packet.code != RADIUS_ACCESS_CHALLENGE) { return; }
// Find the State attribute.
const RadiusAttribute* state = FindAttribute(packet, RADIUS_STATE); if (!state) { return; }
// Do we already have an entry for this State value.
RadiusRawOctets key = { state->value, state->length }; ServerBindingPtr session = sessions.find(key); if (session) { // Make sure the server matches.
session->setServer(server); return; }
// Otherwise, we'll have to create a new one.
try { session = new ServerBinding(key, server); sessions.insert(*session); } catch (const std::bad_alloc&) { // We don't care if this fails.
} }
void RadiusProxyEngine::clearServerAvoidance( const RadiusPacket& packet, RemoteServer& server ) throw () { // Is this packet authoritative?
if ((packet.code == RADIUS_ACCESS_ACCEPT) || (packet.code == RADIUS_ACCESS_REJECT)) { // Find the User-Name attribute.
const RadiusAttribute* attr = FindAttribute(packet, RADIUS_USER_NAME); if (attr != 0) { // Map it to a server.
RadiusRawOctets key = { attr->value, attr->length }; ServerBindingPtr avoidance = avoid.find(key); if (avoidance && (avoidance->getServer() == server)) { avoid.erase(key); } } } }
RemoteServerPtr RadiusProxyEngine::getServerAvoidance( const RadiusPacket& packet ) throw () { // Find the User-Name attribute.
const RadiusAttribute* attr = FindAttribute(packet, RADIUS_USER_NAME); if (!attr) { return NULL; }
// Map it to a server.
RadiusRawOctets key = { attr->value, attr->length }; ServerBindingPtr avoidance = avoid.find(key); if (!avoidance) { return NULL; }
return &avoidance->getServer(); }
void RadiusProxyEngine::setServerAvoidance(const Request& request) throw () { if ((request.getCode() != RADIUS_ACCESS_REQUEST) || (request.getUserName().len == 0)) { return; }
// Do we already have an entry for this User-Name value.
ServerBindingPtr avoidance = avoid.find(request.getUserName()); if (avoidance) { // Make sure the server matches.
avoidance->setServer(request.getServer()); return; }
// Otherwise, we'll have to create a new one.
try { avoidance = new ServerBinding( request.getUserName(), request.getServer() ); avoid.insert(*avoidance); } catch (const std::bad_alloc&) { // We don't care if this fails.
} }
void RadiusProxyEngine::onReceive( UDPSocket& socket, ULONG_PTR key, const SOCKADDR_IN& remoteAddress, BYTE* buffer, ULONG bufferLength ) throw () { //////////
// Set up the event struct. We'll fill in the other fields as we go along.
//////////
RadiusEvent event = { (RadiusPortType)key, eventNone, NULL, remoteAddress.sin_addr.s_addr, remoteAddress.sin_port, buffer, bufferLength, 0 };
//////////
// Validate the remote address.
//////////
RemoteServerPtr server = groups.findServer( remoteAddress.sin_addr.s_addr ); if (!server) { reportEvent(event, eventInvalidAddress); return; }
// Use the server as the event context.
event.context = server;
//////////
// Validate the packet type.
//////////
if (bufferLength == 0) { reportEvent(event, eventUnknownType); return; }
switch (MAKELONG(key, buffer[0])) { case MAKELONG(portAuthentication, RADIUS_ACCESS_ACCEPT): reportEvent(event, eventAccessAccept); break;
case MAKELONG(portAuthentication, RADIUS_ACCESS_REJECT): reportEvent(event, eventAccessReject); break;
case MAKELONG(portAuthentication, RADIUS_ACCESS_CHALLENGE): reportEvent(event, eventAccessChallenge); break;
case MAKELONG(portAccounting, RADIUS_ACCOUNTING_RESPONSE): reportEvent(event, eventAccountingResponse); break;
default: reportEvent(event, eventUnknownType); return; }
//////////
// Validate that the packet is properly formatted.
//////////
RadiusPacket* packet; ALLOC_PACKET_FOR_BUFFER(packet, buffer, bufferLength); if (!packet) { reportEvent(event, eventMalformedPacket); return; }
// Unpack the attributes.
UnpackBuffer(buffer, bufferLength, *packet);
//////////
// Validate that we were expecting this response.
//////////
// Look for our Proxy-State attribute.
RadiusAttribute* proxyState = FindAttribute( *packet, RADIUS_PROXY_STATE );
// If we didn't find it OR it's the wrong length OR it doesn't start with
// our address, then we weren't expecting this packet.
if (!proxyState || proxyState->length != 8 || memcmp(proxyState->value, &proxyAddress, 4)) { reportEvent(event, eventUnexpectedResponse); return; }
// Extract the request ID.
ULONG requestID = ExtractUInt32(proxyState->value + 4);
// Don't send the Proxy-State back to our client.
--packet->end; memmove( proxyState, proxyState + 1, (packet->end - proxyState) * sizeof(RadiusAttribute) );
// Look up the request object. We don't remove it yet because we don't know
// if this is an authentic response.
RequestPtr request = pending.find(requestID); if (!request) { // If it's not there, we'll assume that this is a packet that's
// already been reported as a timeout.
reportEvent(event, eventLateResponse); return; }
// Get the actual server we used for the request in case there are multiple
// servers defined for the same IP address.
event.context = server = &request->getServer();
const RemotePort& port = request->getPort();
// Validate the packet source && identifier.
if (!(port.address == remoteAddress) || request->getIdentifier() != packet->identifier) { reportEvent(event, eventUnexpectedResponse); return; }
//////////
// Validate that the packet is authentic.
//////////
AuthResult authResult = AuthenticateAndDecrypt( request->getAuthenticator(), port.secret, port.secret.length(), buffer, bufferLength, *packet ); switch (authResult) { case AUTH_BAD_AUTHENTICATOR: reportEvent(event, eventBadAuthenticator); return;
case AUTH_BAD_SIGNATURE: reportEvent(event, eventBadSignature); return;
case AUTH_MISSING_SIGNATURE: reportEvent(event, eventMissingSignature); return; }
//////////
// At this point, all the tests have passed -- we have the real thing.
//////////
if (!pending.erase(requestID)) { // It must have timed out while we were authenticating it.
reportEvent(event, eventLateResponse); return; }
// Update endpoint state.
if (request->onReceive(packet->code)) { // The server just came up, so notify the client.
reportEvent(event, eventServerAvailable); }
// Report the round-trip time.
event.data = request->getRoundTripTime(); reportEvent(event, eventRoundTrip);
// Set the server affinity and clear the server avoidance.
setServerAffinity(*packet, *server); clearServerAvoidance(*packet, *server);
// Take ownership of the context.
PVOID context = request->getContext().takeOwnership(); if (context) { // The magic moment -- we have successfully processed the response.
client->onComplete( resultSuccess, context, &request->getServer(), packet->code, packet->begin, packet->end ); } }
void RadiusProxyEngine::onReceiveError( UDPSocket& socket, ULONG_PTR key, ULONG errorCode ) throw () { RadiusEvent event = { (RadiusPortType)key, eventReceiveError, NULL, socket.getLocalAddress().address(), socket.getLocalAddress().port(), NULL, 0, errorCode };
client->onEvent(event); }
RadiusProxyEngine::Result RadiusProxyEngine::sendRequest( RadiusPacket& packet, Request* request ) throw () { // Fill in the packet identifier.
packet.identifier = request->getIdentifier();
// Get the info for the Signature.
BOOL sign = request->getServer().sendSignature;
// Format the Proxy-State attributes.
BYTE proxyStateValue[8]; RadiusAttribute proxyState = { RADIUS_PROXY_STATE, 8, proxyStateValue };
// First our IP address ...
memcpy(proxyStateValue, &proxyAddress, 4); // ... and then the unique request ID.
InsertUInt32(proxyStateValue + 4, request->getRequestID());
// Allocate a buffer to hold the packet on the wire.
PBYTE buffer; ALLOC_BUFFER_FOR_PACKET(buffer, &packet, &proxyState, sign); if (!buffer) { return resultInvalidRequest; }
// Get the port for this request.
const RemotePort& port = request->getPort();
// Generate the request authenticator if necessary.
BYTE requestAuthenticator[16]; if ((packet.code == RADIUS_ACCESS_REQUEST) && (packet.authenticator == 0)) { if (!CryptGenRandom( crypto, sizeof(requestAuthenticator), requestAuthenticator )) { return resultCryptoError; }
packet.authenticator = requestAuthenticator; }
// Pack the buffer. packet.authenticator is used for CHAP when the request
// authenticator is used for the chap-challenge. It can be null
PackBuffer( port.secret, port.secret.length(), packet, &proxyState, sign, buffer );
// Save the request authenticator and packet.
request->setAuthenticator(buffer + 4); request->setPacket(packet);
// Determine the request type.
bool isAuth = request->isAccReq();
// Set up the event struct.
RadiusEvent event = { (isAuth ? portAuthentication : portAccounting), (isAuth ? eventAccessRequest : eventAccountingRequest), &request->getServer(), port.address.address(), port.address.port(), buffer, packet.length };
// Get the appropriate socket.
UDPSocket& sock = isAuth ? authSock : acctSock;
// Insert the pending request before we send it to avoid a race condition.
pending.insert(*request);
// The magic moment -- we actually send the request.
Result result; if (sock.send(port.address, buffer, packet.length)) { // Update request state.
request->onSend();
// Set a timer to clean up if the server doesn't answer.
if (timers.setTimer(request, request->getServer().timeout, 0)) { result = resultSuccess; } else { // If we can't set at timer we have to remove it from the pending
// requests table or else it could leak.
pending.erase(*request); result = resultNotEnoughMemory; } } else { // Update the event with the error data.
event.eventType = eventSendError; event.data = GetLastError();
// If we received "Port Unreachable" ICMP packet, we'll count this as a
// timeout since it means the server is unavailable.
if (event.data == WSAECONNRESET) { request->onTimeout(); }
// Remove from the pending requests table.
pending.erase(*request); }
// Report the event ...
reportEvent(event);
// ... and the result.
return result; }
|