Team Fortress 2 Source Code as on 22/4/2020
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

385 lines
12 KiB

  1. //========= Copyright Valve Corporation, All rights reserved. ============//
  2. //
  3. // Purpose:
  4. //
  5. // A set of generic, template-based matrix functions.
  6. //===========================================================================//
  7. #ifndef MATRIXMATH_H
  8. #define MATRIXMATH_H
  9. #include <stdarg.h>
  10. // The operations in this file can perform basic matrix operations on matrices represented
  11. // using any class that supports the necessary operations:
  12. //
  13. // .Element( row, col ) - return the element at a given matrox position
  14. // .SetElement( row, col, val ) - modify an element
  15. // .Width(), .Height() - get dimensions
  16. // .SetDimensions( nrows, ncols) - set a matrix to be un-initted and the appropriate size
  17. //
  18. // Generally, vectors can be used with these functions by using N x 1 matrices to represent them.
  19. // Matrices are addressed as row, column, and indices are 0-based
  20. //
  21. //
  22. // Note that the template versions of these routines are defined for generality - it is expected
  23. // that template specialization is used for common high performance cases.
  24. namespace MatrixMath
  25. {
  26. /// M *= flScaleValue
  27. template<class MATRIXCLASS>
  28. void ScaleMatrix( MATRIXCLASS &matrix, float flScaleValue )
  29. {
  30. for( int i = 0; i < matrix.Height(); i++ )
  31. {
  32. for( int j = 0; j < matrix.Width(); j++ )
  33. {
  34. matrix.SetElement( i, j, flScaleValue * matrix.Element( i, j ) );
  35. }
  36. }
  37. }
  38. /// AppendElementToMatrix - same as setting the element, except only works when all calls
  39. /// happen in top to bottom left to right order, end you have to call FinishedAppending when
  40. /// done. For normal matrix classes this is not different then SetElement, but for
  41. /// CSparseMatrix, it is an accelerated way to fill a matrix from scratch.
  42. template<class MATRIXCLASS>
  43. FORCEINLINE void AppendElement( MATRIXCLASS &matrix, int nRow, int nCol, float flValue )
  44. {
  45. matrix.SetElement( nRow, nCol, flValue ); // default implementation
  46. }
  47. template<class MATRIXCLASS>
  48. FORCEINLINE void FinishedAppending( MATRIXCLASS &matrix ) {} // default implementation
  49. /// M += fl
  50. template<class MATRIXCLASS>
  51. void AddToMatrix( MATRIXCLASS &matrix, float flAddend )
  52. {
  53. for( int i = 0; i < matrix.Height(); i++ )
  54. {
  55. for( int j = 0; j < matrix.Width(); j++ )
  56. {
  57. matrix.SetElement( i, j, flAddend + matrix.Element( i, j ) );
  58. }
  59. }
  60. }
  61. /// transpose
  62. template<class MATRIXCLASSIN, class MATRIXCLASSOUT>
  63. void TransposeMatrix( MATRIXCLASSIN const &matrixIn, MATRIXCLASSOUT *pMatrixOut )
  64. {
  65. pMatrixOut->SetDimensions( matrixIn.Width(), matrixIn.Height() );
  66. for( int i = 0; i < pMatrixOut->Height(); i++ )
  67. {
  68. for( int j = 0; j < pMatrixOut->Width(); j++ )
  69. {
  70. AppendElement( *pMatrixOut, i, j, matrixIn.Element( j, i ) );
  71. }
  72. }
  73. FinishedAppending( *pMatrixOut );
  74. }
  75. /// copy
  76. template<class MATRIXCLASSIN, class MATRIXCLASSOUT>
  77. void CopyMatrix( MATRIXCLASSIN const &matrixIn, MATRIXCLASSOUT *pMatrixOut )
  78. {
  79. pMatrixOut->SetDimensions( matrixIn.Height(), matrixIn.Width() );
  80. for( int i = 0; i < matrixIn.Height(); i++ )
  81. {
  82. for( int j = 0; j < matrixIn.Width(); j++ )
  83. {
  84. AppendElement( *pMatrixOut, i, j, matrixIn.Element( i, j ) );
  85. }
  86. }
  87. FinishedAppending( *pMatrixOut );
  88. }
  89. /// M+=M
  90. template<class MATRIXCLASSIN, class MATRIXCLASSOUT>
  91. void AddMatrixToMatrix( MATRIXCLASSIN const &matrixIn, MATRIXCLASSOUT *pMatrixOut )
  92. {
  93. for( int i = 0; i < matrixIn.Height(); i++ )
  94. {
  95. for( int j = 0; j < matrixIn.Width(); j++ )
  96. {
  97. pMatrixOut->SetElement( i, j, pMatrixOut->Element( i, j ) + matrixIn.Element( i, j ) );
  98. }
  99. }
  100. }
  101. // M += scale * M
  102. template<class MATRIXCLASSIN, class MATRIXCLASSOUT>
  103. void AddScaledMatrixToMatrix( float flScale, MATRIXCLASSIN const &matrixIn, MATRIXCLASSOUT *pMatrixOut )
  104. {
  105. for( int i = 0; i < matrixIn.Height(); i++ )
  106. {
  107. for( int j = 0; j < matrixIn.Width(); j++ )
  108. {
  109. pMatrixOut->SetElement( i, j, pMatrixOut->Element( i, j ) + flScale * matrixIn.Element( i, j ) );
  110. }
  111. }
  112. }
  113. // simple way to initialize a matrix with constants from code.
  114. template<class MATRIXCLASSOUT>
  115. void SetMatrixToIdentity( MATRIXCLASSOUT *pMatrixOut, float flDiagonalValue = 1.0 )
  116. {
  117. for( int i = 0; i < pMatrixOut->Height(); i++ )
  118. {
  119. for( int j = 0; j < pMatrixOut->Width(); j++ )
  120. {
  121. AppendElement( *pMatrixOut, i, j, ( i == j ) ? flDiagonalValue : 0 );
  122. }
  123. }
  124. FinishedAppending( *pMatrixOut );
  125. }
  126. //// simple way to initialize a matrix with constants from code
  127. template<class MATRIXCLASSOUT>
  128. void SetMatrixValues( MATRIXCLASSOUT *pMatrix, int nRows, int nCols, ... )
  129. {
  130. va_list argPtr;
  131. va_start( argPtr, nCols );
  132. pMatrix->SetDimensions( nRows, nCols );
  133. for( int nRow = 0; nRow < nRows; nRow++ )
  134. {
  135. for( int nCol = 0; nCol < nCols; nCol++ )
  136. {
  137. double flNewValue = va_arg( argPtr, double );
  138. pMatrix->SetElement( nRow, nCol, flNewValue );
  139. }
  140. }
  141. va_end( argPtr );
  142. }
  143. /// row and colum accessors. treat a row or a column as a column vector
  144. template<class MATRIXTYPE> class MatrixRowAccessor
  145. {
  146. public:
  147. FORCEINLINE MatrixRowAccessor( MATRIXTYPE const &matrix, int nRow )
  148. {
  149. m_pMatrix = &matrix;
  150. m_nRow = nRow;
  151. }
  152. FORCEINLINE float Element( int nRow, int nCol ) const
  153. {
  154. Assert( nCol == 0 );
  155. return m_pMatrix->Element( m_nRow, nRow );
  156. }
  157. FORCEINLINE int Width( void ) const { return 1; };
  158. FORCEINLINE int Height( void ) const { return m_pMatrix->Width(); }
  159. private:
  160. MATRIXTYPE const *m_pMatrix;
  161. int m_nRow;
  162. };
  163. template<class MATRIXTYPE> class MatrixColumnAccessor
  164. {
  165. public:
  166. FORCEINLINE MatrixColumnAccessor( MATRIXTYPE const &matrix, int nColumn )
  167. {
  168. m_pMatrix = &matrix;
  169. m_nColumn = nColumn;
  170. }
  171. FORCEINLINE float Element( int nRow, int nColumn ) const
  172. {
  173. Assert( nColumn == 0 );
  174. return m_pMatrix->Element( nRow, m_nColumn );
  175. }
  176. FORCEINLINE int Width( void ) const { return 1; }
  177. FORCEINLINE int Height( void ) const { return m_pMatrix->Height(); }
  178. private:
  179. MATRIXTYPE const *m_pMatrix;
  180. int m_nColumn;
  181. };
  182. /// this translator acts as a proxy for the transposed matrix
  183. template<class MATRIXTYPE> class MatrixTransposeAccessor
  184. {
  185. public:
  186. FORCEINLINE MatrixTransposeAccessor( MATRIXTYPE const & matrix )
  187. {
  188. m_pMatrix = &matrix;
  189. }
  190. FORCEINLINE float Element( int nRow, int nColumn ) const
  191. {
  192. return m_pMatrix->Element( nColumn, nRow );
  193. }
  194. FORCEINLINE int Width( void ) const { return m_pMatrix->Height(); }
  195. FORCEINLINE int Height( void ) const { return m_pMatrix->Width(); }
  196. private:
  197. MATRIXTYPE const *m_pMatrix;
  198. };
  199. /// this tranpose returns a wrapper around it's argument, allowing things like AddMatrixToMatrix( Transpose( matA ), &matB ) without an extra copy
  200. template<class MATRIXCLASSIN>
  201. MatrixTransposeAccessor<MATRIXCLASSIN> TransposeMatrix( MATRIXCLASSIN const &matrixIn )
  202. {
  203. return MatrixTransposeAccessor<MATRIXCLASSIN>( matrixIn );
  204. }
  205. /// retrieve rows and columns
  206. template<class MATRIXTYPE>
  207. FORCEINLINE MatrixColumnAccessor<MATRIXTYPE> MatrixColumn( MATRIXTYPE const &matrix, int nColumn )
  208. {
  209. return MatrixColumnAccessor<MATRIXTYPE>( matrix, nColumn );
  210. }
  211. template<class MATRIXTYPE>
  212. FORCEINLINE MatrixRowAccessor<MATRIXTYPE> MatrixRow( MATRIXTYPE const &matrix, int nRow )
  213. {
  214. return MatrixRowAccessor<MATRIXTYPE>( matrix, nRow );
  215. }
  216. //// dot product between vectors (or rows and/or columns via accessors)
  217. template<class MATRIXACCESSORATYPE, class MATRIXACCESSORBTYPE >
  218. float InnerProduct( MATRIXACCESSORATYPE const &vecA, MATRIXACCESSORBTYPE const &vecB )
  219. {
  220. Assert( vecA.Width() == 1 );
  221. Assert( vecB.Width() == 1 );
  222. Assert( vecA.Height() == vecB.Height() );
  223. double flResult = 0;
  224. for( int i = 0; i < vecA.Height(); i++ )
  225. {
  226. flResult += vecA.Element( i, 0 ) * vecB.Element( i, 0 );
  227. }
  228. return flResult;
  229. }
  230. /// matrix x matrix multiplication
  231. template<class MATRIXATYPE, class MATRIXBTYPE, class MATRIXOUTTYPE>
  232. void MatrixMultiply( MATRIXATYPE const &matA, MATRIXBTYPE const &matB, MATRIXOUTTYPE *pMatrixOut )
  233. {
  234. Assert( matA.Width() == matB.Height() );
  235. pMatrixOut->SetDimensions( matA.Height(), matB.Width() );
  236. for( int i = 0; i < matA.Height(); i++ )
  237. {
  238. for( int j = 0; j < matB.Width(); j++ )
  239. {
  240. pMatrixOut->SetElement( i, j, InnerProduct( MatrixRow( matA, i ), MatrixColumn( matB, j ) ) );
  241. }
  242. }
  243. }
  244. /// solve Ax=B via the conjugate graident method. Code and naming conventions based on the
  245. /// wikipedia article.
  246. template<class ATYPE, class XTYPE, class BTYPE>
  247. void ConjugateGradient( ATYPE const &matA, BTYPE const &vecB, XTYPE &vecX, float flTolerance = 1.0e-20 )
  248. {
  249. XTYPE vecR;
  250. vecR.SetDimensions( vecX.Height(), 1 );
  251. MatrixMultiply( matA, vecX, &vecR );
  252. ScaleMatrix( vecR, -1 );
  253. AddMatrixToMatrix( vecB, &vecR );
  254. XTYPE vecP;
  255. CopyMatrix( vecR, &vecP );
  256. float flRsOld = InnerProduct( vecR, vecR );
  257. for( int nIter = 0; nIter < 100; nIter++ )
  258. {
  259. XTYPE vecAp;
  260. MatrixMultiply( matA, vecP, &vecAp );
  261. float flDivisor = InnerProduct( vecAp, vecP );
  262. float flAlpha = flRsOld / flDivisor;
  263. AddScaledMatrixToMatrix( flAlpha, vecP, &vecX );
  264. AddScaledMatrixToMatrix( -flAlpha, vecAp, &vecR );
  265. float flRsNew = InnerProduct( vecR, vecR );
  266. if ( flRsNew < flTolerance )
  267. {
  268. break;
  269. }
  270. ScaleMatrix( vecP, flRsNew / flRsOld );
  271. AddMatrixToMatrix( vecR, &vecP );
  272. flRsOld = flRsNew;
  273. }
  274. }
  275. /// solve (A'*A) x=B via the conjugate gradient method. Code and naming conventions based on
  276. /// the wikipedia article. Same as Conjugate gradient but allows passing in two matrices whose
  277. /// product is used as the A matrix (in order to preserve sparsity)
  278. template<class ATYPE, class APRIMETYPE, class XTYPE, class BTYPE>
  279. void ConjugateGradient( ATYPE const &matA, APRIMETYPE const &matAPrime, BTYPE const &vecB, XTYPE &vecX, float flTolerance = 1.0e-20 )
  280. {
  281. XTYPE vecR1;
  282. vecR1.SetDimensions( vecX.Height(), 1 );
  283. MatrixMultiply( matA, vecX, &vecR1 );
  284. XTYPE vecR;
  285. vecR.SetDimensions( vecR1.Height(), 1 );
  286. MatrixMultiply( matAPrime, vecR1, &vecR );
  287. ScaleMatrix( vecR, -1 );
  288. AddMatrixToMatrix( vecB, &vecR );
  289. XTYPE vecP;
  290. CopyMatrix( vecR, &vecP );
  291. float flRsOld = InnerProduct( vecR, vecR );
  292. for( int nIter = 0; nIter < 100; nIter++ )
  293. {
  294. XTYPE vecAp1;
  295. MatrixMultiply( matA, vecP, &vecAp1 );
  296. XTYPE vecAp;
  297. MatrixMultiply( matAPrime, vecAp1, &vecAp );
  298. float flDivisor = InnerProduct( vecAp, vecP );
  299. float flAlpha = flRsOld / flDivisor;
  300. AddScaledMatrixToMatrix( flAlpha, vecP, &vecX );
  301. AddScaledMatrixToMatrix( -flAlpha, vecAp, &vecR );
  302. float flRsNew = InnerProduct( vecR, vecR );
  303. if ( flRsNew < flTolerance )
  304. {
  305. break;
  306. }
  307. ScaleMatrix( vecP, flRsNew / flRsOld );
  308. AddMatrixToMatrix( vecR, &vecP );
  309. flRsOld = flRsNew;
  310. }
  311. }
  312. template<class ATYPE, class XTYPE, class BTYPE>
  313. void LeastSquaresFit( ATYPE const &matA, BTYPE const &vecB, XTYPE &vecX )
  314. {
  315. // now, generate the normal equations
  316. BTYPE vecBeta;
  317. MatrixMath::MatrixMultiply( MatrixMath::TransposeMatrix( matA ), vecB, &vecBeta );
  318. vecX.SetDimensions( matA.Width(), 1 );
  319. MatrixMath::SetMatrixToIdentity( &vecX );
  320. ATYPE matATransposed;
  321. TransposeMatrix( matA, &matATransposed );
  322. ConjugateGradient( matA, matATransposed, vecBeta, vecX, 1.0e-20 );
  323. }
  324. };
  325. /// a simple fixed-size matrix class
  326. template<int NUMROWS, int NUMCOLS> class CFixedMatrix
  327. {
  328. public:
  329. FORCEINLINE int Width( void ) const { return NUMCOLS; }
  330. FORCEINLINE int Height( void ) const { return NUMROWS; }
  331. FORCEINLINE float Element( int nRow, int nCol ) const { return m_flValues[nRow][nCol]; }
  332. FORCEINLINE void SetElement( int nRow, int nCol, float flValue ) { m_flValues[nRow][nCol] = flValue; }
  333. FORCEINLINE void SetDimensions( int nNumRows, int nNumCols ) { Assert( ( nNumRows == NUMROWS ) && ( nNumCols == NUMCOLS ) ); }
  334. private:
  335. float m_flValues[NUMROWS][NUMCOLS];
  336. };
  337. #endif //matrixmath_h