#ifndef _SCTL_TENSOR_HPP_ #define _SCTL_TENSOR_HPP_ #include SCTL_INCLUDE(mem_mgr.hpp) #include SCTL_INCLUDE(common.hpp) #include namespace SCTL_NAMESPACE { template class Tensor { template static constexpr Long SizeHelper() { return 1; } template static constexpr Long SizeHelper() { return (k >= 0 ? d : 1) * SizeHelper(); } template static constexpr Long DimHelper() { return 1; } template static constexpr Long DimHelper() { return k==0 ? d0 : DimHelper(); } template static constexpr Long OrderHelper() { return 0; } template static constexpr Long OrderHelper() { return 1 + OrderHelper(); } template struct RotateType { using Value = Tensor; }; template struct RotateType<0,own_data_,d,dd...> { using Value = Tensor; }; template struct RotateType { using Value = typename RotateType::Value; }; public: static constexpr Long Order() { return OrderHelper(); } static constexpr Long Size() { return SizeHelper<0,Args...>(); } template static constexpr Long Dim() { return DimHelper(); } Tensor(Iterator src_iter = NullIterator()) { Init(src_iter); } Tensor(const Tensor &M) { Init((Iterator)M.begin()); } template Tensor(const Tensor &M) { Init((Iterator)M.begin()); } Tensor &operator=(const Tensor &M) { memcopy(begin(), M.begin(), Size()); return *this; } Iterator begin() { return own_data ? (Iterator)buff : iter_[0]; } ConstIterator begin() const { return own_data ? (ConstIterator)buff : (ConstIterator)iter_[0]; } Iterator end() { return begin() + Size(); } ConstIterator end() const { return begin() + Size(); } template ValueType& operator()(PackedLong... ii) { return begin()[offset<0>(ii...)]; } template ValueType operator()(PackedLong... ii) const { return begin()[offset<0>(ii...)]; } typename RotateType<1,true,Args...>::Value RotateLeft() const { typename RotateType<1,true,Args...>::Value Tr; const auto& T = *this; constexpr Long N0 = Dim<0>(); constexpr Long N1 = Size() / N0; for (Long i = 0; i < N0; i++) { for (Long j = 0; j < N1; j++) { Tr.begin()[j*N0+i] = T.begin()[i*N1+j]; } } return Tr; } typename RotateType::Value RotateRight() const { typename RotateType::Value Tr; const auto& T = *this; constexpr Long N0 = Dim(); constexpr Long N1 = Size() / N0; for (Long i = 0; i < N0; i++) { for (Long j = 0; j < N1; j++) { Tr.begin()[i*N1+j] = T.begin()[j*N0+i]; } } return Tr; } Tensor operator*(const ValueType &s) const { Tensor M0; const auto &M1 = *this; for (Long i = 0; i < Size(); i++) { M0.begin()[i] = M1.begin()[i]*s; } return M0; } template Tensor operator+(const Tensor &M2) const { Tensor M0; const auto &M1 = *this; for (Long i = 0; i < Size(); i++) { M0.begin()[i] = M1.begin()[i] + M2.begin()[i]; } return M0; } template Tensor operator-(const Tensor &M2) const { Tensor M0; const auto &M1 = *this; for (Long i = 0; i < Size(); i++) { M0.begin()[i] = M1.begin()[i] - M2.begin()[i]; } return M0; } template Tensor(), N2> operator*(const Tensor &M2) const { static_assert(Order() == 2, "Multiplication is only defined for tensors of order two."); static_assert(Dim<1>() == N1, "Tensor dimensions dont match for multiplication."); Tensor(), N2> M0; const auto &M1 = *this; for (Long i = 0; i < Dim<0>(); i++) { for (Long j = 0; j < N2; j++) { ValueType Mij = 0; for (Long k = 0; k < N1; k++) { Mij += M1(i,k)*M2(k,j); } M0(i,j) = Mij; } } return M0; } private: template static Long offset() { return 0; } template static Long offset(Long i, PackedLong... ii) { return i * SizeHelper<-(k+1),Args...>() + offset(ii...); } void Init(Iterator src_iter) { if (own_data) { if (src_iter != NullIterator()) { memcopy((Iterator)buff, src_iter, Size()); } } else { if (Size()) { SCTL_UNUSED(src_iter[0]); SCTL_UNUSED(src_iter[Size()-1]); iter_[0] = Ptr2Itr(&src_iter[0], Size()); } else { iter_[0] = NullIterator(); } } } StaticArray buff; StaticArray,own_data?0:1> iter_; }; template std::ostream& operator<<(std::ostream &output, const Tensor &M) { std::ios::fmtflags f(std::cout.flags()); output << std::fixed << std::setprecision(4) << std::setiosflags(std::ios::left); for (Long i = 0; i < N1; i++) { for (Long j = 0; j < N2; j++) { float f = ((float)M(i,j)); if (sctl::fabs(f) < 1e-25) f = 0; output << std::setw(10) << ((double)f) << ' '; } output << ";\n"; } std::cout.flags(f); return output; } } // end namespace #endif //_SCTL_TENSOR_HPP_