//===== Copyright © 1996-2011, Valve Corporation, All rights reserved. ======// // // Purpose: // // A class allowing storage of a sparse NxN matirx as an array of sparse rows //===========================================================================// #ifndef SPARSEMATRIX_H #define SPARSEMATRIX_H #include "tier1/utlvector.h" /// CSparseMatrix is a matrix which compresses each row individually, not storing the zeros. NOte, /// that while you can randomly set any element in a CSparseMatrix, you really want to do it top to /// bottom or you will have bad perf as data is moved around to insert new elements. class CSparseMatrix { public: struct NonZeroValueDescriptor_t { int m_nColumnNumber; float m_flValue; }; struct RowDescriptor_t { int m_nNonZeroCount; // number of non-zero elements in the row int m_nDataIndex; // index of NonZeroValueDescriptor_t for the first non-zero value }; int m_nNumRows; int m_nNumCols; CUtlVector m_rowDescriptors; CUtlVector m_entries; int m_nHighestRowAppendedTo; void AdjustAllRowIndicesAfter( int nStartRow, int nDelta ); public: FORCEINLINE float Element( int nRow, int nCol ) const; void SetElement( int nRow, int nCol, float flValue ); void SetDimensions( int nNumRows, int nNumCols ); void AppendElement( int nRow, int nCol, float flValue ); void FinishedAppending( void ); FORCEINLINE int Height( void ) const { return m_nNumRows; } FORCEINLINE int Width( void ) const { return m_nNumCols; } }; FORCEINLINE float CSparseMatrix::Element( int nRow, int nCol ) const { Assert( nCol < m_nNumCols ); int nCount = m_rowDescriptors[nRow].m_nNonZeroCount; if ( nCount ) { NonZeroValueDescriptor_t const *pValue = &(m_entries[m_rowDescriptors[nRow].m_nDataIndex]); do { int nIdx = pValue->m_nColumnNumber; if ( nIdx == nCol ) { return pValue->m_flValue; } if ( nIdx > nCol ) { break; } pValue++; } while( --nCount ); } return 0; } // type-specific overrides of matrixmath template for special case sparse routines namespace MatrixMath { /// sparse * dense matrix x matrix multiplication template void MatrixMultiply( CSparseMatrix const &matA, BTYPE const &matB, OUTTYPE *pMatrixOut ) { Assert( matA.Width() == matB.Height() ); pMatrixOut->SetDimensions( matA.Height(), matB.Width() ); for( int i = 0; i < matA.Height(); i++ ) { for( int j = 0; j < matB.Width(); j++ ) { // compute inner product efficiently because of sparsity int nCnt = matA.m_rowDescriptors[i].m_nNonZeroCount; int nDataIdx = matA.m_rowDescriptors[i].m_nDataIndex; float flDot = 0.0; for( int nIdx = 0; nIdx < nCnt; nIdx++ ) { float flAValue = matA.m_entries[nIdx + nDataIdx].m_flValue; int nCol = matA.m_entries[nIdx + nDataIdx].m_nColumnNumber; flDot += flAValue * matB.Element( nCol, j ); } pMatrixOut->SetElement( i, j, flDot ); } } } FORCEINLINE void AppendElement( CSparseMatrix &matrix, int nRow, int nCol, float flValue ) { matrix.AppendElement( nRow, nCol, flValue ); // default implementation } FORCEINLINE void FinishedAppending( CSparseMatrix &matrix ) { matrix.FinishedAppending(); } }; #endif // SPARSEMATRIX_H