//========= Copyright © 1996-2005, Valve Corporation, All rights reserved. ============// // // Purpose: Handles all the functions for implementing remote access to the engine // //=============================================================================// #include "netadr.h" #include "sv_ipratelimit.h" #include "convar.h" #include "utlrbtree.h" #include "utlvector.h" #include "utlmap.h" #include "../gcsdk/steamextra/tier1/utlhashmaplarge.h" #include "filesystem.h" #include "sv_log.h" #include "tier1/ns_address.h" // NOTE: This has to be the last file included! #include "tier0/memdbgon.h" static ConVar sv_max_queries_sec( "sv_max_queries_sec", "10.0", FCVAR_RELEASE, "Maximum queries per second to respond to from a single IP address." ); static ConVar sv_max_queries_window( "sv_max_queries_window", "30", FCVAR_RELEASE, "Window over which to average queries per second averages." ); static ConVar sv_max_queries_tracked_ips_max( "sv_max_queries_tracked_ips_max", "50000", FCVAR_RELEASE, "Window over which to average queries per second averages." ); static ConVar sv_max_queries_tracked_ips_prune( "sv_max_queries_tracked_ips_prune", "10", FCVAR_RELEASE, "Window over which to average queries per second averages." ); static ConVar sv_max_queries_sec_global( "sv_max_queries_sec_global", "500", FCVAR_RELEASE, "Maximum queries per second to respond to from anywhere." ); static ConVar sv_logblocks("sv_logblocks", "0", FCVAR_RELEASE, "If true when log when a query is blocked (can cause very large log files)"); class CIPRateLimit { public: CIPRateLimit(); ~CIPRateLimit(); // updates an ip entry, return true if the ip is allowed, false otherwise bool CheckIP( netadr_t ip ); void Reset() { m_IPTimes.RemoveAll(); m_IPStorage.RemoveAll(); m_iGlobalCount = 0; m_lLastTime = -1; m_lLastDistributedDetection = -1; m_lLastPersonalDetection = -1; } private: typedef int ip_t; struct iprate_val { long lastTime; int count; int32 idxiptime; }; struct IpHashNoopFunctor { typedef uint32 TargetType; TargetType operator()( const ip_t &key ) const { return key; } }; typedef CUtlHashMapLarge< ip_t, iprate_val, CDefEquals< ip_t >, IpHashNoopFunctor > IPStorage_t; IPStorage_t m_IPStorage; typedef CUtlMap< long, ip_t, int32, CDefLess< long > > IPTimes_t; IPTimes_t m_IPTimes; int m_iGlobalCount; long m_lLastTime; long m_lLastDistributedDetection; long m_lLastPersonalDetection; }; static CIPRateLimit rateChecker; //----------------------------------------------------------------------------- // Purpose: return false if this IP exceeds rate limits //----------------------------------------------------------------------------- bool CheckConnectionLessRateLimits( const ns_address &adr ) { if ( !adr.IsType< netadr_t >() ) return true; // This function can be called from socket thread, mutex around it static CThreadMutex s_mtx; AUTO_LOCK( s_mtx ); bool ret = rateChecker.CheckIP( adr.AsType() ); if ( !ret && sv_logblocks.GetBool() == true ) { g_Log.Printf("Traffic from %s was blocked for exceeding rate limits\n", ns_address_render( adr ).String() ); } return ret; } //----------------------------------------------------------------------------- // Purpose: Constructor //----------------------------------------------------------------------------- CIPRateLimit::CIPRateLimit() { m_iGlobalCount = 0; m_lLastTime = -1; m_lLastDistributedDetection = -1; m_lLastPersonalDetection = -1; } //----------------------------------------------------------------------------- // Purpose: Destructor //----------------------------------------------------------------------------- CIPRateLimit::~CIPRateLimit() { } //----------------------------------------------------------------------------- // Purpose: return false if this IP has exceeded limits //----------------------------------------------------------------------------- bool CIPRateLimit::CheckIP( netadr_t adr ) { long curTime = (long)Plat_FloatTime(); // check the per ip rate (do this first, so one person dosing doesn't add to the global max rate ip_t clientIP; memcpy( &clientIP, adr.ip, sizeof(ip_t) ); int const MAX_TREE_SIZE = sv_max_queries_tracked_ips_max.GetInt(); int const MAX_TREE_PRUNE = sv_max_queries_tracked_ips_prune.GetInt(); // Prune some elements from the tree int numPruned = 0; for ( int32 itIPTime = m_IPTimes.FirstInorder(); ( itIPTime != m_IPTimes.InvalidIndex() ); ) { int32 itIPTimeNext = m_IPTimes.NextInorder( itIPTime ); ip_t ipTracked = m_IPTimes.Element( itIPTime ); if ( ipTracked != clientIP ) { if ( ( curTime - m_IPTimes.Key( itIPTime ) ) < sv_max_queries_window.GetFloat() ) break; // need to still keep monitoring this IP address, time is in order so next ones are even more recent m_IPStorage.Remove( ipTracked ); m_IPTimes.RemoveAt( itIPTime ); ++ numPruned; if ( ( numPruned >= MAX_TREE_PRUNE ) && ( m_IPStorage.Count() < MAX_TREE_SIZE ) ) break; } itIPTime = itIPTimeNext; } if ( m_IPStorage.Count() > MAX_TREE_SIZE ) { // This looks like we are under distributed attack where we are seeing a // very large number of IP addresses in a short time period // Stop tracking individual IP addresses and turn on global rate limit Msg( "IP rate limit detected distributed packet load (%u buckets, %u global count).\n", m_IPStorage.Count(), m_iGlobalCount ); Reset(); m_iGlobalCount = MAX( 1, ( sv_max_queries_sec_global.GetFloat() + 1 ) * ( sv_max_queries_window.GetFloat() + 1 ) ); m_lLastTime = curTime; m_lLastDistributedDetection = curTime; } // now find the entry and check if it's within our rate limits bool bPerIpLimitingPerformed = false; IPStorage_t::IndexType_t ipEntry = m_IPStorage.Find( clientIP ); if ( m_IPStorage.IsValidIndex( ipEntry ) ) { bPerIpLimitingPerformed = true; iprate_val &iprateval = m_IPStorage.Element( ipEntry ); if ( ( curTime - iprateval.lastTime ) > sv_max_queries_window.GetFloat() ) { float query_rate = static_cast< float >( iprateval.count ) / sv_max_queries_window.GetFloat(); // add one so the bottom is never zero if ( query_rate > sv_max_queries_sec.GetFloat() ) { if ( ( curTime - m_lLastPersonalDetection ) > sv_max_queries_window.GetFloat()/10 ) { Msg( "IP rate limiting client %s sustained %u hits at %.1f pps (%u buckets, %u global count).\n", adr.ToString(), iprateval.count, query_rate, m_IPStorage.Count(), m_iGlobalCount ); } } m_IPTimes.RemoveAt( iprateval.idxiptime ); iprateval.idxiptime = m_IPTimes.Insert( curTime, clientIP ); iprateval.lastTime = curTime; iprateval.count = 1; } else { ++ iprateval.count; float query_rate = static_cast< float >( iprateval.count ) / sv_max_queries_window.GetFloat(); // add one so the bottom is never zero if ( query_rate > sv_max_queries_sec.GetFloat() ) { if ( ( curTime - m_lLastPersonalDetection ) > sv_max_queries_window.GetFloat() ) { m_lLastPersonalDetection = curTime; Msg( "IP rate limiting client %s at %u hits (%u buckets, %u global count).\n", adr.ToString(), iprateval.count, m_IPStorage.Count(), m_iGlobalCount ); } return false; } } } // now check the global rate m_iGlobalCount++; if( (curTime - m_lLastTime) > sv_max_queries_window.GetFloat() ) { float query_rate = static_cast< float >( m_iGlobalCount ) / sv_max_queries_window.GetFloat(); // add one so the bottom is never zero if ( query_rate > sv_max_queries_sec_global.GetFloat() ) { if ( ( curTime - m_lLastDistributedDetection ) > sv_max_queries_window.GetFloat()/10 ) { Msg( "IP rate limit sustained %u distributed packets at %.1f pps (%u buckets).\n", m_iGlobalCount, query_rate, m_IPStorage.Count() ); } } m_lLastTime = curTime; m_iGlobalCount = 1; } else { float query_rate = static_cast( m_iGlobalCount ) / sv_max_queries_window.GetFloat(); // add one so the bottom is never zero if( query_rate > sv_max_queries_sec_global.GetFloat() ) { if ( ( curTime - m_lLastDistributedDetection ) > sv_max_queries_window.GetFloat() ) { m_lLastDistributedDetection = curTime; Msg( "IP rate limit under distributed packet load (%u buckets, %u global count), rejecting %s.\n", m_IPStorage.Count(), m_iGlobalCount, adr.ToString() ); } return false; } } if ( !bPerIpLimitingPerformed ) { iprate_val iprateval; iprateval.count = 1; iprateval.lastTime = curTime; // not found, insert this new guy iprateval.idxiptime = m_IPTimes.Insert( curTime, clientIP ); m_IPStorage.Insert( clientIP, iprateval ); } return true; }