#ifndef _PVFMM_MATRIX_HPP_ #define _PVFMM_MATRIX_HPP_ #include #include #include namespace pvfmm { template class Vector; template class Permutation; template class Matrix { public: Matrix(); Matrix(Long dim1, Long dim2, Iterator data_ = NULL, bool own_data_ = true); Matrix(const Matrix& M); ~Matrix(); void Swap(Matrix& M); void ReInit(Long dim1, Long dim2, Iterator data_ = NULL, bool own_data_ = true); void Write(const char* fname) const; void Read(const char* fname); Long Dim(Long i) const; void SetZero(); Iterator Begin(); ConstIterator Begin() const; Matrix& operator=(const Matrix& M); Matrix& operator+=(const Matrix& M); Matrix& operator-=(const Matrix& M); Matrix operator+(const Matrix& M2) const; Matrix operator-(const Matrix& M2) const; ValueType& operator()(Long i, Long j); const ValueType& operator()(Long i, Long j) const; Iterator operator[](Long i); ConstIterator operator[](Long i) const; Matrix operator*(const Matrix& M) const; static void GEMM(Matrix& M_r, const Matrix& A, const Matrix& B, ValueType beta = 0.0); // cublasgemm wrapper static void CUBLASGEMM(Matrix& M_r, const Matrix& A, const Matrix& B, ValueType beta = 0.0); void RowPerm(const Permutation& P); void ColPerm(const Permutation& P); Matrix Transpose() const; static void Transpose(Matrix& M_r, const Matrix& M); // Original matrix is destroyed. void SVD(Matrix& tU, Matrix& tS, Matrix& tVT); // Original matrix is destroyed. Matrix pinv(ValueType eps = -1); private: StaticArray dim; Iterator data_ptr; bool own_data; }; template std::ostream& operator<<(std::ostream& output, const Matrix& M); /** * /brief P=[e(p1)*s1 e(p2)*s2 ... e(pn)*sn], * where e(k) is the kth unit vector, * perm := [p1 p2 ... pn] is the permutation vector, * scal := [s1 s2 ... sn] is the scaling vector. */ #define PERM_INT_T Long template class Permutation { public: Permutation() {} Permutation(Long size); static Permutation RandPerm(Long size); Matrix GetMatrix() const; Long Dim() const; Permutation Transpose(); Permutation operator*(const Permutation& P); Matrix operator*(const Matrix& M); Vector perm; Vector scal; }; template Matrix operator*(const Matrix& M, const Permutation& P); template std::ostream& operator<<(std::ostream& output, const Permutation& P); } // end namespace #include #endif //_PVFMM_MATRIX_HPP_