Counter Strike : Global Offensive Source Code
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

354 lines
14 KiB

  1. //====== Copyright (c), Valve Corporation, All rights reserved. =======
  2. //
  3. // Purpose: Provides access to SQL at a high level
  4. //
  5. //=============================================================================
  6. #ifndef SQLACCESS_H
  7. #define SQLACCESS_H
  8. #ifdef _WIN32
  9. #pragma once
  10. #endif
  11. #include "gcsdk/gcsqlquery.h"
  12. #include "tier0/memdbgon.h"
  13. namespace GCSDK
  14. {
  15. class CGCSQLQuery;
  16. class CGCSQLQueryGroup;
  17. class CColumnSet;
  18. class CRecordType;
  19. //-----------------------------------------------------------------------------
  20. // Purpose: Provides access to SQL at a high level
  21. //-----------------------------------------------------------------------------
  22. class CSQLAccess
  23. {
  24. public:
  25. CSQLAccess( ESchemaCatalog eSchemaCatalog = k_ESchemaCatalogMain );
  26. ~CSQLAccess( );
  27. bool BBeginTransaction( const char *pchName );
  28. bool BCommitTransaction( bool bAllowEmpty = false );
  29. void RollbackTransaction();
  30. bool BInTransaction( ) const { return m_bInTransaction; }
  31. bool BYieldingExecute( const char *pchName, const char *pchSQLCommand, uint32 *pcRowsAffected = NULL, bool bSpewOnError = true );
  32. bool BYieldingExecuteString( const char *pchName, const char *pchSQLCommand, CFmtStr1024 *psResult, uint32 *pcRowsAffected = NULL );
  33. bool BYieldingExecuteScalarInt( const char *pchName, const char *pchSQLCommand, int *pnResult, uint32 *pcRowsAffected = NULL );
  34. bool BYieldingExecuteScalarIntWithDefault( const char *pchName, const char *pchSQLCommand, int *pnResult, int iDefaultValue, uint32 *pcRowsAffected = NULL );
  35. bool BYieldingExecuteScalarUint32( const char *pchName, const char *pchSQLCommand, uint32 *punResult, uint32 *pcRowsAffected = NULL );
  36. bool BYieldingExecuteScalarUint32WithDefault( const char *pchName, const char *pchSQLCommand, uint32 *punResult, uint32 unDefaultValue, uint32 *pcRowsAffected = NULL );
  37. bool BYieldingWipeTable( int iTable );
  38. template <typename TReturn, typename TCast>
  39. bool BYieldingExecuteSingleResult( const char *pchName, const char *pchSQLCommand, EGCSQLType eType, TReturn *pResult, uint32 *pcRowsAffected );
  40. template <typename TReturn, typename TCast>
  41. bool BYieldingExecuteSingleResultWithDefault( const char *pchName, const char *pchSQLCommand, EGCSQLType eType, TReturn *pResult, TReturn defaultValue, uint32 *pcRowsAffected );
  42. // manipulating CRecordBase (i.e. CSch...) objects in the database
  43. bool BYieldingInsertRecord( CRecordBase *pRecordBase );
  44. bool BYieldingInsertWithIdentity( CRecordBase* pRecordBase ) ;
  45. bool BYieldingReadRecordWithWhereColumns( CRecordBase *pRecord, const CColumnSet & readSet, const CColumnSet & whereSet );
  46. template< typename SchClass_t >
  47. bool BYieldingReadRecordFromPK( SchClass_t *pRecord );
  48. template< typename SchClass_t>
  49. bool BYieldingReadMultipleRecordsWithWhereColumns( CUtlVector< SchClass_t > *pvecRecords, const CColumnSet & whereSet, CUtlVector< SchClass_t > *pvecUnmatchedRecords = NULL );
  50. template< typename SchClass_t>
  51. bool BYieldingReadMultipleRecordsWithWhereColumns( CUtlVector< SchClass_t > *pvecRecords, const CColumnSet & readSet, const CColumnSet & whereSet, CUtlVector< SchClass_t > *pvecUnmatchedRecords = NULL );
  52. template< typename SchClass_t>
  53. bool BYieldingReadRecordsWithWhereClause( CUtlVector< SchClass_t > *pvecRecords, const char *pchWhereClause, const CColumnSet & readSet, const char *pchTopClause = NULL );
  54. template< typename SchClass_t>
  55. bool BYieldingReadRecordsWithQuery( CUtlVector< SchClass_t > *pvecRecords, const char *sQuery, const CColumnSet & readSet );
  56. bool BYieldingUpdateRecord( const CRecordBase &record, const CColumnSet & whereColumns, const CColumnSet & updateColumns );
  57. bool BYieldingDeleteRecord( const CRecordBase & record, const CColumnSet & whereColumns );
  58. void AddRecordParameters( const CRecordBase &record, const CColumnSet & columnSet );
  59. void AddBindParam( const char *pchValue );
  60. void AddBindParam( const int16 nValue );
  61. void AddBindParam( const uint16 uValue );
  62. void AddBindParam( const int32 nValue );
  63. void AddBindParam( const uint32 uValue );
  64. void AddBindParam( const uint64 ulValue );
  65. void AddBindParam( const uint8 *ubValue, const int cubValue );
  66. void AddBindParam( const float fValue );
  67. void AddBindParam( const double dValue );
  68. void ClearParams();
  69. IGCSQLResultSetList *GetResults();
  70. uint32 GetResultSetCount();
  71. uint32 GetResultSetRowCount( uint32 unResultSet );
  72. CSQLRecord GetResultRecord( uint32 unResultSet, uint32 unRow );
  73. private:
  74. enum EReadSingleResultResult
  75. {
  76. eReadSingle_Error, // something went wrong in the DB or the data was in a format we didn't expect
  77. eReadSingle_ResultFound, // we found a single result and copied the value -- all is well!
  78. eReadSingle_UseDefault, // we didn't find any results but we specified a value in advance for this case
  79. };
  80. EReadSingleResultResult BYieldingExecuteSingleResultDataInternal( const char *pchName, const char *pchSQLCommand, EGCSQLType eType, uint8 **pubData, uint32 *punSize, uint32 *pcRowsAffected, bool bHasDefaultValue );
  81. private:
  82. CGCSQLQuery *CurrentQuery();
  83. ESchemaCatalog m_eSchemaCatalog;
  84. CGCSQLQuery *m_pCurrentQuery;
  85. CGCSQLQueryGroup *m_pQueryGroup;
  86. bool m_bInTransaction;
  87. };
  88. #define FOR_EACH_SQL_RESULT( sqlAccess, resultSet, record ) \
  89. for( CSQLRecord record = (sqlAccess).GetResultRecord( resultSet, 0 ); record.IsValid(); record.NextRow() )
  90. //-----------------------------------------------------------------------------
  91. // Purpose: templatized version of querying for a single value
  92. //-----------------------------------------------------------------------------
  93. template <typename TReturn, typename TCast>
  94. bool CSQLAccess::BYieldingExecuteSingleResult( const char *pchName, const char *pchSQLCommand, EGCSQLType eType, TReturn *pResult, uint32 *pcRowsAffected )
  95. {
  96. uint8 *pubData;
  97. uint32 cubData;
  98. if( CSQLAccess::BYieldingExecuteSingleResultDataInternal( pchName, pchSQLCommand, eType, &pubData, &cubData, pcRowsAffected, false ) != eReadSingle_ResultFound )
  99. return false;
  100. *pResult = *( (TCast *)pubData );
  101. return true;
  102. }
  103. //-----------------------------------------------------------------------------
  104. // Purpose: templatized version of querying for a single value
  105. //-----------------------------------------------------------------------------
  106. template <typename TReturn, typename TCast>
  107. bool CSQLAccess::BYieldingExecuteSingleResultWithDefault( const char *pchName, const char *pchSQLCommand, EGCSQLType eType, TReturn *pResult, TReturn defaultValue, uint32 *pcRowsAffected )
  108. {
  109. uint8 *pubData;
  110. uint32 cubData;
  111. EReadSingleResultResult eResult = CSQLAccess::BYieldingExecuteSingleResultDataInternal( pchName, pchSQLCommand, eType, &pubData, &cubData, pcRowsAffected, true );
  112. if ( eResult == eReadSingle_Error )
  113. return false;
  114. if ( eResult == eReadSingle_ResultFound )
  115. {
  116. *pResult = *( (TCast *)pubData );
  117. }
  118. else
  119. {
  120. Assert( eResult == eReadSingle_UseDefault );
  121. *pResult = defaultValue;
  122. }
  123. return true;
  124. }
  125. //-----------------------------------------------------------------------------
  126. // Purpose: Reads the record with a given PK.
  127. // Input: pRecordBase - record to read
  128. // Output: true if successful, false otherwise
  129. //-----------------------------------------------------------------------------
  130. template < typename SchClass_t >
  131. bool CSQLAccess::BYieldingReadRecordFromPK( SchClass_t *pRecord )
  132. {
  133. CColumnSet csetWhere = CColumnSet::PrimaryKey< SchClass_t >();
  134. CColumnSet csetRead = CColumnSet::Inverse( csetWhere );
  135. return BYieldingReadRecordWithWhereColumns( pRecord, csetRead, csetWhere );
  136. }
  137. //-----------------------------------------------------------------------------
  138. // Purpose: Reads multiple records from the database based on the where columns
  139. // filled in for each record. If the record is not found in the database
  140. // it will be removed from pvecRecords. If pvecUnmatchedRecords is
  141. // provided, it will be populated with the unmatched records removed
  142. // from pvecRecords
  143. // Input: pvecRecords - The records to fill in from the database
  144. // whereSet - The set of columns to query on
  145. // (optional) pvecUnmatchedRecords - A vector to hold records which
  146. // are not found in the database
  147. // Output: true if successful, false otherwise
  148. //-----------------------------------------------------------------------------
  149. template< typename SchClass_t>
  150. bool CSQLAccess::BYieldingReadMultipleRecordsWithWhereColumns( CUtlVector< SchClass_t > *pvecRecords,
  151. const CColumnSet & whereSet,
  152. CUtlVector< SchClass_t > *pvecUnmatchedRecords /* = NULL */ )
  153. {
  154. CColumnSet readSet( GSchemaFull().GetSchema( SchClass_t::k_iTable ).GetRecordInfo() );
  155. readSet.MakeInverse( whereSet );
  156. return BYieldingReadMultipleRecordsWithWhereColumns( pvecRecords, readSet, whereSet, pvecUnmatchedRecords );
  157. }
  158. //-----------------------------------------------------------------------------
  159. // Purpose: Reads multiple records from the database based on the where columns
  160. // filled in for each record. If the record is not found in the database
  161. // it will be removed from pvecRecords. If pvecUnmatchedRecords is
  162. // provided, it will be populated with the unmatched records removed
  163. // from pvecRecords
  164. // Input: pvecRecords - The records to fill in from the database
  165. // readSet - The set of columns to fill in
  166. // whereSet - The set of columns to query on
  167. // (optional) pvecUnmatchedRecords - A vector to hold records which
  168. // are not found in the database
  169. // Output: true if successful, false otherwise
  170. //-----------------------------------------------------------------------------
  171. template< typename SchClass_t>
  172. bool CSQLAccess::BYieldingReadMultipleRecordsWithWhereColumns( CUtlVector< SchClass_t > *pvecRecords,
  173. const CColumnSet & readSet,
  174. const CColumnSet & whereSet,
  175. CUtlVector< SchClass_t > *pvecUnmatchedRecords /* = NULL */ )
  176. {
  177. AssertMsg( !BInTransaction(), "BYieldingReadMultipleRecordsWithWhereColumns is not supported in a transaction" );
  178. if( BInTransaction() )
  179. return false;
  180. Assert( !readSet.IsEmpty() );
  181. if ( readSet.IsEmpty() )
  182. return false;
  183. if ( pvecUnmatchedRecords )
  184. {
  185. pvecUnmatchedRecords->RemoveAll();
  186. }
  187. // Build the query we'll use for each record
  188. CFmtStr1024 sStatement, sWhere;
  189. BuildSelectStatementText( &sStatement, readSet );
  190. BuildWhereClauseText( &sWhere, whereSet );
  191. sStatement.Append( " WHERE " );
  192. sStatement.Append( sWhere );
  193. BBeginTransaction( CFmtStr1024( "BYieldingReadMultipleRecordsWithWhereColumns() - %s", sStatement.Access() ) );
  194. // Batch this query for each record
  195. FOR_EACH_VEC( *pvecRecords, i )
  196. {
  197. AddRecordParameters( pvecRecords->Element( i ), whereSet );
  198. if( !BYieldingExecute( NULL, sStatement ) )
  199. return false;
  200. }
  201. // Actually run the query
  202. if ( !BCommitTransaction() )
  203. return false;
  204. Assert( GetResultSetCount() == (uint32)pvecRecords->Count() );
  205. if ( GetResultSetCount() != (uint32)pvecRecords->Count() )
  206. return false;
  207. // Get the results. Reading backwards because if a record doesn't find a match we'll
  208. // remove it from the list
  209. FOR_EACH_VEC_BACK( *pvecRecords, i )
  210. {
  211. // make sure the types are the same
  212. IGCSQLResultSet *pResultSet = m_pQueryGroup->GetResults()->GetResultSet( i );
  213. Assert( pResultSet->GetRowCount() <= 1 );
  214. if ( pResultSet->GetRowCount() > 1 )
  215. return false;
  216. if( pResultSet->GetRowCount() == 1 )
  217. {
  218. // We have a record in this set, read it in
  219. FOR_EACH_COLUMN_IN_SET( readSet, nColumnIndex )
  220. {
  221. EGCSQLType eRecordType = readSet.GetColumnInfo( nColumnIndex ).GetType();
  222. EGCSQLType eResultType = pResultSet->GetColumnType( nColumnIndex );
  223. Assert( eResultType == eRecordType );
  224. if( eRecordType != eResultType )
  225. return false;
  226. }
  227. CSQLRecord sqlRecord = GetResultRecord( i, 0 );
  228. FOR_EACH_COLUMN_IN_SET( readSet, nColumnIndex )
  229. {
  230. uint8 *pubData;
  231. uint32 cubData;
  232. DbgVerify( sqlRecord.BGetColumnData( nColumnIndex, &pubData, (int*)&cubData ) );
  233. DbgVerify( pvecRecords->Element( i ).BSetField( readSet.GetColumn( nColumnIndex ), pubData, cubData ) );
  234. }
  235. }
  236. else
  237. {
  238. // This record did not match, remove it and add it to pvecUnmatchedRecords if needed
  239. if ( pvecUnmatchedRecords )
  240. {
  241. pvecUnmatchedRecords->AddToTail( pvecRecords->Element( i ) );
  242. }
  243. pvecRecords->Remove( i );
  244. }
  245. }
  246. return true;
  247. }
  248. //-----------------------------------------------------------------------------
  249. // Purpose: Reads a list of records from the DB according to the specified where
  250. // clause
  251. // Input: pRecordBase - record to insert
  252. // Output: true if successful, false otherwise
  253. //-----------------------------------------------------------------------------
  254. template< typename SchClass_t>
  255. bool CSQLAccess::BYieldingReadRecordsWithWhereClause( CUtlVector< SchClass_t > *pvecRecords, const char *pchWhereClause, const CColumnSet & readSet, const char *pchTopClause )
  256. {
  257. AssertMsg( !BInTransaction(), "BYieldingReadRecordsWithWhereClause is not supported in a transaction" );
  258. if( BInTransaction() )
  259. return false;
  260. Assert( !readSet.IsEmpty() );
  261. CFmtStr1024 sStatement;
  262. BuildSelectStatementText( &sStatement, readSet, pchTopClause );
  263. Assert( pchWhereClause && *pchWhereClause );
  264. if( !pchWhereClause || !(*pchWhereClause) )
  265. return false;
  266. CUtlString sFullStatement = sStatement.Access();
  267. sFullStatement += " WHERE ";
  268. sFullStatement += pchWhereClause;
  269. return BYieldingReadRecordsWithQuery< SchClass_t >( pvecRecords, sFullStatement, readSet );
  270. }
  271. //-----------------------------------------------------------------------------
  272. // Purpose: Inserts a new record into the DB and reads non-insertable fields back
  273. // into the record.
  274. // Input: pRecordBase - record to insert
  275. // Output: true if successful, false otherwise
  276. //-----------------------------------------------------------------------------
  277. template< typename SchClass_t>
  278. bool CSQLAccess::BYieldingReadRecordsWithQuery( CUtlVector< SchClass_t > *pvecRecords, const char *sQuery, const CColumnSet & readSet )
  279. {
  280. AssertMsg( !BInTransaction(), "BYieldingReadRecordsWithQuery is not supported in a transaction" );
  281. if( BInTransaction() )
  282. return false;
  283. Assert(!readSet.IsEmpty() );
  284. if( !BYieldingExecute( NULL, sQuery ) )
  285. return false;
  286. Assert( GetResultSetCount() == 1 );
  287. if ( GetResultSetCount() != 1 )
  288. return false;
  289. // make sure the types are the same
  290. IGCSQLResultSet *pResultSet = m_pQueryGroup->GetResults()->GetResultSet( 0 );
  291. return CopyResultToSchVector( pResultSet, readSet, pvecRecords );
  292. }
  293. } // namespace GCSDK
  294. #include "tier0/memdbgoff.h"
  295. #endif // SQLACCESS_H