tensor.hpp 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. #ifndef _SCTL_TENSOR_HPP_
  2. #define _SCTL_TENSOR_HPP_
  3. #include SCTL_INCLUDE(mem_mgr.hpp)
  4. #include SCTL_INCLUDE(common.hpp)
  5. #include <iostream>
  6. namespace SCTL_NAMESPACE {
  7. template <class ValueType, bool own_data, Long... Args> class Tensor {
  8. template <Long k> static constexpr Long SizeHelper() {
  9. return 1;
  10. }
  11. template <Long k, Long d, Long... dd> static constexpr Long SizeHelper() {
  12. return (k >= 0 ? d : 1) * SizeHelper<k+1, dd...>();
  13. }
  14. template <Long k> static constexpr Long DimHelper() {
  15. return 1;
  16. }
  17. template <Long k, Long d0, Long... dd> static constexpr Long DimHelper() {
  18. return k==0 ? d0 : DimHelper<k-1,dd...>();
  19. }
  20. template <typename T> static constexpr Long OrderHelper() {
  21. return 0;
  22. }
  23. template <typename T, Long d, Long... dd> static constexpr Long OrderHelper() {
  24. return 1 + OrderHelper<void, dd...>();
  25. }
  26. template <Long k, bool own_data_, Long... dd> struct RotateType {
  27. using Value = Tensor<ValueType,own_data_,dd...>;
  28. };
  29. template <bool own_data_, Long d, Long... dd> struct RotateType<0,own_data_,d,dd...> {
  30. using Value = Tensor<ValueType,own_data_,d,dd...>;
  31. };
  32. template <Long k, bool own_data_, Long d, Long... dd> struct RotateType<k,own_data_,d,dd...> {
  33. using Value = typename RotateType<k-1,own_data_,dd...,d>::Value;
  34. };
  35. public:
  36. static constexpr Long Order() {
  37. return OrderHelper<void, Args...>();
  38. }
  39. static constexpr Long Size() {
  40. return SizeHelper<0,Args...>();
  41. }
  42. template <Long k> static constexpr Long Dim() {
  43. return DimHelper<k,Args...>();
  44. }
  45. Tensor(Iterator<ValueType> src_iter = NullIterator<ValueType>()) {
  46. Init(src_iter);
  47. }
  48. Tensor(const Tensor &M) {
  49. Init((Iterator<ValueType>)M.begin());
  50. }
  51. template <bool own_data_> Tensor(const Tensor<ValueType,own_data_,Args...> &M) {
  52. Init((Iterator<ValueType>)M.begin());
  53. }
  54. Tensor &operator=(const Tensor &M) {
  55. memcopy(begin(), M.begin(), Size());
  56. return *this;
  57. }
  58. Iterator<ValueType> begin() {
  59. return own_data ? (Iterator<ValueType>)buff : iter_[0];
  60. }
  61. ConstIterator<ValueType> begin() const {
  62. return own_data ? (ConstIterator<ValueType>)buff : (ConstIterator<ValueType>)iter_[0];
  63. }
  64. Iterator<ValueType> end() {
  65. return begin() + Size();
  66. }
  67. ConstIterator<ValueType> end() const {
  68. return begin() + Size();
  69. }
  70. template <class ...PackedLong> ValueType& operator()(PackedLong... ii) {
  71. return begin()[offset<0>(ii...)];
  72. }
  73. template <class ...PackedLong> ValueType operator()(PackedLong... ii) const {
  74. return begin()[offset<0>(ii...)];
  75. }
  76. typename RotateType<1,true,Args...>::Value RotateLeft() const {
  77. typename RotateType<1,true,Args...>::Value Tr;
  78. const auto& T = *this;
  79. constexpr Long N0 = Dim<0>();
  80. constexpr Long N1 = Size() / N0;
  81. for (Long i = 0; i < N0; i++) {
  82. for (Long j = 0; j < N1; j++) {
  83. Tr.begin()[j*N0+i] = T.begin()[i*N1+j];
  84. }
  85. }
  86. return Tr;
  87. }
  88. typename RotateType<Order()-1,true,Args...>::Value RotateRight() const {
  89. typename RotateType<Order()-1,true,Args...>::Value Tr;
  90. const auto& T = *this;
  91. constexpr Long N0 = Dim<Order()-1>();
  92. constexpr Long N1 = Size() / N0;
  93. for (Long i = 0; i < N0; i++) {
  94. for (Long j = 0; j < N1; j++) {
  95. Tr.begin()[i*N1+j] = T.begin()[j*N0+i];
  96. }
  97. }
  98. return Tr;
  99. }
  100. Tensor<ValueType, true, Args...> operator*(const ValueType &s) const {
  101. Tensor<ValueType, true, Args...> M0;
  102. const auto &M1 = *this;
  103. for (Long i = 0; i < Size(); i++) {
  104. M0.begin()[i] = M1.begin()[i]*s;
  105. }
  106. return M0;
  107. }
  108. template <bool own_data_> Tensor<ValueType, true, Args...> operator+(const Tensor<ValueType, own_data_, Args...> &M2) const {
  109. Tensor<ValueType, true, Args...> M0;
  110. const auto &M1 = *this;
  111. for (Long i = 0; i < Size(); i++) {
  112. M0.begin()[i] = M1.begin()[i] + M2.begin()[i];
  113. }
  114. return M0;
  115. }
  116. template <bool own_data_> Tensor<ValueType, true, Args...> operator-(const Tensor<ValueType, own_data_, Args...> &M2) const {
  117. Tensor<ValueType, true, Args...> M0;
  118. const auto &M1 = *this;
  119. for (Long i = 0; i < Size(); i++) {
  120. M0.begin()[i] = M1.begin()[i] - M2.begin()[i];
  121. }
  122. return M0;
  123. }
  124. template <bool own_data_, Long N1, Long N2> Tensor<ValueType, true, Dim<0>(), N2> operator*(const Tensor<ValueType, own_data_, N1, N2> &M2) const {
  125. static_assert(Order() == 2, "Multiplication is only defined for tensors of order two.");
  126. static_assert(Dim<1>() == N1, "Tensor dimensions dont match for multiplication.");
  127. Tensor<ValueType, true, Dim<0>(), N2> M0;
  128. const auto &M1 = *this;
  129. for (Long i = 0; i < Dim<0>(); i++) {
  130. for (Long j = 0; j < N2; j++) {
  131. ValueType Mij = 0;
  132. for (Long k = 0; k < N1; k++) {
  133. Mij += M1(i,k)*M2(k,j);
  134. }
  135. M0(i,j) = Mij;
  136. }
  137. }
  138. return M0;
  139. }
  140. private:
  141. template <Integer k> static Long offset() {
  142. return 0;
  143. }
  144. template <Integer k, class ...PackedLong> static Long offset(Long i, PackedLong... ii) {
  145. return i * SizeHelper<-(k+1),Args...>() + offset<k+1>(ii...);
  146. }
  147. void Init(Iterator<ValueType> src_iter) {
  148. if (own_data) {
  149. if (src_iter != NullIterator<ValueType>()) {
  150. memcopy((Iterator<ValueType>)buff, src_iter, Size());
  151. }
  152. } else {
  153. if (Size()) {
  154. SCTL_UNUSED(src_iter[0]);
  155. SCTL_UNUSED(src_iter[Size()-1]);
  156. iter_[0] = Ptr2Itr<ValueType>(&src_iter[0], Size());
  157. } else {
  158. iter_[0] = NullIterator<ValueType>();
  159. }
  160. }
  161. }
  162. StaticArray<ValueType,own_data?Size():0> buff;
  163. StaticArray<Iterator<ValueType>,own_data?0:1> iter_;
  164. };
  165. template <class ValueType, bool own_data, Long N1, Long N2> std::ostream& operator<<(std::ostream &output, const Tensor<ValueType, own_data, N1, N2> &M) {
  166. std::ios::fmtflags f(std::cout.flags());
  167. output << std::fixed << std::setprecision(4) << std::setiosflags(std::ios::left);
  168. for (Long i = 0; i < N1; i++) {
  169. for (Long j = 0; j < N2; j++) {
  170. float f = ((float)M(i,j));
  171. if (sctl::fabs<float>(f) < 1e-25) f = 0;
  172. output << std::setw(10) << ((double)f) << ' ';
  173. }
  174. output << ";\n";
  175. }
  176. std::cout.flags(f);
  177. return output;
  178. }
  179. } // end namespace
  180. #endif //_SCTL_TENSOR_HPP_