Leaked source code of windows server 2003
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

529 lines
14 KiB

///////////////////////////////////////////////////////////////////////////////
//
// Copyright (c) 2000, Microsoft Corp. All rights reserved.
//
// FILE
//
// radproxy.h
//
// SYNOPSIS
//
// Declares the interface into the reusable RadiusProxy engine. This should
// have no IAS specific dependencies.
//
// MODIFICATION HISTORY
//
// 02/08/2000 Original version.
// 05/30/2000 Eliminate QUESTIONABLE state.
//
///////////////////////////////////////////////////////////////////////////////
#ifndef RADPROXY_H
#define RADPROXY_H
#if _MSC_VER >= 1000
#pragma once
#endif
#include <iascache.h>
#include <iasobj.h>
#include <radshare.h>
#include <timerq.h>
#include <udpsock.h>
#include <wincrypt.h>
struct RadiusAttribute;
struct RadiusPacket;
class Request;
class ServerBinding;
///////////////////////////////////////////////////////////////////////////////
//
// CLASS
//
// RemotePort
//
// DESCRIPTION
//
// Describes a remote endpoint for a RADIUS conversation.
//
///////////////////////////////////////////////////////////////////////////////
class RemotePort
{
public:
// Read-only properties.
const InternetAddress address;
const RadiusOctets secret;
RemotePort(
ULONG ipAddress,
USHORT port,
PCSTR sharedSecret
);
RemotePort(const RemotePort& port);
// Returns a packet identifier to use when sending a request to this port.
BYTE getIdentifier() throw ()
{ return (BYTE)++nextIdentifier; }
// Synchronizes this with the state of 'port', i.e., use the same next
// identifier.
void copyState(const RemotePort& port) throw ()
{ nextIdentifier = port.nextIdentifier; }
bool operator==(const RemotePort& p) const throw ()
{ return address == p.address && secret == p.secret; }
private:
Count nextIdentifier;
// Not implemented.
RemotePort& operator=(RemotePort&);
};
///////////////////////////////////////////////////////////////////////////////
//
// struct
//
// RemoteServerConfig
//
// DESCRIPTION
//
// Plain ol' data holding all the configuration associated with a
// RemoteServer. This spares clients from having to call a monster
// contructor when creating a remote server.
//
///////////////////////////////////////////////////////////////////////////////
struct RemoteServerConfig
{
GUID guid;
ULONG ipAddress;
USHORT authPort;
USHORT acctPort;
PCSTR authSecret;
PCSTR acctSecret;
ULONG priority;
ULONG weight;
ULONG timeout;
ULONG maxLost;
ULONG blackout;
bool sendSignature;
bool sendAcctOnOff;
};
///////////////////////////////////////////////////////////////////////////////
//
// CLASS
//
// RemoteServer
//
// DESCRIPTION
//
// Describes a remote RADIUS server and maintains the state of that server.
//
///////////////////////////////////////////////////////////////////////////////
class RemoteServer
{
public:
DECLARE_REFERENCE_COUNT();
// Unique ID for this server.
const GUID guid;
// Authentication and accounting ports.
RemotePort authPort;
RemotePort acctPort;
// Read-only properties for load-balancing and failover.
const ULONG priority;
const ULONG weight;
// Read-only properties for determining server state.
const ULONG timeout;
const LONG maxEvents;
const ULONG blackout;
// Should we always send a Signature attribute?
const bool sendSignature;
// Should we formard Accounting-On/Off requests?
const bool sendAcctOnOff;
RemoteServer(const RemoteServerConfig& config);
// Returns the servers IP address.
ULONG getAddress() const throw ()
{ return authPort.address.sin_addr.s_addr; }
// Returns 'true' if the server has a probationary request pending.
bool isInProgress() const throw ()
{ return onProbation && !usable; }
// Returns 'true' if the server is available for use.
bool isUsable() const throw ()
{ return usable; }
// Returns 'true' if the server should receive a broadcast.
bool shouldBroadcast() throw ();
// Notifies the RemoteServer that a valid packet has been received. Returns
// true if this triggers a state change.
bool onReceive(BYTE code) throw ();
// Notifies the RemoteServer that a packet has been sent.
void onSend() throw ();
// Notfies the RemoteServer that a request has timed out. Returns true if
// this triggers a state change.
bool onTimeout() throw ();
// Synchronize the state of this server with target.
void copyState(const RemoteServer& target) throw ();
bool operator==(const RemoteServer& s) const throw ();
protected:
// This is virtual so that RemoteServer can server as a base class.
virtual ~RemoteServer() throw () { }
private:
CriticalSection lock;
bool usable; // true if the server is available.
bool onProbation; // true if the server is on probation.
long eventCount; // Number of packets lost/found
ULONG64 expiry; // Time when blackout interval expires.
// Not implemented.
RemoteServer& operator=(RemoteServer&);
};
typedef ObjectPointer<RemoteServer> RemoteServerPtr;
typedef ObjectVector<RemoteServer> RemoteServers;
class RequestStack;
class ProxyContext;
///////////////////////////////////////////////////////////////////////////////
//
// CLASS
//
// ServerGroup
//
// DESCRIPTION
//
// Load balances requests among a group of RemoteServers.
//
///////////////////////////////////////////////////////////////////////////////
class ServerGroup
{
public:
DECLARE_REFERENCE_COUNT();
ServerGroup(
PCWSTR groupName,
RemoteServer* const* first,
RemoteServer* const* last
);
// Returns the number of servers in the group.
ULONG size() const throw ()
{ return servers.size(); }
bool isEmpty() const throw ()
{ return servers.empty(); }
// Name used to identify the group.
PCWSTR getName() const throw ()
{ return name; }
// Returns a collection of servers that should receive the request.
void getServersForRequest(
ProxyContext* context,
BYTE packetCode,
const RemoteServer* avoid,
RequestStack& result
) const;
// Methods for iterating the servers in the group.
RemoteServers::iterator begin() const throw ()
{ return servers.begin(); }
RemoteServers::iterator end() const throw ()
{ return servers.end(); }
private:
~ServerGroup() throw () { }
// Pick a server from the list. The list must not be empty, and all the
// servers must have the same priority. If 'avoid' is not null and there is
// more than one server in the list, the indicated server won't be picked.
static RemoteServer* pickServer(
RemoteServers::iterator first,
RemoteServers::iterator last,
const RemoteServer* avoid = 0
) throw ();
// Array of servers in priority order.
RemoteServers servers;
// End of top priority level in array.
RemoteServers::iterator endTopPriority;
// Maximum number of bytes required to hold the server candidates.
ULONG maxCandidatesSize;
RadiusString name;
static ULONG theSeed;
// Not implemented.
ServerGroup(const ServerGroup&);
ServerGroup& operator=(const ServerGroup&);
};
typedef ObjectPointer<ServerGroup> ServerGroupPtr;
typedef ObjectVector<ServerGroup> ServerGroups;
///////////////////////////////////////////////////////////////////////////////
//
// CLASS
//
// ServerGroupManager
//
// DESCRIPTION
//
// Manages a collection of ServerGroups.
//
///////////////////////////////////////////////////////////////////////////////
class ServerGroupManager
{
public:
ServerGroupManager() throw () { }
// Set the server groups to be managed.
bool setServerGroups(
ServerGroups::iterator begin,
ServerGroups::iterator end
) throw ();
// Returns a server with the given IP address.
RemoteServerPtr findServer(
ULONG address
) const throw ();
void getServersByGroup(
ProxyContext* context,
BYTE packetCode,
PCWSTR name,
const RemoteServer* avoid,
RequestStack& result
) const;
void getServersForAcctOnOff(
ProxyContext* context,
RequestStack& result
) const;
private:
// Synchronize access.
mutable RWLock monitor;
// Server groups being managed sorted by name.
ServerGroups groups;
// All servers sorted by guid.
RemoteServers byAddress;
// All servers sorted by guid.
RemoteServers byGuid;
// Servers to receive Accounting-On/Off requests.
RemoteServers acctServers;
// Not implemented.
ServerGroupManager(const ServerGroupManager&);
ServerGroupManager& operator=(const ServerGroupManager&);
};
class RadiusProxyClient;
///////////////////////////////////////////////////////////////////////////////
//
// CLASS
//
// RadiusProxyEngine
//
// DESCRIPTION
//
// Implements a RADIUS proxy.
//
///////////////////////////////////////////////////////////////////////////////
class RadiusProxyEngine : PacketReceiver
{
public:
// Final result of processing a request.
enum Result
{
resultSuccess,
resultNotEnoughMemory,
resultUnknownServerGroup,
resultUnknownServer,
resultInvalidRequest,
resultSendError,
resultRequestTimeout,
resultCryptoError
};
RadiusProxyEngine(RadiusProxyClient* source);
~RadiusProxyEngine() throw ();
HRESULT finalConstruct() throw ();
// Set the server groups to be used by the proxy.
bool setServerGroups(
ServerGroup* const* begin,
ServerGroup* const* end
) throw ();
// Forward a request to the given server group.
void forwardRequest(
PVOID context,
PCWSTR serverGroup,
BYTE code,
const BYTE* requestAuthenticator,
const RadiusAttribute* begin,
const RadiusAttribute* end
) throw ();
// Callback when a request context has been abandoned.
static void onRequestAbandoned(
PVOID context,
RemoteServer* server
) throw ();
// Callback when a request has timed out.
static void onRequestTimeout(
Request* request
) throw ();
private:
// Methods for associating a stateful authentication session with a
// particular server.
RemoteServerPtr getServerAffinity(
const RadiusPacket& packet
) throw ();
void setServerAffinity(
const RadiusPacket& packet,
RemoteServer& server
) throw ();
// Methods for associating a bad server with a User-Name.
void clearServerAvoidance(
const RadiusPacket& packet,
RemoteServer& server
) throw ();
RemoteServerPtr getServerAvoidance(
const RadiusPacket& packet
) throw ();
void setServerAvoidance(const Request& request) throw ();
// PacketReceiver callbacks.
virtual void onReceive(
UDPSocket& socket,
ULONG_PTR key,
const SOCKADDR_IN& remoteAddress,
BYTE* buffer,
ULONG bufferLength
) throw ();
virtual void onReceiveError(
UDPSocket& socket,
ULONG_PTR key,
ULONG errorCode
) throw ();
// Forward a request to an individual RemoteServer.
Result sendRequest(
RadiusPacket& packet,
Request* request
) throw ();
// Report an event to the client.
void reportEvent(
const RadiusEvent& event
) const throw ();
void reportEvent(
RadiusEvent& event,
RadiusEventType type
) const throw ();
// Callback when a timer has expired.
static VOID NTAPI onTimerExpiry(PVOID context, BOOLEAN flag) throw ();
// The object supplying us with requests.
RadiusProxyClient* client;
// The local address of the proxy. Used when forming Proxy-State.
ULONG proxyAddress;
// UDP sockets used for network I/O.
UDPSocket authSock;
UDPSocket acctSock;
// Server groups used for processing groups.
ServerGroupManager groups;
// Table of pending requests.
HashTable< LONG, Request > pending;
// Queue of pending requests.
TimerQueue timers;
// Table of current authentication sessions.
Cache< RadiusRawOctets, ServerBinding > sessions;
// Table of servers to avoid for a given User-Name.
Cache< RadiusRawOctets, ServerBinding > avoid;
// Used for generating request authenticators.
HCRYPTPROV crypto;
// Global pointer to the RadiusProxyEngine. This is a hack, but it saves me
// from having to give every Request and Context object a back pointer.
static RadiusProxyEngine* theProxy;
// Not implemented.
RadiusProxyEngine(const RadiusProxyEngine&);
RadiusProxyEngine& operator=(const RadiusProxyEngine&);
};
///////////////////////////////////////////////////////////////////////////////
//
// CLASS
//
// RadiusProxyClient
//
// DESCRIPTION
//
// Abstract base class for clients of the RadiusProxy engine.
//
///////////////////////////////////////////////////////////////////////////////
class __declspec(novtable) RadiusProxyClient
{
public:
// Invoked to report one of the above events.
virtual void onEvent(
const RadiusEvent& event
) throw () = 0;
// Invoked exactly once for each call to RadiusProxyEngine::forwardRequest.
virtual void onComplete(
RadiusProxyEngine::Result result,
PVOID context,
RemoteServer* server,
BYTE code,
const RadiusAttribute* begin,
const RadiusAttribute* end
) throw () = 0;
};
#endif // RADPROXY_H