matrix.hpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #ifndef _SCTL_MATRIX_HPP_
  2. #define _SCTL_MATRIX_HPP_
  3. #include <cstdint>
  4. #include <cstdlib>
  5. #include SCTL_INCLUDE(vector.hpp)
  6. #include SCTL_INCLUDE(common.hpp)
  7. namespace SCTL_NAMESPACE {
  8. template <class ValueType> class Vector;
  9. template <class ValueType> class Permutation;
  10. template <class ValueType> class Matrix {
  11. public:
  12. Matrix();
  13. Matrix(Long dim1, Long dim2, Iterator<ValueType> data_ = NULL, bool own_data_ = true);
  14. Matrix(const Matrix<ValueType>& M);
  15. ~Matrix();
  16. void Swap(Matrix<ValueType>& M);
  17. void ReInit(Long dim1, Long dim2, Iterator<ValueType> data_ = NULL, bool own_data_ = true);
  18. void Write(const char* fname) const;
  19. void Read(const char* fname);
  20. Long Dim(Long i) const;
  21. void SetZero();
  22. Iterator<ValueType> begin();
  23. ConstIterator<ValueType> begin() const;
  24. Iterator<ValueType> end();
  25. ConstIterator<ValueType> end() const;
  26. // Matrix-Matrix operations
  27. Matrix<ValueType>& operator=(const Matrix<ValueType>& M);
  28. Matrix<ValueType>& operator+=(const Matrix<ValueType>& M);
  29. Matrix<ValueType>& operator-=(const Matrix<ValueType>& M);
  30. Matrix<ValueType> operator+(const Matrix<ValueType>& M2) const;
  31. Matrix<ValueType> operator-(const Matrix<ValueType>& M2) const;
  32. Matrix<ValueType> operator*(const Matrix<ValueType>& M) const;
  33. static void GEMM(Matrix<ValueType>& M_r, const Matrix<ValueType>& A, const Matrix<ValueType>& B, ValueType beta = 0.0);
  34. static void GEMM(Matrix<ValueType>& M_r, const Permutation<ValueType>& P, const Matrix<ValueType>& M, ValueType beta = 0.0);
  35. static void GEMM(Matrix<ValueType>& M_r, const Matrix<ValueType>& M, const Permutation<ValueType>& P, ValueType beta = 0.0);
  36. // cublasgemm wrapper
  37. static void CUBLASGEMM(Matrix<ValueType>& M_r, const Matrix<ValueType>& A, const Matrix<ValueType>& B, ValueType beta = 0.0);
  38. // Matrix-Scalar operations
  39. Matrix<ValueType>& operator=(ValueType s);
  40. Matrix<ValueType>& operator+=(ValueType s);
  41. Matrix<ValueType>& operator-=(ValueType s);
  42. Matrix<ValueType>& operator*=(ValueType s);
  43. Matrix<ValueType>& operator/=(ValueType s);
  44. Matrix<ValueType> operator+(ValueType s) const;
  45. Matrix<ValueType> operator-(ValueType s) const;
  46. Matrix<ValueType> operator*(ValueType s) const;
  47. Matrix<ValueType> operator/(ValueType s) const;
  48. // Element access
  49. ValueType& operator()(Long i, Long j);
  50. const ValueType& operator()(Long i, Long j) const;
  51. Iterator<ValueType> operator[](Long i);
  52. ConstIterator<ValueType> operator[](Long i) const;
  53. void RowPerm(const Permutation<ValueType>& P);
  54. void ColPerm(const Permutation<ValueType>& P);
  55. Matrix<ValueType> Transpose() const;
  56. static void Transpose(Matrix<ValueType>& M_r, const Matrix<ValueType>& M);
  57. // Original matrix is destroyed.
  58. void SVD(Matrix<ValueType>& tU, Matrix<ValueType>& tS, Matrix<ValueType>& tVT);
  59. // Original matrix is destroyed.
  60. Matrix<ValueType> pinv(ValueType eps = -1);
  61. private:
  62. StaticArray<Long, 2> dim;
  63. Iterator<ValueType> data_ptr;
  64. bool own_data;
  65. };
  66. template <class ValueType> std::ostream& operator<<(std::ostream& output, const Matrix<ValueType>& M);
  67. template <class ValueType> Matrix<ValueType> operator+(ValueType s, const Matrix<ValueType>& M) { return M + s; }
  68. template <class ValueType> Matrix<ValueType> operator-(ValueType s, const Matrix<ValueType>& M) { return s + (M * -1.0); }
  69. template <class ValueType> Matrix<ValueType> operator*(ValueType s, const Matrix<ValueType>& M) { return M * s; }
  70. /**
  71. * /brief P=[e(p1)*s1 e(p2)*s2 ... e(pn)*sn],
  72. * where e(k) is the kth unit vector,
  73. * perm := [p1 p2 ... pn] is the permutation vector,
  74. * scal := [s1 s2 ... sn] is the scaling vector.
  75. */
  76. template <class ValueType> class Permutation {
  77. public:
  78. Permutation() {}
  79. Permutation(Long size);
  80. static Permutation<ValueType> RandPerm(Long size);
  81. Matrix<ValueType> GetMatrix() const;
  82. Long Dim() const;
  83. Permutation<ValueType> Transpose();
  84. Permutation<ValueType>& operator*=(ValueType s);
  85. Permutation<ValueType>& operator/=(ValueType s);
  86. Permutation<ValueType> operator*(ValueType s) const;
  87. Permutation<ValueType> operator/(ValueType s) const;
  88. Permutation<ValueType> operator*(const Permutation<ValueType>& P) const;
  89. Matrix<ValueType> operator*(const Matrix<ValueType>& M) const;
  90. Vector<Long> perm;
  91. Vector<ValueType> scal;
  92. };
  93. template <class ValueType> Permutation<ValueType> operator*(ValueType s, const Permutation<ValueType>& P) { return P * s; }
  94. template <class ValueType> Matrix<ValueType> operator*(const Matrix<ValueType>& M, const Permutation<ValueType>& P);
  95. template <class ValueType> std::ostream& operator<<(std::ostream& output, const Permutation<ValueType>& P);
  96. } // end namespace
  97. #include SCTL_INCLUDE(matrix.txx)
  98. #endif //_SCTL_MATRIX_HPP_