#ifndef _SCTL_MATRIX_HPP_ #define _SCTL_MATRIX_HPP_ #include #include #include SCTL_INCLUDE(vector.hpp) #include SCTL_INCLUDE(common.hpp) namespace SCTL_NAMESPACE { template class Vector; template class Permutation; template class Matrix { public: typedef ValueType value_type; typedef ValueType& reference; typedef const ValueType& const_reference; typedef Iterator iterator; typedef ConstIterator const_iterator; typedef Long difference_type; typedef Long size_type; Matrix(); Matrix(Long dim1, Long dim2, Iterator data_ = nullptr, bool own_data_ = true); Matrix(const Matrix& M); ~Matrix(); void Swap(Matrix& M); void ReInit(Long dim1, Long dim2, Iterator data_ = nullptr, bool own_data_ = true); void Write(const char* fname) const; void Read(const char* fname); Long Dim(Long i) const; void SetZero(); Iterator begin(); ConstIterator begin() const; Iterator end(); ConstIterator end() const; // Matrix-Matrix operations Matrix& operator=(const Matrix& M); Matrix& operator+=(const Matrix& M); Matrix& operator-=(const Matrix& M); Matrix operator+(const Matrix& M2) const; Matrix operator-(const Matrix& M2) const; Matrix operator*(const Matrix& M) const; static void GEMM(Matrix& M_r, const Matrix& A, const Matrix& B, ValueType beta = 0.0); static void GEMM(Matrix& M_r, const Permutation& P, const Matrix& M, ValueType beta = 0.0); static void GEMM(Matrix& M_r, const Matrix& M, const Permutation& P, ValueType beta = 0.0); // cublasgemm wrapper static void CUBLASGEMM(Matrix& M_r, const Matrix& A, const Matrix& B, ValueType beta = 0.0); // Matrix-Scalar operations Matrix& operator=(ValueType s); Matrix& operator+=(ValueType s); Matrix& operator-=(ValueType s); Matrix& operator*=(ValueType s); Matrix& operator/=(ValueType s); Matrix operator+(ValueType s) const; Matrix operator-(ValueType s) const; Matrix operator*(ValueType s) const; Matrix operator/(ValueType s) const; // Element access ValueType& operator()(Long i, Long j); const ValueType& operator()(Long i, Long j) const; Iterator operator[](Long i); ConstIterator operator[](Long i) const; void RowPerm(const Permutation& P); void ColPerm(const Permutation& P); Matrix Transpose() const; static void Transpose(Matrix& M_r, const Matrix& M); // Original matrix is destroyed. void SVD(Matrix& tU, Matrix& tS, Matrix& tVT); // Original matrix is destroyed. Matrix pinv(ValueType eps = -1); private: StaticArray dim; Iterator data_ptr; bool own_data; }; template std::ostream& operator<<(std::ostream& output, const Matrix& M); template Matrix operator+(ValueType s, const Matrix& M) { return M + s; } template Matrix operator-(ValueType s, const Matrix& M) { return s + (M * -1.0); } template Matrix operator*(ValueType s, const Matrix& M) { return M * s; } /** * /brief P=[e(p1)*s1 e(p2)*s2 ... e(pn)*sn], * where e(k) is the kth unit vector, * perm := [p1 p2 ... pn] is the permutation vector, * scal := [s1 s2 ... sn] is the scaling vector. */ template class Permutation { public: Permutation() {} Permutation(Long size); static Permutation RandPerm(Long size); Matrix GetMatrix() const; Long Dim() const; Permutation Transpose(); Permutation& operator*=(ValueType s); Permutation& operator/=(ValueType s); Permutation operator*(ValueType s) const; Permutation operator/(ValueType s) const; Permutation operator*(const Permutation& P) const; Matrix operator*(const Matrix& M) const; Vector perm; Vector scal; }; template Permutation operator*(ValueType s, const Permutation& P) { return P * s; } template Matrix operator*(const Matrix& M, const Permutation& P); template std::ostream& operator<<(std::ostream& output, const Permutation& P); } // end namespace #include SCTL_INCLUDE(matrix.txx) #endif //_SCTL_MATRIX_HPP_