Counter Strike : Global Offensive Source Code
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

243 lines
8.4 KiB

  1. //========= Copyright � 1996-2005, Valve Corporation, All rights reserved. ============//
  2. //
  3. // Purpose: Handles all the functions for implementing remote access to the engine
  4. //
  5. //=============================================================================//
  6. #include "netadr.h"
  7. #include "sv_ipratelimit.h"
  8. #include "convar.h"
  9. #include "utlrbtree.h"
  10. #include "utlvector.h"
  11. #include "utlmap.h"
  12. #include "../gcsdk/steamextra/tier1/utlhashmaplarge.h"
  13. #include "filesystem.h"
  14. #include "sv_log.h"
  15. #include "tier1/ns_address.h"
  16. // NOTE: This has to be the last file included!
  17. #include "tier0/memdbgon.h"
  18. static ConVar sv_max_queries_sec( "sv_max_queries_sec", "10.0", FCVAR_RELEASE, "Maximum queries per second to respond to from a single IP address." );
  19. static ConVar sv_max_queries_window( "sv_max_queries_window", "30", FCVAR_RELEASE, "Window over which to average queries per second averages." );
  20. static ConVar sv_max_queries_tracked_ips_max( "sv_max_queries_tracked_ips_max", "50000", FCVAR_RELEASE, "Window over which to average queries per second averages." );
  21. static ConVar sv_max_queries_tracked_ips_prune( "sv_max_queries_tracked_ips_prune", "10", FCVAR_RELEASE, "Window over which to average queries per second averages." );
  22. static ConVar sv_max_queries_sec_global( "sv_max_queries_sec_global", "500", FCVAR_RELEASE, "Maximum queries per second to respond to from anywhere." );
  23. static ConVar sv_logblocks("sv_logblocks", "0", FCVAR_RELEASE, "If true when log when a query is blocked (can cause very large log files)");
  24. class CIPRateLimit
  25. {
  26. public:
  27. CIPRateLimit();
  28. ~CIPRateLimit();
  29. // updates an ip entry, return true if the ip is allowed, false otherwise
  30. bool CheckIP( netadr_t ip );
  31. void Reset()
  32. {
  33. m_IPTimes.RemoveAll();
  34. m_IPStorage.RemoveAll();
  35. m_iGlobalCount = 0;
  36. m_lLastTime = -1;
  37. m_lLastDistributedDetection = -1;
  38. m_lLastPersonalDetection = -1;
  39. }
  40. private:
  41. typedef int ip_t;
  42. struct iprate_val
  43. {
  44. long lastTime;
  45. int count;
  46. int32 idxiptime;
  47. };
  48. struct IpHashNoopFunctor
  49. {
  50. typedef uint32 TargetType;
  51. TargetType operator()( const ip_t &key ) const
  52. {
  53. return key;
  54. }
  55. };
  56. typedef CUtlHashMapLarge< ip_t, iprate_val, CDefEquals< ip_t >, IpHashNoopFunctor > IPStorage_t;
  57. IPStorage_t m_IPStorage;
  58. typedef CUtlMap< long, ip_t, int32, CDefLess< long > > IPTimes_t;
  59. IPTimes_t m_IPTimes;
  60. int m_iGlobalCount;
  61. long m_lLastTime;
  62. long m_lLastDistributedDetection;
  63. long m_lLastPersonalDetection;
  64. };
  65. static CIPRateLimit rateChecker;
  66. //-----------------------------------------------------------------------------
  67. // Purpose: return false if this IP exceeds rate limits
  68. //-----------------------------------------------------------------------------
  69. bool CheckConnectionLessRateLimits( const ns_address &adr )
  70. {
  71. if ( !adr.IsType< netadr_t >() )
  72. return true;
  73. // This function can be called from socket thread, mutex around it
  74. static CThreadMutex s_mtx;
  75. AUTO_LOCK( s_mtx );
  76. bool ret = rateChecker.CheckIP( adr.AsType<netadr_t>() );
  77. if ( !ret && sv_logblocks.GetBool() == true )
  78. {
  79. g_Log.Printf("Traffic from %s was blocked for exceeding rate limits\n", ns_address_render( adr ).String() );
  80. }
  81. return ret;
  82. }
  83. //-----------------------------------------------------------------------------
  84. // Purpose: Constructor
  85. //-----------------------------------------------------------------------------
  86. CIPRateLimit::CIPRateLimit()
  87. {
  88. m_iGlobalCount = 0;
  89. m_lLastTime = -1;
  90. m_lLastDistributedDetection = -1;
  91. m_lLastPersonalDetection = -1;
  92. }
  93. //-----------------------------------------------------------------------------
  94. // Purpose: Destructor
  95. //-----------------------------------------------------------------------------
  96. CIPRateLimit::~CIPRateLimit()
  97. {
  98. }
  99. //-----------------------------------------------------------------------------
  100. // Purpose: return false if this IP has exceeded limits
  101. //-----------------------------------------------------------------------------
  102. bool CIPRateLimit::CheckIP( netadr_t adr )
  103. {
  104. long curTime = (long)Plat_FloatTime();
  105. // check the per ip rate (do this first, so one person dosing doesn't add to the global max rate
  106. ip_t clientIP;
  107. memcpy( &clientIP, adr.ip, sizeof(ip_t) );
  108. int const MAX_TREE_SIZE = sv_max_queries_tracked_ips_max.GetInt();
  109. int const MAX_TREE_PRUNE = sv_max_queries_tracked_ips_prune.GetInt();
  110. // Prune some elements from the tree
  111. int numPruned = 0;
  112. for ( int32 itIPTime = m_IPTimes.FirstInorder(); ( itIPTime != m_IPTimes.InvalidIndex() ); )
  113. {
  114. int32 itIPTimeNext = m_IPTimes.NextInorder( itIPTime );
  115. ip_t ipTracked = m_IPTimes.Element( itIPTime );
  116. if ( ipTracked != clientIP )
  117. {
  118. if ( ( curTime - m_IPTimes.Key( itIPTime ) ) < sv_max_queries_window.GetFloat() )
  119. break; // need to still keep monitoring this IP address, time is in order so next ones are even more recent
  120. m_IPStorage.Remove( ipTracked );
  121. m_IPTimes.RemoveAt( itIPTime );
  122. ++ numPruned;
  123. if ( ( numPruned >= MAX_TREE_PRUNE ) && ( m_IPStorage.Count() < MAX_TREE_SIZE ) )
  124. break;
  125. }
  126. itIPTime = itIPTimeNext;
  127. }
  128. if ( m_IPStorage.Count() > MAX_TREE_SIZE )
  129. {
  130. // This looks like we are under distributed attack where we are seeing a
  131. // very large number of IP addresses in a short time period
  132. // Stop tracking individual IP addresses and turn on global rate limit
  133. Msg( "IP rate limit detected distributed packet load (%u buckets, %u global count).\n", m_IPStorage.Count(), m_iGlobalCount );
  134. Reset();
  135. m_iGlobalCount = MAX( 1, ( sv_max_queries_sec_global.GetFloat() + 1 ) * ( sv_max_queries_window.GetFloat() + 1 ) );
  136. m_lLastTime = curTime;
  137. m_lLastDistributedDetection = curTime;
  138. }
  139. // now find the entry and check if it's within our rate limits
  140. bool bPerIpLimitingPerformed = false;
  141. IPStorage_t::IndexType_t ipEntry = m_IPStorage.Find( clientIP );
  142. if ( m_IPStorage.IsValidIndex( ipEntry ) )
  143. {
  144. bPerIpLimitingPerformed = true;
  145. iprate_val &iprateval = m_IPStorage.Element( ipEntry );
  146. if ( ( curTime - iprateval.lastTime ) > sv_max_queries_window.GetFloat() )
  147. {
  148. float query_rate = static_cast< float >( iprateval.count ) / sv_max_queries_window.GetFloat(); // add one so the bottom is never zero
  149. if ( query_rate > sv_max_queries_sec.GetFloat() )
  150. {
  151. if ( ( curTime - m_lLastPersonalDetection ) > sv_max_queries_window.GetFloat()/10 )
  152. {
  153. Msg( "IP rate limiting client %s sustained %u hits at %.1f pps (%u buckets, %u global count).\n", adr.ToString(), iprateval.count, query_rate, m_IPStorage.Count(), m_iGlobalCount );
  154. }
  155. }
  156. m_IPTimes.RemoveAt( iprateval.idxiptime );
  157. iprateval.idxiptime = m_IPTimes.Insert( curTime, clientIP );
  158. iprateval.lastTime = curTime;
  159. iprateval.count = 1;
  160. }
  161. else
  162. {
  163. ++ iprateval.count;
  164. float query_rate = static_cast< float >( iprateval.count ) / sv_max_queries_window.GetFloat(); // add one so the bottom is never zero
  165. if ( query_rate > sv_max_queries_sec.GetFloat() )
  166. {
  167. if ( ( curTime - m_lLastPersonalDetection ) > sv_max_queries_window.GetFloat() )
  168. {
  169. m_lLastPersonalDetection = curTime;
  170. Msg( "IP rate limiting client %s at %u hits (%u buckets, %u global count).\n", adr.ToString(), iprateval.count, m_IPStorage.Count(), m_iGlobalCount );
  171. }
  172. return false;
  173. }
  174. }
  175. }
  176. // now check the global rate
  177. m_iGlobalCount++;
  178. if( (curTime - m_lLastTime) > sv_max_queries_window.GetFloat() )
  179. {
  180. float query_rate = static_cast< float >( m_iGlobalCount ) / sv_max_queries_window.GetFloat(); // add one so the bottom is never zero
  181. if ( query_rate > sv_max_queries_sec_global.GetFloat() )
  182. {
  183. if ( ( curTime - m_lLastDistributedDetection ) > sv_max_queries_window.GetFloat()/10 )
  184. {
  185. Msg( "IP rate limit sustained %u distributed packets at %.1f pps (%u buckets).\n", m_iGlobalCount, query_rate, m_IPStorage.Count() );
  186. }
  187. }
  188. m_lLastTime = curTime;
  189. m_iGlobalCount = 1;
  190. }
  191. else
  192. {
  193. float query_rate = static_cast<float>( m_iGlobalCount ) / sv_max_queries_window.GetFloat(); // add one so the bottom is never zero
  194. if( query_rate > sv_max_queries_sec_global.GetFloat() )
  195. {
  196. if ( ( curTime - m_lLastDistributedDetection ) > sv_max_queries_window.GetFloat() )
  197. {
  198. m_lLastDistributedDetection = curTime;
  199. Msg( "IP rate limit under distributed packet load (%u buckets, %u global count), rejecting %s.\n", m_IPStorage.Count(), m_iGlobalCount, adr.ToString() );
  200. }
  201. return false;
  202. }
  203. }
  204. if ( !bPerIpLimitingPerformed )
  205. {
  206. iprate_val iprateval;
  207. iprateval.count = 1;
  208. iprateval.lastTime = curTime;
  209. // not found, insert this new guy
  210. iprateval.idxiptime = m_IPTimes.Insert( curTime, clientIP );
  211. m_IPStorage.Insert( clientIP, iprateval );
  212. }
  213. return true;
  214. }