Team Fortress 2 Source Code as on 22/4/2020
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

953 lines
33 KiB

  1. //========= Copyright Valve Corporation, All rights reserved. ============//
  2. //
  3. // Purpose: Provides access to SQL at a high level
  4. //
  5. //=============================================================================
  6. #include "stdafx.h"
  7. #include "gcsdk/sqlaccess/sqlaccess.h"
  8. #include "gcsdk/gcsqlquery.h"
  9. // memdbgon must be the last include file in a .cpp file!!!
  10. #include "tier0/memdbgon.h"
  11. template< typename LISTENER_FUNC >
  12. static void RunAndClearListenerList( std::vector< LISTENER_FUNC > &vecListeners )
  13. {
  14. // Let us not underestimate the ability of random listeners to re-enter everything.
  15. std::vector< LISTENER_FUNC > listenerCopy;
  16. listenerCopy.swap( vecListeners );
  17. vecListeners.clear();
  18. // Why would you consider such a thing
  19. DO_NOT_YIELD_THIS_SCOPE();
  20. for ( const auto &listener : listenerCopy )
  21. {
  22. listener();
  23. }
  24. }
  25. namespace GCSDK
  26. {
  27. //------------------------------------------------------------------------------------
  28. // Purpose: Constructor
  29. //------------------------------------------------------------------------------------
  30. CSQLAccess::CSQLAccess( ESchemaCatalog eSchemaCatalog )
  31. : m_eSchemaCatalog( eSchemaCatalog)
  32. , m_pCurrentQuery( NULL )
  33. , m_bInTransaction( false )
  34. {
  35. m_pQueryGroup = CGCSQLQueryGroup::Alloc();
  36. }
  37. //------------------------------------------------------------------------------------
  38. // Purpose: Destructor
  39. //------------------------------------------------------------------------------------
  40. CSQLAccess::~CSQLAccess( )
  41. {
  42. SAFE_RELEASE( m_pQueryGroup );
  43. Assert( !m_pCurrentQuery );
  44. SAFE_DELETE( m_pCurrentQuery );
  45. AssertMsg( !m_bInTransaction, "GCSDK::CSQLAccess object being destroyed with a transaction pending. Use BCommitTransaction or RollbackTransaction to match your BBeginTransaction call." );
  46. }
  47. //------------------------------------------------------------------------------------
  48. // Purpose: Perform a query
  49. //------------------------------------------------------------------------------------
  50. bool CSQLAccess::BYieldingExecute( const char *pchName, const char *pchSQLCommand, uint32 *pcRowsAffected, bool bSpewOnError )
  51. {
  52. if ( NULL == pchName )
  53. {
  54. pchName = pchSQLCommand;
  55. }
  56. bool bStandalone = !BInTransaction();
  57. if( bStandalone )
  58. {
  59. BBeginTransaction( pchName );
  60. }
  61. CurrentQuery()->SetCommand( pchSQLCommand );
  62. m_pQueryGroup->AddQuery( m_pCurrentQuery );
  63. m_pCurrentQuery = NULL;
  64. bool bSuccess = true;
  65. if( bStandalone )
  66. {
  67. bSuccess = BCommitTransaction();
  68. if( bSuccess && pcRowsAffected )
  69. {
  70. *pcRowsAffected = m_pQueryGroup->GetResults()->GetRowsAffected( 0 );
  71. }
  72. }
  73. return bSuccess;
  74. }
  75. //------------------------------------------------------------------------------------
  76. // Purpose: Starts a transaction
  77. //------------------------------------------------------------------------------------
  78. bool CSQLAccess::BBeginTransaction( const char *pchName )
  79. {
  80. Assert( !m_bInTransaction );
  81. if( m_bInTransaction )
  82. return false;
  83. m_pQueryGroup->Clear();
  84. m_pQueryGroup->SetName( pchName );
  85. m_bInTransaction = true;
  86. return true;
  87. }
  88. //------------------------------------------------------------------------------------
  89. // Purpose: Returns the string last passed to BBeginTransaction
  90. //------------------------------------------------------------------------------------
  91. const char *CSQLAccess::PchTransactionName( ) const
  92. {
  93. return m_pQueryGroup->PchName();
  94. }
  95. //------------------------------------------------------------------------------------
  96. // Purpose: Commits a transaction to the database
  97. //------------------------------------------------------------------------------------
  98. bool CSQLAccess::BCommitTransaction( bool bAllowEmpty )
  99. {
  100. Assert( BInTransaction() );
  101. if( !BInTransaction() )
  102. return false;
  103. if( !m_pCurrentQuery && !m_pQueryGroup->GetStatementCount() )
  104. {
  105. if( bAllowEmpty )
  106. {
  107. // No-op success
  108. m_bInTransaction = false;
  109. RunListeners_Commit();
  110. return true;
  111. }
  112. else
  113. {
  114. AssertMsg1( false, "BCommitTransaction with empty transaction at %s", m_pQueryGroup->PchName() );
  115. return false;
  116. }
  117. }
  118. AssertMsg1( !m_pCurrentQuery, "Unexecuted query present in BCommitTransaction: %s", m_pCurrentQuery->PchCommand() );
  119. if( m_pCurrentQuery )
  120. return false;
  121. m_bInTransaction = false;
  122. if( !GJobCur().BYieldingRunQuery( m_pQueryGroup, m_eSchemaCatalog ) )
  123. {
  124. // Notify listeners that the transaction did not succeed
  125. RunListeners_Rollback();
  126. return false;
  127. }
  128. // The transaction presumably did make the database, so we do not notify rollback listeners beyond here.
  129. RunListeners_Commit();
  130. if( !m_pQueryGroup->GetResults() )
  131. return false;
  132. return true;
  133. }
  134. //------------------------------------------------------------------------------------
  135. // Purpose: Rolls back a transaction and clears any queries
  136. //------------------------------------------------------------------------------------
  137. void CSQLAccess::RollbackTransaction()
  138. {
  139. bool bWasTransaction = BInTransaction();
  140. Assert( bWasTransaction );
  141. SAFE_DELETE( m_pCurrentQuery );
  142. m_bInTransaction = false;
  143. if ( bWasTransaction )
  144. {
  145. RunListeners_Rollback();
  146. }
  147. else
  148. {
  149. m_vecCommitListeners.clear();
  150. m_vecRollbackListeners.clear();
  151. }
  152. }
  153. //------------------------------------------------------------------------------------
  154. // Purpose: Adds a listener to be called synchronously should the transaction successfully commit
  155. //------------------------------------------------------------------------------------
  156. void CSQLAccess::AddCommitListener( std::function<void (void)> &&listener )
  157. {
  158. if ( !BInTransaction() )
  159. {
  160. AssertMsg( BInTransaction(), "Adding a listener to a non-transaction access, will never fire" );
  161. return;
  162. }
  163. m_vecCommitListeners.push_back( std::move( listener ) );
  164. }
  165. //------------------------------------------------------------------------------------
  166. // Purpose: Adds a listener to be called synchronously should the transaction fail or explicitly rollback
  167. //------------------------------------------------------------------------------------
  168. void CSQLAccess::AddRollbackListener( std::function<void (void)> &&listener )
  169. {
  170. if ( !BInTransaction() )
  171. {
  172. AssertMsg( BInTransaction(), "Adding a listener to a non-transaction access, will never fire" );
  173. return;
  174. }
  175. m_vecRollbackListeners.push_back( std::move( listener ) );
  176. }
  177. //------------------------------------------------------------------------------------
  178. // Purpose: Notifies listeners of successful commit.
  179. //------------------------------------------------------------------------------------
  180. void CSQLAccess::RunListeners_Commit()
  181. {
  182. RunAndClearListenerList( m_vecCommitListeners );
  183. // Clear the unused set
  184. m_vecRollbackListeners.clear();
  185. }
  186. //------------------------------------------------------------------------------------
  187. // Purpose: Notifies listeners of a implicitly or explicitly rolled back transactions and clears the listener list.
  188. //------------------------------------------------------------------------------------
  189. void CSQLAccess::RunListeners_Rollback()
  190. {
  191. RunAndClearListenerList( m_vecRollbackListeners );
  192. // Clear the unused set
  193. m_vecCommitListeners.clear();
  194. }
  195. //------------------------------------------------------------------------------------
  196. // Purpose: Perform a query that returns a single string
  197. //------------------------------------------------------------------------------------
  198. CSQLAccess::EReadSingleResultResult CSQLAccess::BYieldingExecuteSingleResultDataInternal( const char *pchName, const char *pchSQLCommand, EGCSQLType eType, uint8 **ppubData, uint32 *punSize, uint32 *pcRowsAffected, bool bHasDefaultValue )
  199. {
  200. AssertMsg( !BInTransaction(), "BYieldingExecuteSingleResultData is not supported in a transaction" );
  201. if( BInTransaction() )
  202. return eReadSingle_Error;
  203. bool bRet = BYieldingExecute( pchName, pchSQLCommand, pcRowsAffected );
  204. if ( !bRet )
  205. return eReadSingle_Error;
  206. if( m_pQueryGroup->GetResults()->GetResultSetCount() != 1 )
  207. {
  208. AssertMsg1( false, "Expected single result set, found %d", m_pQueryGroup->GetResults()->GetResultSetCount() );
  209. return eReadSingle_Error;
  210. }
  211. IGCSQLResultSet *pResultSet = m_pQueryGroup->GetResults()->GetResultSet( 0 );
  212. // If we have a default value, getting back zero rows is acceptable.
  213. if( pResultSet->GetRowCount() == 0 && bHasDefaultValue )
  214. {
  215. return eReadSingle_UseDefault;
  216. }
  217. // If we either have more than one row or no default value specified, that's an error.
  218. if( pResultSet->GetRowCount() != 1 )
  219. {
  220. AssertMsg1( false, "Expected single result, found %d", pResultSet->GetRowCount() );
  221. return eReadSingle_Error;
  222. }
  223. if( pResultSet->GetColumnCount() != 1 )
  224. {
  225. AssertMsg1( false, "Expected single column, found %d", pResultSet->GetColumnCount() );
  226. return eReadSingle_Error;
  227. }
  228. if( pResultSet->GetColumnType( 0 ) != eType )
  229. {
  230. AssertMsg2( false, "Expected column of type %s, found %s", PchNameFromEGCSQLType( eType ), PchNameFromEGCSQLType( pResultSet->GetColumnType( 0 ) ) );
  231. return eReadSingle_Error;
  232. }
  233. return pResultSet->GetData( 0, 0, ppubData, punSize )
  234. ? eReadSingle_ResultFound
  235. : eReadSingle_Error;
  236. }
  237. //------------------------------------------------------------------------------------
  238. // Purpose: Perform a query that returns a single string
  239. //------------------------------------------------------------------------------------
  240. bool CSQLAccess::BYieldingExecuteString( const char *pchName, const char *pchSQLCommand, CFmtStr1024 *psResult, uint32 *pcRowsAffected )
  241. {
  242. uint8 *pubData;
  243. uint32 cubData;
  244. if( CSQLAccess::BYieldingExecuteSingleResultDataInternal( pchName, pchSQLCommand, k_EGCSQLType_String, &pubData, &cubData, pcRowsAffected, false ) != eReadSingle_ResultFound )
  245. return false;
  246. *psResult = (char *)pubData;
  247. return true;
  248. }
  249. //------------------------------------------------------------------------------------
  250. // Purpose: Perform a query that returns a single int
  251. //------------------------------------------------------------------------------------
  252. bool CSQLAccess::BYieldingExecuteScalarInt( const char *pchName, const char *pchSQLCommand, int *pnResult, uint32 *pcRowsAffected )
  253. {
  254. return BYieldingExecuteSingleResult<int32, uint32>( pchName, pchSQLCommand, k_EGCSQLType_int32, pnResult, pcRowsAffected );
  255. }
  256. bool CSQLAccess::BYieldingExecuteScalarIntWithDefault( const char *pchName, const char *pchSQLCommand, int *pnResult, int iDefaultValue, uint32 *pcRowsAffected )
  257. {
  258. return BYieldingExecuteSingleResultWithDefault<int32, uint32>( pchName, pchSQLCommand, k_EGCSQLType_int32, pnResult, iDefaultValue, pcRowsAffected );
  259. }
  260. //------------------------------------------------------------------------------------
  261. // Purpose: Perform a query that returns a single uint32
  262. //------------------------------------------------------------------------------------
  263. bool CSQLAccess::BYieldingExecuteScalarUint32( const char *pchName, const char *pchSQLCommand, uint32 *punResult, uint32 *pcRowsAffected )
  264. {
  265. return BYieldingExecuteSingleResult<uint32, uint32>( pchName, pchSQLCommand, k_EGCSQLType_int32, punResult, pcRowsAffected );
  266. }
  267. bool CSQLAccess::BYieldingExecuteScalarUint32WithDefault( const char *pchName, const char *pchSQLCommand, uint32 *punResult, uint32 unDefaultValue, uint32 *pcRowsAffected )
  268. {
  269. return BYieldingExecuteSingleResultWithDefault<uint32, uint32>( pchName, pchSQLCommand, k_EGCSQLType_int32, punResult, unDefaultValue, pcRowsAffected );
  270. }
  271. //------------------------------------------------------------------------------------
  272. // Purpose: A bunch of pass throughs to the query itself
  273. //------------------------------------------------------------------------------------
  274. void CSQLAccess::AddBindParam( const char *pchValue )
  275. {
  276. CurrentQuery()->AddBindParam( pchValue );
  277. }
  278. void CSQLAccess::AddBindParam( const int16 nValue )
  279. {
  280. CurrentQuery()->AddBindParam( nValue );
  281. }
  282. void CSQLAccess::AddBindParam( const uint16 uValue )
  283. {
  284. CurrentQuery()->AddBindParam( uValue );
  285. }
  286. void CSQLAccess::AddBindParam( const int32 nValue )
  287. {
  288. CurrentQuery()->AddBindParam( nValue );
  289. }
  290. void CSQLAccess::AddBindParam( const uint32 uValue )
  291. {
  292. CurrentQuery()->AddBindParam( uValue );
  293. }
  294. void CSQLAccess::AddBindParam( const uint64 ulValue )
  295. {
  296. CurrentQuery()->AddBindParam( ulValue );
  297. }
  298. void CSQLAccess::AddBindParam( const uint8 *ubValue, const int cubValue )
  299. {
  300. CurrentQuery()->AddBindParam( ubValue, cubValue );
  301. }
  302. void CSQLAccess::AddBindParam( const float fValue )
  303. {
  304. CurrentQuery()->AddBindParam( fValue );
  305. }
  306. void CSQLAccess::AddBindParam( const double dValue )
  307. {
  308. CurrentQuery()->AddBindParam( dValue );
  309. }
  310. void CSQLAccess::AddBindParamRaw( EGCSQLType eType, const byte *pubData, uint32 cubData )
  311. {
  312. CurrentQuery()->AddBindParamRaw( eType, pubData, cubData );
  313. }
  314. void CSQLAccess::ClearParams()
  315. {
  316. if( m_pCurrentQuery )
  317. {
  318. delete m_pCurrentQuery;
  319. m_pCurrentQuery = NULL;
  320. }
  321. }
  322. IGCSQLResultSetList *CSQLAccess::GetResults()
  323. {
  324. return m_pQueryGroup->GetResults();
  325. }
  326. //------------------------------------------------------------------------------------
  327. // Purpose: Returns the number of result sets
  328. //------------------------------------------------------------------------------------
  329. uint32 CSQLAccess::GetResultSetCount()
  330. {
  331. if( m_pQueryGroup->GetResults() )
  332. return m_pQueryGroup->GetResults()->GetResultSetCount();
  333. else
  334. return 0;
  335. }
  336. //------------------------------------------------------------------------------------
  337. // Purpose: Returns the number of rows in a result set
  338. //------------------------------------------------------------------------------------
  339. uint32 CSQLAccess::GetResultSetRowCount( uint32 unResultSet )
  340. {
  341. if( m_pQueryGroup->GetResults() && unResultSet < m_pQueryGroup->GetResults()->GetResultSetCount() )
  342. return m_pQueryGroup->GetResults()->GetResultSet( unResultSet )->GetRowCount();
  343. else
  344. return 0;
  345. }
  346. //------------------------------------------------------------------------------------
  347. // Purpose: Returns a CSQLRecord object that represents a row in a result set
  348. //------------------------------------------------------------------------------------
  349. CSQLRecord CSQLAccess::GetResultRecord( uint32 unResultSet, uint32 unRow )
  350. {
  351. if( m_pQueryGroup->GetResults() && unResultSet < m_pQueryGroup->GetResults()->GetResultSetCount() )
  352. {
  353. IGCSQLResultSet *pResultSet = m_pQueryGroup->GetResults()->GetResultSet( unResultSet );
  354. if( unRow < pResultSet->GetRowCount() )
  355. return CSQLRecord( unRow, pResultSet );
  356. }
  357. return CSQLRecord(); // if there was a problem return an empty record
  358. }
  359. //-----------------------------------------------------------------------------
  360. // Purpose: Inserts a new record into the DS
  361. // Input: pRecordBase - record to insert
  362. // Output: true if successful, false otherwise
  363. //-----------------------------------------------------------------------------
  364. bool CSQLAccess::BYieldingInsertRecord( const CRecordBase *pRecordBase )
  365. {
  366. ClearParams();
  367. const CRecordInfo *pRecordInfo = pRecordBase->GetPSchema()->GetRecordInfo();
  368. int cColumns = pRecordInfo->GetNumColumns();
  369. for ( int nColumn = 0; nColumn < cColumns; nColumn++ )
  370. {
  371. const CColumnInfo &columnInfo = pRecordInfo->GetColumnInfo( nColumn );
  372. if ( !columnInfo.BIsInsertable() )
  373. continue;
  374. uint8 *pubData;
  375. uint32 cubData;
  376. DbgVerify( pRecordBase->BGetField( nColumn, &pubData, &cubData ) );
  377. CurrentQuery()->AddBindParamRaw( columnInfo.GetType(), pubData, cubData );
  378. }
  379. uint32 nRows;
  380. const char *pchStatement = pRecordBase->GetPSchema()->GetInsertStatementText();
  381. bool bRet = BYieldingExecute( pchStatement, pchStatement, &nRows );
  382. return ( nRows == 1 || BInTransaction() ) && bRet;
  383. }
  384. //-----------------------------------------------------------------------------
  385. // Purpose: Inserts a new record into the DS if such row doesn't exist
  386. // Input: pRecordBase - record to insert
  387. // Output: true if successful, false otherwise
  388. //-----------------------------------------------------------------------------
  389. bool CSQLAccess::BYieldingInsertWhenNotMatchedOnPK( CRecordBase *pRecordBase )
  390. {
  391. ClearParams();
  392. const CRecordInfo *pRecordInfo = pRecordBase->GetPSchema()->GetRecordInfo();
  393. int cColumns = pRecordInfo->GetNumColumns();
  394. for ( int nColumn = 0; nColumn < cColumns; nColumn++ )
  395. {
  396. const CColumnInfo &columnInfo = pRecordInfo->GetColumnInfo( nColumn );
  397. if ( !columnInfo.BIsInsertable() )
  398. {
  399. Assert( columnInfo.BIsInsertable() );
  400. return false;
  401. }
  402. uint8 *pubData;
  403. uint32 cubData;
  404. DbgVerify( pRecordBase->BGetField( nColumn, &pubData, &cubData ) );
  405. CurrentQuery()->AddBindParamRaw( columnInfo.GetType(), pubData, cubData );
  406. }
  407. uint32 nRows;
  408. const char *pchStatement = pRecordBase->GetPSchema()->GetMergeStatementTextOnPKWhenNotMatchedInsert();
  409. bool bRet = BYieldingExecute( pchStatement, pchStatement, &nRows );
  410. return ( nRows == 1 || nRows == 0 || BInTransaction() ) && bRet;
  411. }
  412. //-----------------------------------------------------------------------------
  413. // Purpose: Inserts a new record into the DS if such row doesn't exist
  414. // updates an existing row if such row is matched by PK
  415. // Input: pRecordBase - record to insert
  416. // Output: true if successful, false otherwise
  417. //-----------------------------------------------------------------------------
  418. bool CSQLAccess::BYieldingInsertOrUpdateOnPK( CRecordBase *pRecordBase )
  419. {
  420. ClearParams();
  421. const CRecordInfo *pRecordInfo = pRecordBase->GetPSchema()->GetRecordInfo();
  422. int cColumns = pRecordInfo->GetNumColumns();
  423. for ( int nColumn = 0; nColumn < cColumns; nColumn++ )
  424. {
  425. const CColumnInfo &columnInfo = pRecordInfo->GetColumnInfo( nColumn );
  426. if ( !columnInfo.BIsInsertable() )
  427. {
  428. Assert( columnInfo.BIsInsertable() );
  429. return false;
  430. }
  431. uint8 *pubData;
  432. uint32 cubData;
  433. DbgVerify( pRecordBase->BGetField( nColumn, &pubData, &cubData ) );
  434. CurrentQuery()->AddBindParamRaw( columnInfo.GetType(), pubData, cubData );
  435. }
  436. uint32 nRows;
  437. const char *pchStatement = pRecordBase->GetPSchema()->GetMergeStatementTextOnPKWhenMatchedUpdateWhenNotMatchedInsert();
  438. bool bRet = BYieldingExecute( pchStatement, pchStatement, &nRows );
  439. return ( nRows == 1 || BInTransaction() ) && bRet;
  440. }
  441. //-----------------------------------------------------------------------------
  442. // Purpose: Inserts a new record into the DB and reads non-insertable fields back
  443. // into the record.
  444. // Input: pRecordBase - record to insert
  445. // Output: true if successful, false otherwise
  446. //-----------------------------------------------------------------------------
  447. bool CSQLAccess::BYieldingInsertWithIdentity( CRecordBase* pRecordBase )
  448. {
  449. AssertMsg( !BInTransaction(), "BYieldingInsertWithIdentity is not supported in a transaction" );
  450. if( BInTransaction() )
  451. return false;
  452. ClearParams();
  453. TSQLCmdStr sStatement;
  454. CUtlVector<int> vecOutputFields;
  455. CRecordInfo *pRecordInfo = pRecordBase->GetPSchema()->GetRecordInfo();
  456. BuildInsertAndReadStatementText( &sStatement, &vecOutputFields, pRecordInfo );
  457. AssertMsg( vecOutputFields.Count() > 0, "BYieldingInsertAndReadRecord called for a record type with no non-insertable columns" );
  458. if ( vecOutputFields.Count() == 0 )
  459. return false;
  460. int cColumns = pRecordInfo->GetNumColumns();
  461. for ( int nColumn = 0; nColumn < cColumns; nColumn++ )
  462. {
  463. const CColumnInfo &columnInfo = pRecordInfo->GetColumnInfo( nColumn );
  464. if ( !columnInfo.BIsInsertable() )
  465. {
  466. continue;
  467. }
  468. uint8 *pubData;
  469. uint32 cubData;
  470. DbgVerify( pRecordBase->BGetField( nColumn, &pubData, &cubData ) );
  471. CurrentQuery()->AddBindParamRaw( columnInfo.GetType(), pubData, cubData );
  472. }
  473. bool bRet = BYieldingExecute( sStatement, sStatement );
  474. if( !bRet )
  475. return false;
  476. Assert( 1 == GetResultSetCount() );
  477. if ( 1 != GetResultSetCount() )
  478. return false;
  479. IGCSQLResultSet *pResultSet = m_pQueryGroup->GetResults()->GetResultSet( 0 );
  480. Assert( 1 == pResultSet->GetRowCount() );
  481. if ( 1 != pResultSet->GetRowCount() )
  482. return false;
  483. Assert( (uint32)vecOutputFields.Count() == pResultSet->GetColumnCount() );
  484. if ( (uint32)vecOutputFields.Count() != pResultSet->GetColumnCount() )
  485. return false;
  486. for( uint32 nColumn = 0; nColumn < pResultSet->GetColumnCount(); nColumn++ )
  487. {
  488. uint8 *pubData;
  489. uint32 cubData;
  490. DbgVerify( pResultSet->GetData( 0, nColumn, &pubData, &cubData ) );
  491. int nSchColumn = vecOutputFields[nColumn];
  492. Assert( pResultSet->GetColumnType( nColumn ) == pRecordInfo->GetColumnInfo( nSchColumn ).GetType() );
  493. DbgVerify( pRecordBase->BSetField( nSchColumn, pubData, cubData ) );
  494. }
  495. return true;
  496. }
  497. //-----------------------------------------------------------------------------
  498. // Purpose: Reads a list of records from the DB according to the specified where
  499. // clause
  500. // Input: pRecordBase - record to read
  501. // readSet - The set of columns to read
  502. // whereSet - The set of columns to query on
  503. // Output: true if successful, false otherwise
  504. //-----------------------------------------------------------------------------
  505. EResult CSQLAccess::YieldingReadRecordWithWhereColumns( CRecordBase *pRecord, const CColumnSet & readSet, const CColumnSet & whereSet, const char* pchOrderClause )
  506. {
  507. AssertMsg( !BInTransaction(), "BYieldingReadRecordWithWhereColumns is not supported in a transaction" );
  508. if( BInTransaction() )
  509. return k_EResultInvalidState;
  510. //if there is an order by clause, only take the top one, if there isn't, then validate that we have a single instance
  511. const char* pszTopClause = ( pchOrderClause ) ? "TOP (1)" : "TOP (2)";
  512. TSQLCmdStr sStatement;
  513. BuildSelectStatementText( &sStatement, readSet, pszTopClause );
  514. // if we actually have some columns for the where clause,
  515. // append a where clause.
  516. if( whereSet.GetColumnCount() )
  517. {
  518. sStatement.Append( " WHERE " );
  519. AppendWhereClauseText( &sStatement, whereSet );
  520. AddRecordParameters( *pRecord, whereSet );
  521. }
  522. //append the order by if they added one
  523. if( pchOrderClause )
  524. {
  525. sStatement.Append( " ORDER BY " );
  526. sStatement.Append( pchOrderClause );
  527. }
  528. Assert(!readSet.IsEmpty() );
  529. if( !BYieldingExecute( sStatement, sStatement ) )
  530. return k_EResultFail;
  531. if ( GetResultSetCount() != 1 )
  532. {
  533. AssertMsg( GetResultSetCount() == 1, "Unexpected number of result sets returned from select statement" );
  534. return k_EResultFail;
  535. }
  536. // make sure the types are the same
  537. IGCSQLResultSet *pResultSet = m_pQueryGroup->GetResults()->GetResultSet( 0 );
  538. if ( pResultSet->GetRowCount() == 0 )
  539. return k_EResultNoMatch;
  540. //note that since we only take the top one when there is an order by clause, we don't need to handle that case down here, only if top 2 is selected
  541. if( pResultSet->GetRowCount() != 1 )
  542. {
  543. // Make sure we aren't failing because there are multiple matching records.
  544. // That is probably a misuse of the API or some unexpected condition.
  545. AssertMsg1( false, "BYieldingReadRecordWithWhereColumns from %s failing because multiple records match WHERE clause", readSet.GetRecordInfo()->GetName() );
  546. return k_EResultLimitExceeded;
  547. }
  548. FOR_EACH_COLUMN_IN_SET( readSet, nColumnIndex )
  549. {
  550. EGCSQLType eRecordType = readSet.GetColumnInfo( nColumnIndex ).GetType();
  551. EGCSQLType eResultType = pResultSet->GetColumnType( nColumnIndex );
  552. AssertMsg2( eResultType == eRecordType, "Column %d type mismatch in %s", nColumnIndex, readSet.GetRecordInfo()->GetName() );
  553. if( eRecordType != eResultType )
  554. return k_EResultInvalidParam;
  555. }
  556. CSQLRecord sqlRecord = GetResultRecord( 0, 0 );
  557. FOR_EACH_COLUMN_IN_SET( readSet, nColumnIndex )
  558. {
  559. uint8 *pubData;
  560. uint32 cubData;
  561. DbgVerify( sqlRecord.BGetColumnData( nColumnIndex, &pubData, (int*)&cubData ) );
  562. DbgVerify( pRecord->BSetField( readSet.GetColumn( nColumnIndex), pubData, cubData ) );
  563. }
  564. return k_EResultOK;
  565. }
  566. //-----------------------------------------------------------------------------
  567. // Purpose: Updates a record in the DB
  568. // Input: record - data source for columns to match against (whereColumns) and
  569. // columns to assign (updateColumns)
  570. // whereColumns - columns to match against
  571. // updateColumns - columns to update
  572. // Output: true if successful, false otherwise
  573. //-----------------------------------------------------------------------------
  574. bool CSQLAccess::BYieldingUpdateRecord( const CRecordBase & record, const CColumnSet & whereColumns, const CColumnSet & updateColumns, const CSQLOutputParams *pOptionalOutputParams /* = NULL */ )
  575. {
  576. return BYieldingUpdateRecords( record, whereColumns, record, updateColumns, pOptionalOutputParams );
  577. }
  578. //-----------------------------------------------------------------------------
  579. // Purpose:
  580. //-----------------------------------------------------------------------------
  581. bool CSQLAccess::BYieldingUpdateRecords( const CRecordBase & whereRecord, const CColumnSet & whereColumns, const CRecordBase & updateRecord, const CColumnSet & updateColumns, const CSQLOutputParams *pOptionalOutputParams /* = NULL */ )
  582. {
  583. ClearParams();
  584. Assert( whereColumns.GetRecordInfo() == updateColumns.GetRecordInfo() );
  585. if ( whereColumns.GetRecordInfo() != updateColumns.GetRecordInfo() )
  586. return false;
  587. Assert( whereColumns.GetRecordInfo() == whereRecord.GetPSchema()->GetRecordInfo() );
  588. if ( whereColumns.GetRecordInfo() != whereRecord.GetPSchema()->GetRecordInfo() )
  589. return false;
  590. Assert( whereColumns.GetRecordInfo() == updateRecord.GetPSchema()->GetRecordInfo() );
  591. if ( whereColumns.GetRecordInfo() != updateRecord.GetPSchema()->GetRecordInfo() )
  592. return false;
  593. AssertMsg( !updateColumns.IsEmpty(), "Someone is calling BYieldingUpdateRecord with no columns to update." );
  594. if ( updateColumns.IsEmpty() )
  595. return false;
  596. // add the columns we're updating as bound params
  597. TSQLCmdStr sStatement;
  598. BuildUpdateStatementText( &sStatement, updateColumns );
  599. AddRecordParameters( updateRecord, updateColumns );
  600. // did the users specify an OUTPUT block?
  601. if ( pOptionalOutputParams )
  602. {
  603. TSQLCmdStr sOutput;
  604. BuildOutputClauseText( &sOutput, pOptionalOutputParams->GetColumnSet() );
  605. sStatement.Append( sOutput );
  606. AddRecordParameters( pOptionalOutputParams->GetRecord(), pOptionalOutputParams->GetColumnSet() );
  607. }
  608. if ( !whereColumns.IsEmpty() )
  609. {
  610. sStatement.Append( " WHERE " );
  611. AppendWhereClauseText( &sStatement, whereColumns );
  612. // add the columns we're querying on as bound params
  613. AddRecordParameters( whereRecord, whereColumns );
  614. }
  615. return BYieldingExecute( sStatement, sStatement );
  616. }
  617. //-----------------------------------------------------------------------------
  618. // Purpose: Deletes this record's row in the table
  619. // Input: record - record to delete
  620. // whereColumns - columns to use when searching for this record
  621. //-----------------------------------------------------------------------------
  622. bool CSQLAccess::BYieldingDeleteRecords( const CRecordBase & record, const CColumnSet & whereColumns )
  623. {
  624. Assert( whereColumns.GetRecordInfo() == record.GetPSchema()->GetRecordInfo() );
  625. if ( whereColumns.GetRecordInfo() != record.GetPSchema()->GetRecordInfo() )
  626. return false;
  627. ClearParams();
  628. AddRecordParameters( record, whereColumns );
  629. TSQLCmdStr sStatement;
  630. BuildDeleteStatementText( &sStatement, record.GetPRecordInfo() );
  631. sStatement.Append( " WHERE " );
  632. AppendWhereClauseText( &sStatement, whereColumns );
  633. uint32 unRowsAffected;
  634. if( !BYieldingExecute( sStatement, sStatement, &unRowsAffected ) )
  635. return false;
  636. return unRowsAffected > 0 || BInTransaction();
  637. }
  638. //--------------------------------------------------------------------------------------------------------------------------------
  639. // CSQLUpdateOrInsert
  640. //--------------------------------------------------------------------------------------------------------------------------------
  641. CSQLUpdateOrInsert::CSQLUpdateOrInsert( const char* pszName, int nTable, const CColumnSet & whereColumns, const CColumnSet & updateColumns, const char* pszWhereClause, const char* pszUpdateClause )
  642. {
  643. const CRecordInfo* pRecordInfo = GSchemaFull().GetSchema( nTable ).GetRecordInfo();
  644. //how many columns do we have
  645. const int nNumColumns = pRecordInfo->GetNumColumns();
  646. TSQLCmdStr sStatement;
  647. sStatement = "MERGE INTO ";
  648. sStatement.Append( GSchemaFull().GetDefaultSchemaNameForCatalog( pRecordInfo->GetESchemaCatalog() ) );
  649. sStatement.Append( '.' );
  650. sStatement.Append( pRecordInfo->GetName() );
  651. sStatement.Append( " WITH(HOLDLOCK) AS D USING(VALUES(" );
  652. sStatement.AppendFormat( "%.*s", GetInsertArgStringChars( nNumColumns ), GetInsertArgString() );
  653. sStatement.Append( "))AS S(" );
  654. //add each column that we are adding the values for, along with the parameter from the structure
  655. for( int nCurrColumn = 0; nCurrColumn < nNumColumns; nCurrColumn++ )
  656. {
  657. const CColumnInfo& colInfo = pRecordInfo->GetColumnInfo( nCurrColumn );
  658. if( nCurrColumn != 0 )
  659. sStatement.Append( ',' );
  660. sStatement.Append( colInfo.GetName() );
  661. }
  662. //our where clause
  663. sStatement.Append( ")ON " );
  664. if( pszWhereClause )
  665. {
  666. sStatement.Append( pszWhereClause );
  667. }
  668. else
  669. {
  670. FOR_EACH_COLUMN_IN_SET( whereColumns, nCurrColumn )
  671. {
  672. const char* pszColName = pRecordInfo->GetColumnInfo( whereColumns.GetColumn( nCurrColumn ) ).GetName();
  673. if( nCurrColumn > 0 )
  674. sStatement.Append( " AND " );
  675. sStatement.AppendFormat( "D.%s=S.%s", pszColName, pszColName );
  676. }
  677. }
  678. //our update clause (if they have provided fields that they want to update)
  679. if( pszUpdateClause || !updateColumns.IsEmpty() )
  680. {
  681. sStatement.Append( " WHEN MATCHED THEN UPDATE SET " );
  682. if( pszUpdateClause )
  683. {
  684. sStatement.Append( pszUpdateClause );
  685. }
  686. else
  687. {
  688. FOR_EACH_COLUMN_IN_SET( updateColumns, nCurrColumn )
  689. {
  690. const char* pszColName = pRecordInfo->GetColumnInfo( updateColumns.GetColumn( nCurrColumn ) ).GetName();
  691. if( nCurrColumn > 0 )
  692. sStatement.Append( ',' );
  693. sStatement.AppendFormat( "%s=S.%s", pszColName, pszColName );
  694. }
  695. }
  696. }
  697. //our insert clause
  698. sStatement.Append( " WHEN NOT MATCHED THEN INSERT(" );
  699. bool bFirstColumn = true;
  700. for( int nCurrColumn = 0; nCurrColumn < nNumColumns; nCurrColumn++ )
  701. {
  702. const CColumnInfo& colInfo = pRecordInfo->GetColumnInfo( nCurrColumn );
  703. if( !colInfo.BIsInsertable() )
  704. continue;
  705. if( !bFirstColumn )
  706. sStatement.Append( ',' );
  707. bFirstColumn = false;
  708. sStatement.Append( colInfo.GetName() );
  709. }
  710. sStatement.Append( ")VALUES(" );
  711. bFirstColumn = true;
  712. for( int nCurrColumn = 0; nCurrColumn < nNumColumns; nCurrColumn++ )
  713. {
  714. const CColumnInfo& colInfo = pRecordInfo->GetColumnInfo( nCurrColumn );
  715. if( !colInfo.BIsInsertable() )
  716. continue;
  717. if( !bFirstColumn )
  718. sStatement.Append( ',' );
  719. bFirstColumn = false;
  720. sStatement.AppendFormat( "S.%s", colInfo.GetName() );
  721. }
  722. sStatement.Append( ");" );
  723. //save our results so we can execute it in the future
  724. m_nTable = nTable;
  725. m_sName = pszName;
  726. m_sQuery = sStatement;
  727. }
  728. bool CSQLUpdateOrInsert::BYieldingExecute( CSQLAccess& sqlAccess, const CRecordBase& record, uint32 *out_punRowsAffected /* = NULL */ ) const
  729. {
  730. AssertMsg2( record.GetITable() == m_nTable, "Error: Merge was compiled for table %s, but was attempted to be executed against %s", GSchemaFull().GetSchema( m_nTable ).GetRecordInfo()->GetName(), record.GetPRecordInfo()->GetName() );
  731. const CRecordInfo* pRecordInfo = record.GetPRecordInfo();
  732. //how many columns do we have
  733. const int nNumColumns = pRecordInfo->GetNumColumns();
  734. sqlAccess.ClearParams();
  735. for( int nCurrColumn = 0; nCurrColumn < nNumColumns; nCurrColumn++ )
  736. {
  737. const CColumnInfo& colInfo = pRecordInfo->GetColumnInfo( nCurrColumn );
  738. uint8 *pubData;
  739. uint32 cubData;
  740. DbgVerify( record.BGetField( nCurrColumn, &pubData, &cubData ) );
  741. sqlAccess.AddBindParamRaw( colInfo.GetType(), pubData, cubData );
  742. }
  743. return sqlAccess.BYieldingExecute( m_sName, m_sQuery, out_punRowsAffected );
  744. }
  745. //-----------------------------------------------------------------------------
  746. // Purpose: Adds bind parameters to the list based on a set of fields in a record
  747. // Input: record - record to insert
  748. // columnSet - The set of columns to add as params
  749. //-----------------------------------------------------------------------------
  750. void CSQLAccess::AddRecordParameters( const CRecordBase &record, const CColumnSet & columnSet )
  751. {
  752. Assert( record.GetPSchema()->GetRecordInfo() == columnSet.GetRecordInfo() );
  753. if ( record.GetPSchema()->GetRecordInfo() != columnSet.GetRecordInfo() )
  754. return;
  755. FOR_EACH_COLUMN_IN_SET( columnSet, nColumnIndex )
  756. {
  757. const CColumnInfo &columnInfo = columnSet.GetColumnInfo( nColumnIndex );
  758. uint8 *pubData;
  759. uint32 cubData;
  760. DbgVerify( record.BGetField( columnSet.GetColumn( nColumnIndex ), &pubData, &cubData ) );
  761. EGCSQLType eType = columnInfo.GetType();
  762. CurrentQuery()->AddBindParamRaw( eType, pubData, cubData );
  763. }
  764. }
  765. //-----------------------------------------------------------------------------
  766. // Purpose: Deletes all records from a table
  767. // Input: iTable - table to wipe
  768. // Output: true if the operation was successful
  769. // Note: PERFORMANCE WARNING: this is slow on big tables, not intended for use
  770. // in production
  771. //-----------------------------------------------------------------------------
  772. bool CSQLAccess::BYieldingWipeTable( int iTable )
  773. {
  774. // make a wipe operation
  775. CRecordInfo *pRecordInfo = GSchemaFull().GetSchema( iTable ).GetRecordInfo();
  776. CUtlString buf;
  777. buf.Format( "DELETE FROM %s", pRecordInfo->GetName() );
  778. return BYieldingExecute( buf.String(), buf.String() );
  779. }
  780. //-----------------------------------------------------------------------------
  781. // Purpose: Returns the current query to add stuff to, creating it if there isn't
  782. // already a current query
  783. //-----------------------------------------------------------------------------
  784. CGCSQLQuery *CSQLAccess::CurrentQuery()
  785. {
  786. if( m_pCurrentQuery )
  787. return m_pCurrentQuery;
  788. m_pCurrentQuery = new CGCSQLQuery();
  789. return m_pCurrentQuery;
  790. }
  791. } // namespace GCSDK