Leaked source code of windows server 2003
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1468 lines
40 KiB

///////////////////////////////////////////////////////////////////////////////
//
// Copyright (c) 2000, Microsoft Corp. All rights reserved.
//
// FILE
//
// radprxy.cpp
//
// SYNOPSIS
//
// Defines the reusable RadiusProxy engine. This should have no IAS specific
// dependencies.
//
// MODIFICATION HISTORY
//
// 02/08/2000 Original version.
// 05/30/2000 Eliminate QUESTIONABLE state.
//
///////////////////////////////////////////////////////////////////////////////
#include <proxypch.h>
#include <radproxyp.h>
#include <radproxy.h>
// Avoid dependencies on ntrtl.h
extern "C" ULONG __stdcall RtlRandom(PULONG seed);
// Extract a 32-bit integer from a buffer.
ULONG ExtractUInt32(const BYTE* p) throw ()
{
return (ULONG)(p[0] << 24) | (ULONG)(p[1] << 16) |
(ULONG)(p[2] << 8) | (ULONG)(p[3] );
}
// Insert a 32-bit integer into a buffer.
void InsertUInt32(BYTE* p, ULONG val) throw ()
{
*p++ = (BYTE)(val >> 24);
*p++ = (BYTE)(val >> 16);
*p++ = (BYTE)(val >> 8);
*p = (BYTE)(val );
}
//
// Layout of a Microsoft State attribute
//
// struct MicrosoftState
// {
// BYTE checksum[4];
// BYTE vendorID[4];
// BYTE version[2];
// BYTE serverAddress[4];
// BYTE sourceID[4];
// BYTE sessionID[4];
// };
//
// Extracts the creators address from a State attribute or INADDR_NONE if this
// isn't a valid Microsoft State attributes.
ULONG ExtractAddressFromState(const RadiusAttribute& state) throw ()
{
if (state.length == 22 &&
!memcmp(state.value + 4, "\x00\x00\x01\x37\x00\x01", 6) &&
IASAdler32(state.value + 4, 18) == ExtractUInt32(state.value))
{
return ExtractUInt32(state.value + 10);
}
return INADDR_NONE;
}
// Returns true if this is an Accounting-On/Off packet.
bool IsNasStateRequest(const RadiusPacket& packet) throw ()
{
const RadiusAttribute* status = FindAttribute(
packet,
RADIUS_ACCT_STATUS_TYPE
);
if (!status) { return false; }
ULONG value = ExtractUInt32(status->value);
return value == 7 || value == 8;
}
RemotePort::RemotePort(
ULONG ipAddress,
USHORT port,
PCSTR sharedSecret
)
: address(ipAddress, port),
secret((const BYTE*)sharedSecret, strlen(sharedSecret))
{
}
RemotePort::RemotePort(const RemotePort& port)
: address(port.address),
secret(port.secret),
nextIdentifier(port.nextIdentifier)
{
}
RemoteServer::RemoteServer(
const RemoteServerConfig& config
)
: guid(config.guid),
authPort(config.ipAddress, config.authPort, config.authSecret),
acctPort(config.ipAddress, config.acctPort, config.acctSecret),
timeout(config.timeout),
maxEvents((LONG)config.maxLost),
blackout(config.blackout),
priority(config.priority),
weight(config.weight),
sendSignature(config.sendSignature),
sendAcctOnOff(config.sendAcctOnOff),
usable(true),
onProbation(false),
eventCount(0),
expiry(0)
{
}
bool RemoteServer::shouldBroadcast() throw ()
{
bool broadcastable = false;
if (!onProbation && !usable)
{
ULONG64 now = GetSystemTime64();
lock.lock();
// Has the blackout interval expired ?
if (now > expiry)
{
// Yes, so set a new expiration.
expiry = now + blackout * 10000i64;
broadcastable = true;
}
lock.unlock();
}
return broadcastable;
}
bool RemoteServer::onReceive(BYTE code) throw ()
{
const bool authoritative = (code != RADIUS_ACCESS_CHALLENGE);
// Did the server transition from unavailable to available?
bool downToUp = false;
lock.lock();
if (onProbation)
{
if (authoritative)
{
// Bump the success count.
if (++eventCount >= maxEvents)
{
// We're off probation w/ a lost count of zero.
onProbation = false;
eventCount = 0;
downToUp = true;
}
// We successfully finished a request, so we can send another.
usable = true;
}
}
else if (usable)
{
if (authoritative)
{
// An authoritative response resets the lost count.
eventCount = 0;
}
}
else
{
// An unavailable server has responded to a broadcast, so put it on
// probation. Set the success count accordingly.
usable = true;
onProbation = true;
eventCount = authoritative ? 1 : 0;
}
lock.unlock();
return downToUp;
}
void RemoteServer::onSend() throw ()
{
if (onProbation)
{
lock.lock();
if (onProbation)
{
// Probationary servers can only send one request at a time.
usable = false;
}
lock.unlock();
}
}
bool RemoteServer::onTimeout() throw ()
{
// Did the server transition from available to unavailable?
bool upToDown = false;
lock.lock();
if (onProbation)
{
// Sudden death for probationary servers. Move it straight to
// unavailable.
usable = false;
onProbation = false;
expiry = GetSystemTime64() + blackout * 10000ui64;
}
else if (usable)
{
// Bump the lost count.
if (++eventCount >= maxEvents)
{
// Server is now unavailable.
usable = false;
expiry = GetSystemTime64() + blackout * 10000ui64;
upToDown = true;
}
}
else
{
// If the server is already unavailable, ignore the timeout.
}
lock.unlock();
return upToDown;
}
void RemoteServer::copyState(const RemoteServer& target) throw ()
{
// Synchronize the ports.
authPort.copyState(target.authPort);
acctPort.copyState(target.acctPort);
// Synchronize server availability.
usable = target.usable;
onProbation = target.onProbation;
eventCount = target.eventCount;
expiry = target.expiry;
}
bool RemoteServer::operator==(const RemoteServer& s) const throw ()
{
return authPort == s.authPort &&
acctPort == s.acctPort &&
priority == s.priority &&
weight == s.weight &&
timeout == s.timeout &&
eventCount == s.eventCount &&
blackout == s.blackout &&
sendSignature == s.sendSignature &&
sendAcctOnOff == s.sendAcctOnOff;
}
//////////
// Used for sorting servers by priority.
//////////
int __cdecl sortServersByPriority(
const RemoteServer* const* server1,
const RemoteServer* const* server2
) throw ()
{
return (int)(*server1)->priority - (int)(*server2)->priority;
}
ULONG ServerGroup::theSeed;
ServerGroup::ServerGroup(
PCWSTR groupName,
RemoteServer* const* first,
RemoteServer* const* last
)
: servers(first, last),
name(groupName)
{
// We don't allow empty groups.
if (servers.empty()) { _com_issue_error(E_INVALIDARG); }
if (theSeed == 0)
{
FILETIME ft;
GetSystemTimeAsFileTime(&ft);
theSeed = ft.dwLowDateTime | ft.dwHighDateTime;
}
// Sort by priority.
servers.sort(sortServersByPriority);
// Find the end of the top priority servers. This will be useful when doing
// a forced pick.
ULONG topPriority = (*servers.begin())->priority;
for (endTopPriority = servers.begin();
endTopPriority != servers.end();
++endTopPriority)
{
if ((*endTopPriority)->priority != topPriority) { break; }
}
// Find the max number of servers at any priority level. This will be useful
// when allocating a buffer to hold the candidates.
ULONG maxCount = 0, count = 0, priority = (*servers.begin())->priority;
for (RemoteServer* const* i = begin(); i != end(); ++i)
{
if ((*i)->priority != priority)
{
priority = (*i)->priority;
count = 0;
}
if (++count > maxCount) { maxCount = count; }
}
maxCandidatesSize = maxCount * sizeof(RemoteServer*);
}
RemoteServer* ServerGroup::pickServer(
RemoteServers::iterator first,
RemoteServers::iterator last,
const RemoteServer* avoid
) throw ()
{
// If the list has exactly one entry, there's nothing to do.
if (last == first + 1) { return *first; }
RemoteServer* const* i;
// Compute the combined weight off all the servers.
ULONG weight = 0;
for (i = first; i != last; ++i)
{
if (*i != avoid)
{
weight += (*i)->weight;
}
}
// Pick a random number from [0, weight)
ULONG offset = (ULONG)
(((ULONG64)RtlRandom(&theSeed) * (ULONG64)weight) >> 31);
// We don't test the last server since if we make it that far we have to use
// it anyway.
--last;
// Iterate through the candidates until we reach the offset.
for (i = first; i != last; ++i)
{
if (*i != avoid)
{
if ((*i)->weight >= offset) { break; }
offset -= (*i)->weight;
}
}
return *i;
}
void ServerGroup::getServersForRequest(
ProxyContext* context,
BYTE packetCode,
const RemoteServer* avoid,
RequestStack& result
) const
{
// List of candidates.
RemoteServer** first = (RemoteServer**)_alloca(maxCandidatesSize);
RemoteServer** last = first;
// Iterate through the servers.
ULONG maxPriority = (ULONG)-1;
for (RemoteServer* const* i = servers.begin(); i != servers.end(); ++i)
{
// If this test fails, we must have found a higher priority server that's
// usable.
if ((*i)->priority > maxPriority) { break; }
if ((*i)->isUsable())
{
// Don't consider lower priority servers.
maxPriority = (*i)->priority;
// Add this to the list of candidates.
*last++ = *i;
}
else if ((*i)->shouldBroadcast())
{
// It's not available, but it's ready for a broadcast
result.push(new Request(context, *i, packetCode));
}
}
if (first == last)
{
// No usable servers, so look for in progress servers.
maxPriority = (ULONG)-1;
for (RemoteServer* const* i = servers.begin(); i != servers.end(); ++i)
{
// If this test fails, we must have found a higher priority server
// that's in progress.
if ((*i)->priority > maxPriority) { break; }
if ((*i)->isInProgress())
{
// Don't consider lower priority servers.
maxPriority = (*i)->priority;
// Add this to the list of candidates.
*last++ = *i;
}
}
}
if (first != last)
{
// We have at least one candidate, so pick one and add it to the list.
result.push(new Request(
context,
pickServer(first, last, avoid),
packetCode
));
}
else if (result.empty() && !servers.empty())
{
// We have no candidates and no servers available for broadcast, so just
// force a pick from the top priority servers.
result.push(new Request(
context,
pickServer(servers.begin(), endTopPriority, avoid),
packetCode
));
}
}
//////////
// Used for sorting and searching groups by name.
//////////
int __cdecl sortGroupsByName(
const ServerGroup* const* group1,
const ServerGroup* const* group2
) throw ()
{
return wcscmp((*group1)->getName(), (*group2)->getName());
}
int __cdecl findGroupByName(
const void* key,
const ServerGroup* const* group
) throw ()
{
return wcscmp((PCWSTR)key, (*group)->getName());
}
//////////
// Used for sorting and searching servers by address.
//////////
int __cdecl sortServersByAddress(
const RemoteServer* const* server1,
const RemoteServer* const* server2
)
{
if ((*server1)->getAddress() < (*server2)->getAddress()) { return -1; }
if ((*server1)->getAddress() > (*server2)->getAddress()) { return 1; }
return 0;
}
int __cdecl findServerByAddress(
const void* key,
const RemoteServer* const* server
) throw ()
{
if ((ULONG_PTR)key < (*server)->getAddress()) { return -1; }
if ((ULONG_PTR)key > (*server)->getAddress()) { return 1; }
return 0;
}
//////////
// Used for sorting and searching servers by guid.
//////////
int __cdecl sortServersByGUID(
const RemoteServer* const* server1,
const RemoteServer* const* server2
) throw ()
{
return memcmp(&(*server1)->guid, &(*server2)->guid, sizeof(GUID));
}
int __cdecl findServerByGUID(
const void* key,
const RemoteServer* const* server
) throw ()
{
return memcmp(key, &(*server)->guid, sizeof(GUID));
}
//////////
// Used for sorting accounting servers by port.
//////////
int __cdecl sortServersByAcctPort(
const RemoteServer* const* server1,
const RemoteServer* const* server2
)
{
const sockaddr_in& a1 = (*server1)->acctPort.address;
const sockaddr_in& a2 = (*server2)->acctPort.address;
return memcmp(&a1.sin_port, &a2.sin_port, 6);
}
bool ServerGroupManager::setServerGroups(
ServerGroup* const* first,
ServerGroup* const* last
) throw ()
{
bool success;
try
{
// Save the new server groups ...
ServerGroups newGroups(first, last);
// Sort by name.
newGroups.sort(sortGroupsByName);
// Useful iterators.
ServerGroups::iterator i;
RemoteServers::iterator j;
// Count the number of servers and accounting servers.
ULONG count = 0, acctCount = 0;
for (i = first; i != last; ++i)
{
for (j = (*i)->begin(); j != (*i)->end(); ++j)
{
++count;
if ((*j)->sendAcctOnOff) { ++acctCount; }
}
}
// Reserve space for the servers.
RemoteServers newServers(count);
RemoteServers newAcctServers(acctCount);
// Populate the servers.
for (i = first; i != last; ++i)
{
for (j = (*i)->begin(); j != (*i)->end(); ++j)
{
RemoteServer* newServer = *j;
// Does this server already exist?
RemoteServer* existing = byGuid.search(
(const void*)&newServer->guid,
findServerByGUID
);
if (existing)
{
if (*existing == *newServer)
{
// If it's an exact match, use the existing server.
newServer = existing;
}
else
{
// Otherwise, copy the state of the existing server.
newServer->copyState(*existing);
}
}
newServers.push_back(newServer);
if (newServer->sendAcctOnOff)
{
newAcctServers.push_back(newServer);
}
}
}
// Sort the servers by address ...
newServers.sort(sortServersByAddress);
// ... and GUID.
RemoteServers newServersByGuid(newServers);
newServersByGuid.sort(sortServersByGUID);
// Everything is ready so now we grab the write lock ...
monitor.LockExclusive();
// ... and swap in the collections.
groups.swap(newGroups);
byAddress.swap(newServers);
byGuid.swap(newServersByGuid);
acctServers.swap(newAcctServers);
monitor.Unlock();
success = true;
}
catch (const std::bad_alloc&)
{
success = false;
}
return success;
}
RemoteServerPtr ServerGroupManager::findServer(
ULONG address
) const throw ()
{
monitor.Lock();
RemoteServer* server = byAddress.search(
(const void*)ULongToPtr(address),
findServerByAddress
);
monitor.Unlock();
return server;
}
void ServerGroupManager::getServersByGroup(
ProxyContext* context,
BYTE packetCode,
PCWSTR name,
const RemoteServer* avoid,
RequestStack& result
) const throw ()
{
monitor.Lock();
ServerGroup* group = groups.search(name, findGroupByName);
if (group)
{
group->getServersForRequest(context, packetCode, avoid, result);
}
monitor.Unlock();
}
void ServerGroupManager::getServersForAcctOnOff(
ProxyContext* context,
RequestStack& result
) const
{
monitor.Lock();
for (RemoteServer* const* i = acctServers.begin();
i != acctServers.end();
++i)
{
result.push(new Request(context, *i, RADIUS_ACCOUNTING_REQUEST));
}
monitor.Unlock();
}
RadiusProxyEngine* RadiusProxyEngine::theProxy;
RadiusProxyEngine::RadiusProxyEngine(RadiusProxyClient* source) throw ()
: client(source),
proxyAddress(INADDR_NONE),
pending(Request::hash, 1),
sessions(ServerBinding::hash, 1, 10000, (2 * 60 * 1000), true),
avoid(ServerBinding::hash, 1, 10000, (35 * 60 * 1000), false),
crypto(0)
{
theProxy = this;
// We don't care if this fails. The proxy will just use INADDR_NONE in it's
// proxy-state attribute.
PHOSTENT he = IASGetHostByName(NULL);
if (he)
{
if (he->h_addr_list[0])
{
proxyAddress = *(PULONG)he->h_addr_list[0];
}
LocalFree(he);
}
}
RadiusProxyEngine::~RadiusProxyEngine() throw ()
{
// Block any new reponses.
authSock.close();
acctSock.close();
// Clear the pending request table.
pending.clear();
// Cancel all the timers.
timers.cancelAllTimers();
// At this point all our threads should be done, but let's just make sure.
SwitchToThread();
if (crypto != 0)
{
CryptReleaseContext(crypto, 0);
}
theProxy = NULL;
}
HRESULT RadiusProxyEngine::finalConstruct() throw ()
{
HRESULT hr = S_OK;
if (!CryptAcquireContext(
&crypto,
0,
0,
PROV_RSA_FULL,
CRYPT_VERIFYCONTEXT
))
{
DWORD error = GetLastError();
hr = HRESULT_FROM_WIN32(error);
}
return hr;
}
bool RadiusProxyEngine::setServerGroups(
ServerGroup* const* begin,
ServerGroup* const* end
) throw ()
{
// We don't open the sockets unless we actually have some server groups
// configured. This is just to be a good corporate citizen.
if (begin != end)
{
if ((!authSock.isOpen() && !authSock.open(this, portAuthentication)) ||
(!acctSock.isOpen() && !acctSock.open(this, portAccounting)))
{
return false;
}
}
return groups.setServerGroups(begin, end);
}
void RadiusProxyEngine::forwardRequest(
PVOID context,
PCWSTR serverGroup,
BYTE code,
const BYTE* requestAuthenticator,
const RadiusAttribute* begin,
const RadiusAttribute* end
) throw ()
{
// Save the request context. We have to handle this carefully since we rely
// on the ProxyContext object to ensure that onComplete gets called exactly
// one. If we can't allocate the object, we have to handle it specially.
ProxyContextPtr ctxt(new (std::nothrow) ProxyContext(context));
if (!ctxt)
{
client->onComplete(
resultNotEnoughMemory,
context,
NULL,
code,
NULL,
NULL
);
return;
}
Result retval = resultUnknownServerGroup;
try
{
// Store the in parameters in a RadiusPacket struct.
RadiusPacket packet;
packet.code = code;
packet.begin = const_cast<RadiusAttribute*>(begin);
packet.end = const_cast<RadiusAttribute*>(end);
// Generate the list of RADIUS requests to be sent.
RequestStack requests;
switch (code)
{
case RADIUS_ACCESS_REQUEST:
{
// Is this request associated with a particular server?
RemoteServerPtr server = getServerAffinity(packet);
if (server)
{
requests.push(new Request(ctxt, server, RADIUS_ACCESS_REQUEST));
}
else
{
server = getServerAvoidance(packet);
groups.getServersByGroup(
ctxt,
code,
serverGroup,
server,
requests
);
}
// Put request authenticator in the packet. The request
// authenticator can be NULL. The authenticator will not be
// changed.
packet.authenticator = requestAuthenticator;
break;
}
case RADIUS_ACCOUNTING_REQUEST:
{
if (!IsNasStateRequest(packet))
{
groups.getServersByGroup(
ctxt,
code,
serverGroup,
0,
requests
);
}
else
{
groups.getServersForAcctOnOff(
ctxt,
requests
);
// NAS State requests are always reported as a success since we
// don't care if it gets to all the destinations.
context = ctxt->takeOwnership();
if (context)
{
client->onComplete(
resultSuccess,
context,
NULL,
RADIUS_ACCOUNTING_RESPONSE,
NULL,
NULL
);
}
retval = resultSuccess;
}
break;
}
default:
{
retval = resultInvalidRequest;
}
}
if (!requests.empty())
{
// First we handle the primary.
RequestPtr request = requests.pop();
ctxt->setPrimaryServer(&request->getServer());
retval = sendRequest(packet, request);
// Now we broadcast.
while (!requests.empty())
{
request = requests.pop();
Result result = sendRequest(packet, request);
if (result == resultSuccess && retval != resultSuccess)
{
// This was the first request to succeed so mark it as primary.
retval = resultSuccess;
ctxt->setPrimaryServer(&request->getServer());
}
}
}
}
catch (const std::bad_alloc&)
{
retval = resultNotEnoughMemory;
}
if (retval != resultSuccess)
{
// If we made it here, then we didn't successfully send a packet to any
// server, so we have to report the result ourself.
context = ctxt->takeOwnership();
if (context)
{
client->onComplete(
retval,
context,
ctxt->getPrimaryServer(),
code,
NULL,
NULL
);
}
}
}
void RadiusProxyEngine::onRequestAbandoned(
PVOID context,
RemoteServer* server
) throw ()
{
// Nobody took responsibility for the request, so we time it out.
theProxy->client->onComplete(
resultRequestTimeout,
context,
server,
0,
NULL,
NULL
);
}
inline void RadiusProxyEngine::reportEvent(
const RadiusEvent& event
) const throw ()
{
client->onEvent(event);
}
inline void RadiusProxyEngine::reportEvent(
RadiusEvent& event,
RadiusEventType type
) const throw ()
{
event.eventType = type;
client->onEvent(event);
}
void RadiusProxyEngine::onRequestTimeout(
Request* request
) throw ()
{
// Erase the pending request. If it's not there, that's okay; it means that
// we received a response, but weren't able to cancel the timer in time.
if (theProxy->pending.erase(request->getRequestID()))
{
// Avoid this server next time.
theProxy->setServerAvoidance(*request);
RadiusEvent event =
{
request->getPortType(),
eventTimeout,
&request->getServer(),
request->getPort().address.address(),
request->getPort().address.port()
};
// Report the protocol event.
theProxy->reportEvent(event);
// Update request state.
if (request->onTimeout())
{
// The server was just marked unavailable, so notify the client.
theProxy->reportEvent(event, eventServerUnavailable);
}
}
}
RemoteServerPtr RadiusProxyEngine::getServerAffinity(
const RadiusPacket& packet
) throw ()
{
// Find the State attribute.
const RadiusAttribute* attr = FindAttribute(packet, RADIUS_STATE);
if (!attr) { return NULL; }
// Map it to a session.
RadiusRawOctets key = { attr->value, attr->length };
ServerBindingPtr session = sessions.find(key);
if (!session) { return NULL; }
return &session->getServer();
}
void RadiusProxyEngine::setServerAffinity(
const RadiusPacket& packet,
RemoteServer& server
) throw ()
{
// Is this an Access-Challenge ?
if (packet.code != RADIUS_ACCESS_CHALLENGE) { return; }
// Find the State attribute.
const RadiusAttribute* state = FindAttribute(packet, RADIUS_STATE);
if (!state) { return; }
// Do we already have an entry for this State value.
RadiusRawOctets key = { state->value, state->length };
ServerBindingPtr session = sessions.find(key);
if (session)
{
// Make sure the server matches.
session->setServer(server);
return;
}
// Otherwise, we'll have to create a new one.
try
{
session = new ServerBinding(key, server);
sessions.insert(*session);
}
catch (const std::bad_alloc&)
{
// We don't care if this fails.
}
}
void RadiusProxyEngine::clearServerAvoidance(
const RadiusPacket& packet,
RemoteServer& server
) throw ()
{
// Is this packet authoritative?
if ((packet.code == RADIUS_ACCESS_ACCEPT) ||
(packet.code == RADIUS_ACCESS_REJECT))
{
// Find the User-Name attribute.
const RadiusAttribute* attr = FindAttribute(packet, RADIUS_USER_NAME);
if (attr != 0)
{
// Map it to a server.
RadiusRawOctets key = { attr->value, attr->length };
ServerBindingPtr avoidance = avoid.find(key);
if (avoidance && (avoidance->getServer() == server))
{
avoid.erase(key);
}
}
}
}
RemoteServerPtr RadiusProxyEngine::getServerAvoidance(
const RadiusPacket& packet
) throw ()
{
// Find the User-Name attribute.
const RadiusAttribute* attr = FindAttribute(packet, RADIUS_USER_NAME);
if (!attr) { return NULL; }
// Map it to a server.
RadiusRawOctets key = { attr->value, attr->length };
ServerBindingPtr avoidance = avoid.find(key);
if (!avoidance) { return NULL; }
return &avoidance->getServer();
}
void RadiusProxyEngine::setServerAvoidance(const Request& request) throw ()
{
if ((request.getCode() != RADIUS_ACCESS_REQUEST) ||
(request.getUserName().len == 0))
{
return;
}
// Do we already have an entry for this User-Name value.
ServerBindingPtr avoidance = avoid.find(request.getUserName());
if (avoidance)
{
// Make sure the server matches.
avoidance->setServer(request.getServer());
return;
}
// Otherwise, we'll have to create a new one.
try
{
avoidance = new ServerBinding(
request.getUserName(),
request.getServer()
);
avoid.insert(*avoidance);
}
catch (const std::bad_alloc&)
{
// We don't care if this fails.
}
}
void RadiusProxyEngine::onReceive(
UDPSocket& socket,
ULONG_PTR key,
const SOCKADDR_IN& remoteAddress,
BYTE* buffer,
ULONG bufferLength
) throw ()
{
//////////
// Set up the event struct. We'll fill in the other fields as we go along.
//////////
RadiusEvent event =
{
(RadiusPortType)key,
eventNone,
NULL,
remoteAddress.sin_addr.s_addr,
remoteAddress.sin_port,
buffer,
bufferLength,
0
};
//////////
// Validate the remote address.
//////////
RemoteServerPtr server = groups.findServer(
remoteAddress.sin_addr.s_addr
);
if (!server)
{
reportEvent(event, eventInvalidAddress);
return;
}
// Use the server as the event context.
event.context = server;
//////////
// Validate the packet type.
//////////
if (bufferLength == 0)
{
reportEvent(event, eventUnknownType);
return;
}
switch (MAKELONG(key, buffer[0]))
{
case MAKELONG(portAuthentication, RADIUS_ACCESS_ACCEPT):
reportEvent(event, eventAccessAccept);
break;
case MAKELONG(portAuthentication, RADIUS_ACCESS_REJECT):
reportEvent(event, eventAccessReject);
break;
case MAKELONG(portAuthentication, RADIUS_ACCESS_CHALLENGE):
reportEvent(event, eventAccessChallenge);
break;
case MAKELONG(portAccounting, RADIUS_ACCOUNTING_RESPONSE):
reportEvent(event, eventAccountingResponse);
break;
default:
reportEvent(event, eventUnknownType);
return;
}
//////////
// Validate that the packet is properly formatted.
//////////
RadiusPacket* packet;
ALLOC_PACKET_FOR_BUFFER(packet, buffer, bufferLength);
if (!packet)
{
reportEvent(event, eventMalformedPacket);
return;
}
// Unpack the attributes.
UnpackBuffer(buffer, bufferLength, *packet);
//////////
// Validate that we were expecting this response.
//////////
// Look for our Proxy-State attribute.
RadiusAttribute* proxyState = FindAttribute(
*packet,
RADIUS_PROXY_STATE
);
// If we didn't find it OR it's the wrong length OR it doesn't start with
// our address, then we weren't expecting this packet.
if (!proxyState ||
proxyState->length != 8 ||
memcmp(proxyState->value, &proxyAddress, 4))
{
reportEvent(event, eventUnexpectedResponse);
return;
}
// Extract the request ID.
ULONG requestID = ExtractUInt32(proxyState->value + 4);
// Don't send the Proxy-State back to our client.
--packet->end;
memmove(
proxyState,
proxyState + 1,
(packet->end - proxyState) * sizeof(RadiusAttribute)
);
// Look up the request object. We don't remove it yet because we don't know
// if this is an authentic response.
RequestPtr request = pending.find(requestID);
if (!request)
{
// If it's not there, we'll assume that this is a packet that's
// already been reported as a timeout.
reportEvent(event, eventLateResponse);
return;
}
// Get the actual server we used for the request in case there are multiple
// servers defined for the same IP address.
event.context = server = &request->getServer();
const RemotePort& port = request->getPort();
// Validate the packet source && identifier.
if (!(port.address == remoteAddress) ||
request->getIdentifier() != packet->identifier)
{
reportEvent(event, eventUnexpectedResponse);
return;
}
//////////
// Validate that the packet is authentic.
//////////
AuthResult authResult = AuthenticateAndDecrypt(
request->getAuthenticator(),
port.secret,
port.secret.length(),
buffer,
bufferLength,
*packet
);
switch (authResult)
{
case AUTH_BAD_AUTHENTICATOR:
reportEvent(event, eventBadAuthenticator);
return;
case AUTH_BAD_SIGNATURE:
reportEvent(event, eventBadSignature);
return;
case AUTH_MISSING_SIGNATURE:
reportEvent(event, eventMissingSignature);
return;
}
//////////
// At this point, all the tests have passed -- we have the real thing.
//////////
if (!pending.erase(requestID))
{
// It must have timed out while we were authenticating it.
reportEvent(event, eventLateResponse);
return;
}
// Update endpoint state.
if (request->onReceive(packet->code))
{
// The server just came up, so notify the client.
reportEvent(event, eventServerAvailable);
}
// Report the round-trip time.
event.data = request->getRoundTripTime();
reportEvent(event, eventRoundTrip);
// Set the server affinity and clear the server avoidance.
setServerAffinity(*packet, *server);
clearServerAvoidance(*packet, *server);
// Take ownership of the context.
PVOID context = request->getContext().takeOwnership();
if (context)
{
// The magic moment -- we have successfully processed the response.
client->onComplete(
resultSuccess,
context,
&request->getServer(),
packet->code,
packet->begin,
packet->end
);
}
}
void RadiusProxyEngine::onReceiveError(
UDPSocket& socket,
ULONG_PTR key,
ULONG errorCode
) throw ()
{
RadiusEvent event =
{
(RadiusPortType)key,
eventReceiveError,
NULL,
socket.getLocalAddress().address(),
socket.getLocalAddress().port(),
NULL,
0,
errorCode
};
client->onEvent(event);
}
RadiusProxyEngine::Result RadiusProxyEngine::sendRequest(
RadiusPacket& packet,
Request* request
) throw ()
{
// Fill in the packet identifier.
packet.identifier = request->getIdentifier();
// Get the info for the Signature.
BOOL sign = request->getServer().sendSignature;
// Format the Proxy-State attributes.
BYTE proxyStateValue[8];
RadiusAttribute proxyState = { RADIUS_PROXY_STATE, 8, proxyStateValue };
// First our IP address ...
memcpy(proxyStateValue, &proxyAddress, 4);
// ... and then the unique request ID.
InsertUInt32(proxyStateValue + 4, request->getRequestID());
// Allocate a buffer to hold the packet on the wire.
PBYTE buffer;
ALLOC_BUFFER_FOR_PACKET(buffer, &packet, &proxyState, sign);
if (!buffer) { return resultInvalidRequest; }
// Get the port for this request.
const RemotePort& port = request->getPort();
// Generate the request authenticator if necessary.
BYTE requestAuthenticator[16];
if ((packet.code == RADIUS_ACCESS_REQUEST) &&
(packet.authenticator == 0))
{
if (!CryptGenRandom(
crypto,
sizeof(requestAuthenticator),
requestAuthenticator
))
{
return resultCryptoError;
}
packet.authenticator = requestAuthenticator;
}
// Pack the buffer. packet.authenticator is used for CHAP when the request
// authenticator is used for the chap-challenge. It can be null
PackBuffer(
port.secret,
port.secret.length(),
packet,
&proxyState,
sign,
buffer
);
// Save the request authenticator and packet.
request->setAuthenticator(buffer + 4);
request->setPacket(packet);
// Determine the request type.
bool isAuth = request->isAccReq();
// Set up the event struct.
RadiusEvent event =
{
(isAuth ? portAuthentication : portAccounting),
(isAuth ? eventAccessRequest : eventAccountingRequest),
&request->getServer(),
port.address.address(),
port.address.port(),
buffer,
packet.length
};
// Get the appropriate socket.
UDPSocket& sock = isAuth ? authSock : acctSock;
// Insert the pending request before we send it to avoid a race condition.
pending.insert(*request);
// The magic moment -- we actually send the request.
Result result;
if (sock.send(port.address, buffer, packet.length))
{
// Update request state.
request->onSend();
// Set a timer to clean up if the server doesn't answer.
if (timers.setTimer(request, request->getServer().timeout, 0))
{
result = resultSuccess;
}
else
{
// If we can't set at timer we have to remove it from the pending
// requests table or else it could leak.
pending.erase(*request);
result = resultNotEnoughMemory;
}
}
else
{
// Update the event with the error data.
event.eventType = eventSendError;
event.data = GetLastError();
// If we received "Port Unreachable" ICMP packet, we'll count this as a
// timeout since it means the server is unavailable.
if (event.data == WSAECONNRESET) { request->onTimeout(); }
// Remove from the pending requests table.
pending.erase(*request);
}
// Report the event ...
reportEvent(event);
// ... and the result.
return result;
}