matrix.hpp 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. #ifndef _PVFMM_MATRIX_HPP_
  2. #define _PVFMM_MATRIX_HPP_
  3. #include <stdint.h>
  4. #include <cstdlib>
  5. #include <pvfmm/common.hpp>
  6. namespace pvfmm {
  7. template <class ValueType> class Vector;
  8. template <class ValueType> class Permutation;
  9. template <class ValueType> class Matrix {
  10. public:
  11. Matrix();
  12. Matrix(Long dim1, Long dim2, Iterator<ValueType> data_ = NULL, bool own_data_ = true);
  13. Matrix(const Matrix<ValueType>& M);
  14. ~Matrix();
  15. void Swap(Matrix<ValueType>& M);
  16. void ReInit(Long dim1, Long dim2, Iterator<ValueType> data_ = NULL, bool own_data_ = true);
  17. void Write(const char* fname) const;
  18. void Read(const char* fname);
  19. Long Dim(Long i) const;
  20. void SetZero();
  21. Iterator<ValueType> Begin();
  22. ConstIterator<ValueType> Begin() const;
  23. Matrix<ValueType>& operator=(const Matrix<ValueType>& M);
  24. Matrix<ValueType>& operator+=(const Matrix<ValueType>& M);
  25. Matrix<ValueType>& operator-=(const Matrix<ValueType>& M);
  26. Matrix<ValueType> operator+(const Matrix<ValueType>& M2) const;
  27. Matrix<ValueType> operator-(const Matrix<ValueType>& M2) const;
  28. ValueType& operator()(Long i, Long j);
  29. const ValueType& operator()(Long i, Long j) const;
  30. Iterator<ValueType> operator[](Long i);
  31. ConstIterator<ValueType> operator[](Long i) const;
  32. Matrix<ValueType> operator*(const Matrix<ValueType>& M) const;
  33. static void GEMM(Matrix<ValueType>& M_r, const Matrix<ValueType>& A, const Matrix<ValueType>& B, ValueType beta = 0.0);
  34. // cublasgemm wrapper
  35. static void CUBLASGEMM(Matrix<ValueType>& M_r, const Matrix<ValueType>& A, const Matrix<ValueType>& B, ValueType beta = 0.0);
  36. void RowPerm(const Permutation<ValueType>& P);
  37. void ColPerm(const Permutation<ValueType>& P);
  38. Matrix<ValueType> Transpose() const;
  39. static void Transpose(Matrix<ValueType>& M_r, const Matrix<ValueType>& M);
  40. // Original matrix is destroyed.
  41. void SVD(Matrix<ValueType>& tU, Matrix<ValueType>& tS, Matrix<ValueType>& tVT);
  42. // Original matrix is destroyed.
  43. Matrix<ValueType> pinv(ValueType eps = -1);
  44. private:
  45. StaticArray<Long, 2> dim;
  46. Iterator<ValueType> data_ptr;
  47. bool own_data;
  48. };
  49. template <class ValueType> std::ostream& operator<<(std::ostream& output, const Matrix<ValueType>& M);
  50. /**
  51. * /brief P=[e(p1)*s1 e(p2)*s2 ... e(pn)*sn],
  52. * where e(k) is the kth unit vector,
  53. * perm := [p1 p2 ... pn] is the permutation vector,
  54. * scal := [s1 s2 ... sn] is the scaling vector.
  55. */
  56. #define PERM_INT_T Long
  57. template <class ValueType> class Permutation {
  58. public:
  59. Permutation() {}
  60. Permutation(Long size);
  61. static Permutation<ValueType> RandPerm(Long size);
  62. Matrix<ValueType> GetMatrix() const;
  63. Long Dim() const;
  64. Permutation<ValueType> Transpose();
  65. Permutation<ValueType> operator*(const Permutation<ValueType>& P);
  66. Matrix<ValueType> operator*(const Matrix<ValueType>& M);
  67. Vector<PERM_INT_T> perm;
  68. Vector<ValueType> scal;
  69. };
  70. template <class ValueType> Matrix<ValueType> operator*(const Matrix<ValueType>& M, const Permutation<ValueType>& P);
  71. template <class ValueType> std::ostream& operator<<(std::ostream& output, const Permutation<ValueType>& P);
  72. } // end namespace
  73. #include <pvfmm/matrix.txx>
  74. #endif //_PVFMM_MATRIX_HPP_