tensor.hpp 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. #ifndef _SCTL_TENSOR_HPP_
  2. #define _SCTL_TENSOR_HPP_
  3. #include SCTL_INCLUDE(mem_mgr.hpp)
  4. #include SCTL_INCLUDE(math_utils.hpp)
  5. #include SCTL_INCLUDE(common.hpp)
  6. #include <iostream>
  7. namespace SCTL_NAMESPACE {
  8. template <class ValueType, bool own_data, Long... Args> class Tensor {
  9. template <Long k> static constexpr Long SizeHelper() {
  10. return 1;
  11. }
  12. template <Long k, Long d, Long... dd> static constexpr Long SizeHelper() {
  13. return (k >= 0 ? d : 1) * SizeHelper<k+1, dd...>();
  14. }
  15. template <Long k> static constexpr Long DimHelper() {
  16. return 1;
  17. }
  18. template <Long k, Long d0, Long... dd> static constexpr Long DimHelper() {
  19. return k==0 ? d0 : DimHelper<k-1,dd...>();
  20. }
  21. template <typename T> static constexpr Long OrderHelper() {
  22. return 0;
  23. }
  24. template <typename T, Long d, Long... dd> static constexpr Long OrderHelper() {
  25. return 1 + OrderHelper<void, dd...>();
  26. }
  27. template <Long k, bool own_data_, Long... dd> struct RotateType {
  28. using Value = Tensor<ValueType,own_data_,dd...>;
  29. };
  30. template <bool own_data_, Long d, Long... dd> struct RotateType<0,own_data_,d,dd...> {
  31. using Value = Tensor<ValueType,own_data_,d,dd...>;
  32. };
  33. template <Long k, bool own_data_, Long d, Long... dd> struct RotateType<k,own_data_,d,dd...> {
  34. using Value = typename RotateType<k-1,own_data_,dd...,d>::Value;
  35. };
  36. public:
  37. static constexpr Long Order() {
  38. return OrderHelper<void, Args...>();
  39. }
  40. static constexpr Long Size() {
  41. return SizeHelper<0,Args...>();
  42. }
  43. template <Long k> static constexpr Long Dim() {
  44. return DimHelper<k,Args...>();
  45. }
  46. Tensor(Iterator<ValueType> src_iter = NullIterator<ValueType>()) {
  47. Init(src_iter);
  48. }
  49. Tensor(const Tensor &M) {
  50. Init((Iterator<ValueType>)M.begin());
  51. }
  52. template <bool own_data_> Tensor(const Tensor<ValueType,own_data_,Args...> &M) {
  53. Init((Iterator<ValueType>)M.begin());
  54. }
  55. Tensor &operator=(const Tensor &M) {
  56. memcopy(begin(), M.begin(), Size());
  57. return *this;
  58. }
  59. Iterator<ValueType> begin() {
  60. return own_data ? (Iterator<ValueType>)buff : iter_[0];
  61. }
  62. ConstIterator<ValueType> begin() const {
  63. return own_data ? (ConstIterator<ValueType>)buff : (ConstIterator<ValueType>)iter_[0];
  64. }
  65. Iterator<ValueType> end() {
  66. return begin() + Size();
  67. }
  68. ConstIterator<ValueType> end() const {
  69. return begin() + Size();
  70. }
  71. template <class ...PackedLong> ValueType& operator()(PackedLong... ii) {
  72. return begin()[offset<0>(ii...)];
  73. }
  74. template <class ...PackedLong> ValueType operator()(PackedLong... ii) const {
  75. return begin()[offset<0>(ii...)];
  76. }
  77. typename RotateType<1,true,Args...>::Value RotateLeft() const {
  78. typename RotateType<1,true,Args...>::Value Tr;
  79. const auto& T = *this;
  80. constexpr Long N0 = Dim<0>();
  81. constexpr Long N1 = Size() / N0;
  82. for (Long i = 0; i < N0; i++) {
  83. for (Long j = 0; j < N1; j++) {
  84. Tr.begin()[j*N0+i] = T.begin()[i*N1+j];
  85. }
  86. }
  87. return Tr;
  88. }
  89. typename RotateType<Order()-1,true,Args...>::Value RotateRight() const {
  90. typename RotateType<Order()-1,true,Args...>::Value Tr;
  91. const auto& T = *this;
  92. constexpr Long N0 = Dim<Order()-1>();
  93. constexpr Long N1 = Size() / N0;
  94. for (Long i = 0; i < N0; i++) {
  95. for (Long j = 0; j < N1; j++) {
  96. Tr.begin()[i*N1+j] = T.begin()[j*N0+i];
  97. }
  98. }
  99. return Tr;
  100. }
  101. Tensor<ValueType, true, Args...> operator*(const ValueType &s) const {
  102. Tensor<ValueType, true, Args...> M0;
  103. const auto &M1 = *this;
  104. for (Long i = 0; i < Size(); i++) {
  105. M0.begin()[i] = M1.begin()[i]*s;
  106. }
  107. return M0;
  108. }
  109. template <bool own_data_> Tensor<ValueType, true, Args...> operator+(const Tensor<ValueType, own_data_, Args...> &M2) const {
  110. Tensor<ValueType, true, Args...> M0;
  111. const auto &M1 = *this;
  112. for (Long i = 0; i < Size(); i++) {
  113. M0.begin()[i] = M1.begin()[i] + M2.begin()[i];
  114. }
  115. return M0;
  116. }
  117. template <bool own_data_> Tensor<ValueType, true, Args...> operator-(const Tensor<ValueType, own_data_, Args...> &M2) const {
  118. Tensor<ValueType, true, Args...> M0;
  119. const auto &M1 = *this;
  120. for (Long i = 0; i < Size(); i++) {
  121. M0.begin()[i] = M1.begin()[i] - M2.begin()[i];
  122. }
  123. return M0;
  124. }
  125. template <bool own_data_, Long N1, Long N2> Tensor<ValueType, true, Dim<0>(), N2> operator*(const Tensor<ValueType, own_data_, N1, N2> &M2) const {
  126. static_assert(Order() == 2, "Multiplication is only defined for tensors of order two.");
  127. static_assert(Dim<1>() == N1, "Tensor dimensions dont match for multiplication.");
  128. Tensor<ValueType, true, Dim<0>(), N2> M0;
  129. const auto &M1 = *this;
  130. for (Long i = 0; i < Dim<0>(); i++) {
  131. for (Long j = 0; j < N2; j++) {
  132. ValueType Mij = 0;
  133. for (Long k = 0; k < N1; k++) {
  134. Mij += M1(i,k)*M2(k,j);
  135. }
  136. M0(i,j) = Mij;
  137. }
  138. }
  139. return M0;
  140. }
  141. private:
  142. template <Integer k> static Long offset() {
  143. return 0;
  144. }
  145. template <Integer k, class ...PackedLong> static Long offset(Long i, PackedLong... ii) {
  146. return i * SizeHelper<-(k+1),Args...>() + offset<k+1>(ii...);
  147. }
  148. void Init(Iterator<ValueType> src_iter) {
  149. if (own_data) {
  150. if (src_iter != NullIterator<ValueType>()) {
  151. memcopy((Iterator<ValueType>)buff, src_iter, Size());
  152. }
  153. } else {
  154. if (Size()) {
  155. SCTL_UNUSED(src_iter[0]);
  156. SCTL_UNUSED(src_iter[Size()-1]);
  157. iter_[0] = Ptr2Itr<ValueType>(&src_iter[0], Size());
  158. } else {
  159. iter_[0] = NullIterator<ValueType>();
  160. }
  161. }
  162. }
  163. StaticArray<ValueType,own_data?Size():0> buff;
  164. StaticArray<Iterator<ValueType>,own_data?0:1> iter_;
  165. };
  166. template <class ValueType, bool own_data, Long N1, Long N2> std::ostream& operator<<(std::ostream &output, const Tensor<ValueType, own_data, N1, N2> &M) {
  167. std::ios::fmtflags f(std::cout.flags());
  168. output << std::fixed << std::setprecision(4) << std::setiosflags(std::ios::left);
  169. for (Long i = 0; i < N1; i++) {
  170. for (Long j = 0; j < N2; j++) {
  171. float f = ((float)M(i,j));
  172. if (sctl::fabs<float>(f) < 1e-25) f = 0;
  173. output << std::setw(10) << ((double)f) << ' ';
  174. }
  175. output << ";\n";
  176. }
  177. std::cout.flags(f);
  178. return output;
  179. }
  180. } // end namespace
  181. #endif //_SCTL_TENSOR_HPP_