matrix.hpp 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. #ifndef _SCTL_MATRIX_HPP_
  2. #define _SCTL_MATRIX_HPP_
  3. #include <cstdint>
  4. #include <cstdlib>
  5. #include <sctl/common.hpp>
  6. #include SCTL_INCLUDE(vector.hpp)
  7. #include SCTL_INCLUDE(mem_mgr.hpp)
  8. namespace SCTL_NAMESPACE {
  9. template <class ValueType> class Vector;
  10. template <class ValueType> class Permutation;
  11. template <class ValueType> class Matrix {
  12. public:
  13. typedef ValueType value_type;
  14. typedef ValueType& reference;
  15. typedef const ValueType& const_reference;
  16. typedef Iterator<ValueType> iterator;
  17. typedef ConstIterator<ValueType> const_iterator;
  18. typedef Long difference_type;
  19. typedef Long size_type;
  20. Matrix();
  21. Matrix(Long dim1, Long dim2, Iterator<ValueType> data_ = NullIterator<ValueType>(), bool own_data_ = true);
  22. Matrix(const Matrix<ValueType>& M);
  23. ~Matrix();
  24. void Swap(Matrix<ValueType>& M);
  25. void ReInit(Long dim1, Long dim2, Iterator<ValueType> data_ = NullIterator<ValueType>(), bool own_data_ = true);
  26. void Write(const char* fname) const;
  27. void Read(const char* fname);
  28. Long Dim(Long i) const;
  29. void SetZero();
  30. Iterator<ValueType> begin();
  31. ConstIterator<ValueType> begin() const;
  32. Iterator<ValueType> end();
  33. ConstIterator<ValueType> end() const;
  34. // Matrix-Matrix operations
  35. Matrix<ValueType>& operator=(const Matrix<ValueType>& M);
  36. Matrix<ValueType>& operator+=(const Matrix<ValueType>& M);
  37. Matrix<ValueType>& operator-=(const Matrix<ValueType>& M);
  38. Matrix<ValueType> operator+(const Matrix<ValueType>& M2) const;
  39. Matrix<ValueType> operator-(const Matrix<ValueType>& M2) const;
  40. Matrix<ValueType> operator*(const Matrix<ValueType>& M) const;
  41. static void GEMM(Matrix<ValueType>& M_r, const Matrix<ValueType>& A, const Matrix<ValueType>& B, ValueType beta = 0.0);
  42. static void GEMM(Matrix<ValueType>& M_r, const Permutation<ValueType>& P, const Matrix<ValueType>& M, ValueType beta = 0.0);
  43. static void GEMM(Matrix<ValueType>& M_r, const Matrix<ValueType>& M, const Permutation<ValueType>& P, ValueType beta = 0.0);
  44. // cublasgemm wrapper
  45. static void CUBLASGEMM(Matrix<ValueType>& M_r, const Matrix<ValueType>& A, const Matrix<ValueType>& B, ValueType beta = 0.0);
  46. // Matrix-Scalar operations
  47. Matrix<ValueType>& operator=(ValueType s);
  48. Matrix<ValueType>& operator+=(ValueType s);
  49. Matrix<ValueType>& operator-=(ValueType s);
  50. Matrix<ValueType>& operator*=(ValueType s);
  51. Matrix<ValueType>& operator/=(ValueType s);
  52. Matrix<ValueType> operator+(ValueType s) const;
  53. Matrix<ValueType> operator-(ValueType s) const;
  54. Matrix<ValueType> operator*(ValueType s) const;
  55. Matrix<ValueType> operator/(ValueType s) const;
  56. // Element access
  57. ValueType& operator()(Long i, Long j);
  58. const ValueType& operator()(Long i, Long j) const;
  59. Iterator<ValueType> operator[](Long i);
  60. ConstIterator<ValueType> operator[](Long i) const;
  61. void RowPerm(const Permutation<ValueType>& P);
  62. void ColPerm(const Permutation<ValueType>& P);
  63. Matrix<ValueType> Transpose() const;
  64. static void Transpose(Matrix<ValueType>& M_r, const Matrix<ValueType>& M);
  65. // Original matrix is destroyed.
  66. void SVD(Matrix<ValueType>& tU, Matrix<ValueType>& tS, Matrix<ValueType>& tVT);
  67. // Original matrix is destroyed.
  68. Matrix<ValueType> pinv(ValueType eps = -1);
  69. private:
  70. void Init(Long dim1, Long dim2, Iterator<ValueType> data_ = NullIterator<ValueType>(), bool own_data_ = true);
  71. StaticArray<Long, 2> dim;
  72. Iterator<ValueType> data_ptr;
  73. bool own_data;
  74. };
  75. template <class ValueType> std::ostream& operator<<(std::ostream& output, const Matrix<ValueType>& M);
  76. template <class ValueType> Matrix<ValueType> operator+(ValueType s, const Matrix<ValueType>& M) { return M + s; }
  77. template <class ValueType> Matrix<ValueType> operator-(ValueType s, const Matrix<ValueType>& M) { return s + (M * -1.0); }
  78. template <class ValueType> Matrix<ValueType> operator*(ValueType s, const Matrix<ValueType>& M) { return M * s; }
  79. /**
  80. * /brief P=[e(p1)*s1 e(p2)*s2 ... e(pn)*sn],
  81. * where e(k) is the kth unit vector,
  82. * perm := [p1 p2 ... pn] is the permutation vector,
  83. * scal := [s1 s2 ... sn] is the scaling vector.
  84. */
  85. template <class ValueType> class Permutation {
  86. public:
  87. Permutation() {}
  88. explicit Permutation(Long size);
  89. static Permutation<ValueType> RandPerm(Long size);
  90. Matrix<ValueType> GetMatrix() const;
  91. Long Dim() const;
  92. Permutation<ValueType> Transpose();
  93. Permutation<ValueType>& operator*=(ValueType s);
  94. Permutation<ValueType>& operator/=(ValueType s);
  95. Permutation<ValueType> operator*(ValueType s) const;
  96. Permutation<ValueType> operator/(ValueType s) const;
  97. Permutation<ValueType> operator*(const Permutation<ValueType>& P) const;
  98. Matrix<ValueType> operator*(const Matrix<ValueType>& M) const;
  99. Vector<Long> perm;
  100. Vector<ValueType> scal;
  101. };
  102. template <class ValueType> Permutation<ValueType> operator*(ValueType s, const Permutation<ValueType>& P) { return P * s; }
  103. template <class ValueType> Matrix<ValueType> operator*(const Matrix<ValueType>& M, const Permutation<ValueType>& P);
  104. template <class ValueType> std::ostream& operator<<(std::ostream& output, const Permutation<ValueType>& P);
  105. } // end namespace
  106. #include SCTL_INCLUDE(matrix.txx)
  107. #endif //_SCTL_MATRIX_HPP_