matrix.hpp 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #ifndef _SCTL_MATRIX_HPP_
  2. #define _SCTL_MATRIX_HPP_
  3. #include <cstdint>
  4. #include <cstdlib>
  5. #include SCTL_INCLUDE(vector.hpp)
  6. #include SCTL_INCLUDE(common.hpp)
  7. namespace SCTL_NAMESPACE {
  8. template <class ValueType> class Vector;
  9. template <class ValueType> class Permutation;
  10. template <class ValueType> class Matrix {
  11. public:
  12. typedef ValueType value_type;
  13. typedef ValueType& reference;
  14. typedef const ValueType& const_reference;
  15. typedef Iterator<ValueType> iterator;
  16. typedef ConstIterator<ValueType> const_iterator;
  17. typedef Long difference_type;
  18. typedef Long size_type;
  19. Matrix();
  20. Matrix(Long dim1, Long dim2, Iterator<ValueType> data_ = NullIterator<ValueType>(), bool own_data_ = true);
  21. Matrix(const Matrix<ValueType>& M);
  22. ~Matrix();
  23. void Swap(Matrix<ValueType>& M);
  24. void ReInit(Long dim1, Long dim2, Iterator<ValueType> data_ = NullIterator<ValueType>(), bool own_data_ = true);
  25. void Write(const char* fname) const;
  26. void Read(const char* fname);
  27. Long Dim(Long i) const;
  28. void SetZero();
  29. Iterator<ValueType> begin();
  30. ConstIterator<ValueType> begin() const;
  31. Iterator<ValueType> end();
  32. ConstIterator<ValueType> end() const;
  33. // Matrix-Matrix operations
  34. Matrix<ValueType>& operator=(const Matrix<ValueType>& M);
  35. Matrix<ValueType>& operator+=(const Matrix<ValueType>& M);
  36. Matrix<ValueType>& operator-=(const Matrix<ValueType>& M);
  37. Matrix<ValueType> operator+(const Matrix<ValueType>& M2) const;
  38. Matrix<ValueType> operator-(const Matrix<ValueType>& M2) const;
  39. Matrix<ValueType> operator*(const Matrix<ValueType>& M) const;
  40. static void GEMM(Matrix<ValueType>& M_r, const Matrix<ValueType>& A, const Matrix<ValueType>& B, ValueType beta = 0.0);
  41. static void GEMM(Matrix<ValueType>& M_r, const Permutation<ValueType>& P, const Matrix<ValueType>& M, ValueType beta = 0.0);
  42. static void GEMM(Matrix<ValueType>& M_r, const Matrix<ValueType>& M, const Permutation<ValueType>& P, ValueType beta = 0.0);
  43. // cublasgemm wrapper
  44. static void CUBLASGEMM(Matrix<ValueType>& M_r, const Matrix<ValueType>& A, const Matrix<ValueType>& B, ValueType beta = 0.0);
  45. // Matrix-Scalar operations
  46. Matrix<ValueType>& operator=(ValueType s);
  47. Matrix<ValueType>& operator+=(ValueType s);
  48. Matrix<ValueType>& operator-=(ValueType s);
  49. Matrix<ValueType>& operator*=(ValueType s);
  50. Matrix<ValueType>& operator/=(ValueType s);
  51. Matrix<ValueType> operator+(ValueType s) const;
  52. Matrix<ValueType> operator-(ValueType s) const;
  53. Matrix<ValueType> operator*(ValueType s) const;
  54. Matrix<ValueType> operator/(ValueType s) const;
  55. // Element access
  56. ValueType& operator()(Long i, Long j);
  57. const ValueType& operator()(Long i, Long j) const;
  58. Iterator<ValueType> operator[](Long i);
  59. ConstIterator<ValueType> operator[](Long i) const;
  60. void RowPerm(const Permutation<ValueType>& P);
  61. void ColPerm(const Permutation<ValueType>& P);
  62. Matrix<ValueType> Transpose() const;
  63. static void Transpose(Matrix<ValueType>& M_r, const Matrix<ValueType>& M);
  64. // Original matrix is destroyed.
  65. void SVD(Matrix<ValueType>& tU, Matrix<ValueType>& tS, Matrix<ValueType>& tVT);
  66. // Original matrix is destroyed.
  67. Matrix<ValueType> pinv(ValueType eps = -1);
  68. private:
  69. void Init(Long dim1, Long dim2, Iterator<ValueType> data_ = NullIterator<ValueType>(), bool own_data_ = true);
  70. StaticArray<Long, 2> dim;
  71. Iterator<ValueType> data_ptr;
  72. bool own_data;
  73. };
  74. template <class ValueType> std::ostream& operator<<(std::ostream& output, const Matrix<ValueType>& M);
  75. template <class ValueType> Matrix<ValueType> operator+(ValueType s, const Matrix<ValueType>& M) { return M + s; }
  76. template <class ValueType> Matrix<ValueType> operator-(ValueType s, const Matrix<ValueType>& M) { return s + (M * -1.0); }
  77. template <class ValueType> Matrix<ValueType> operator*(ValueType s, const Matrix<ValueType>& M) { return M * s; }
  78. /**
  79. * /brief P=[e(p1)*s1 e(p2)*s2 ... e(pn)*sn],
  80. * where e(k) is the kth unit vector,
  81. * perm := [p1 p2 ... pn] is the permutation vector,
  82. * scal := [s1 s2 ... sn] is the scaling vector.
  83. */
  84. template <class ValueType> class Permutation {
  85. public:
  86. Permutation() {}
  87. Permutation(Long size);
  88. static Permutation<ValueType> RandPerm(Long size);
  89. Matrix<ValueType> GetMatrix() const;
  90. Long Dim() const;
  91. Permutation<ValueType> Transpose();
  92. Permutation<ValueType>& operator*=(ValueType s);
  93. Permutation<ValueType>& operator/=(ValueType s);
  94. Permutation<ValueType> operator*(ValueType s) const;
  95. Permutation<ValueType> operator/(ValueType s) const;
  96. Permutation<ValueType> operator*(const Permutation<ValueType>& P) const;
  97. Matrix<ValueType> operator*(const Matrix<ValueType>& M) const;
  98. Vector<Long> perm;
  99. Vector<ValueType> scal;
  100. };
  101. template <class ValueType> Permutation<ValueType> operator*(ValueType s, const Permutation<ValueType>& P) { return P * s; }
  102. template <class ValueType> Matrix<ValueType> operator*(const Matrix<ValueType>& M, const Permutation<ValueType>& P);
  103. template <class ValueType> std::ostream& operator<<(std::ostream& output, const Permutation<ValueType>& P);
  104. } // end namespace
  105. #include SCTL_INCLUDE(matrix.txx)
  106. #endif //_SCTL_MATRIX_HPP_