tensor.hpp 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. #ifndef _SCTL_TENSOR_HPP_
  2. #define _SCTL_TENSOR_HPP_
  3. #include <sctl/common.hpp>
  4. #include SCTL_INCLUDE(mem_mgr.hpp)
  5. #include SCTL_INCLUDE(math_utils.hpp)
  6. #include <iostream>
  7. #include <iomanip>
  8. namespace SCTL_NAMESPACE {
  9. template <class ValueType, bool own_data, Long... Args> class Tensor {
  10. template <Long k> static constexpr Long SizeHelper() {
  11. return 1;
  12. }
  13. template <Long k, Long d, Long... dd> static constexpr Long SizeHelper() {
  14. return (k >= 0 ? d : 1) * SizeHelper<k+1, dd...>();
  15. }
  16. template <Long k> static constexpr Long DimHelper() {
  17. return 1;
  18. }
  19. template <Long k, Long d0, Long... dd> static constexpr Long DimHelper() {
  20. return k==0 ? d0 : DimHelper<k-1,dd...>();
  21. }
  22. template <typename T> static constexpr Long OrderHelper() {
  23. return 0;
  24. }
  25. template <typename T, Long d, Long... dd> static constexpr Long OrderHelper() {
  26. return 1 + OrderHelper<void, dd...>();
  27. }
  28. template <Long k, bool own_data_, Long... dd> struct RotateType {
  29. using Value = Tensor<ValueType,own_data_,dd...>;
  30. };
  31. template <bool own_data_, Long d, Long... dd> struct RotateType<0,own_data_,d,dd...> {
  32. using Value = Tensor<ValueType,own_data_,d,dd...>;
  33. };
  34. template <Long k, bool own_data_, Long d, Long... dd> struct RotateType<k,own_data_,d,dd...> {
  35. using Value = typename RotateType<k-1,own_data_,dd...,d>::Value;
  36. };
  37. public:
  38. static constexpr Long Order() {
  39. return OrderHelper<void, Args...>();
  40. }
  41. static constexpr Long Size() {
  42. return SizeHelper<0,Args...>();
  43. }
  44. template <Long k> static constexpr Long Dim() {
  45. return DimHelper<k,Args...>();
  46. }
  47. Tensor(Iterator<ValueType> src_iter = NullIterator<ValueType>()) {
  48. Init(src_iter);
  49. }
  50. Tensor(const Tensor &M) {
  51. Init((Iterator<ValueType>)M.begin());
  52. }
  53. explicit Tensor(const ValueType& v) {
  54. static_assert(own_data || Size() == 0, "Memory pointer must be provided to initialize Tensor types with own_data=false");
  55. Init(NullIterator<ValueType>());
  56. for (auto& x : *this) x = v;
  57. }
  58. template <bool own_data_> Tensor(const Tensor<ValueType,own_data_,Args...> &M) {
  59. static_assert(Size() <= M.Size(), "Initializer must be at least the size of the object.");
  60. Init((Iterator<ValueType>)M.begin());
  61. }
  62. Tensor &operator=(const Tensor &M) {
  63. memcopy(begin(), M.begin(), Size());
  64. return *this;
  65. }
  66. Tensor &operator=(const ValueType& v) {
  67. for (auto& x : *this) x = v;
  68. return *this;
  69. }
  70. Iterator<ValueType> begin() {
  71. return own_data ? (Iterator<ValueType>)buff : iter_[0];
  72. }
  73. ConstIterator<ValueType> begin() const {
  74. return own_data ? (ConstIterator<ValueType>)buff : (ConstIterator<ValueType>)iter_[0];
  75. }
  76. Iterator<ValueType> end() {
  77. return begin() + Size();
  78. }
  79. ConstIterator<ValueType> end() const {
  80. return begin() + Size();
  81. }
  82. template <class ...PackedLong> ValueType& operator()(PackedLong... ii) {
  83. return begin()[offset<0>(ii...)];
  84. }
  85. template <class ...PackedLong> ValueType operator()(PackedLong... ii) const {
  86. return begin()[offset<0>(ii...)];
  87. }
  88. typename RotateType<1,true,Args...>::Value RotateLeft() const {
  89. typename RotateType<1,true,Args...>::Value Tr;
  90. const auto& T = *this;
  91. constexpr Long N0 = Dim<0>();
  92. constexpr Long N1 = Size() / N0;
  93. for (Long i = 0; i < N0; i++) {
  94. for (Long j = 0; j < N1; j++) {
  95. Tr.begin()[j*N0+i] = T.begin()[i*N1+j];
  96. }
  97. }
  98. return Tr;
  99. }
  100. typename RotateType<Order()-1,true,Args...>::Value RotateRight() const {
  101. typename RotateType<Order()-1,true,Args...>::Value Tr;
  102. const auto& T = *this;
  103. constexpr Long N0 = Dim<Order()-1>();
  104. constexpr Long N1 = Size() / N0;
  105. for (Long i = 0; i < N0; i++) {
  106. for (Long j = 0; j < N1; j++) {
  107. Tr.begin()[i*N1+j] = T.begin()[j*N0+i];
  108. }
  109. }
  110. return Tr;
  111. }
  112. Tensor<ValueType, true, Args...> operator-() const {
  113. Tensor<ValueType, true, Args...> M0;
  114. const auto &M1 = *this;
  115. for (Long i = 0; i < Size(); i++) {
  116. M0.begin()[i] = -M1.begin()[i];
  117. }
  118. return M0;
  119. }
  120. Tensor<ValueType, true, Args...> operator*(const ValueType &s) const {
  121. Tensor<ValueType, true, Args...> M0;
  122. const auto &M1 = *this;
  123. for (Long i = 0; i < Size(); i++) {
  124. M0.begin()[i] = M1.begin()[i]*s;
  125. }
  126. return M0;
  127. }
  128. Tensor<ValueType, true, Args...> operator/(const ValueType &s) const {
  129. Tensor<ValueType, true, Args...> M0;
  130. const auto &M1 = *this;
  131. for (Long i = 0; i < Size(); i++) {
  132. M0.begin()[i] = M1.begin()[i]/s;
  133. }
  134. return M0;
  135. }
  136. template <bool own_data_> Tensor<ValueType, true, Args...> operator+(const Tensor<ValueType, own_data_, Args...> &M2) const {
  137. Tensor<ValueType, true, Args...> M0;
  138. const auto &M1 = *this;
  139. for (Long i = 0; i < Size(); i++) {
  140. M0.begin()[i] = M1.begin()[i] + M2.begin()[i];
  141. }
  142. return M0;
  143. }
  144. template <bool own_data_> Tensor<ValueType, true, Args...> operator-(const Tensor<ValueType, own_data_, Args...> &M2) const {
  145. Tensor<ValueType, true, Args...> M0;
  146. const auto &M1 = *this;
  147. for (Long i = 0; i < Size(); i++) {
  148. M0.begin()[i] = M1.begin()[i] - M2.begin()[i];
  149. }
  150. return M0;
  151. }
  152. template <bool own_data_, Long N1, Long N2> Tensor<ValueType, true, Dim<0>(), N2> operator*(const Tensor<ValueType, own_data_, N1, N2> &M2) const {
  153. static_assert(Order() == 2, "Multiplication is only defined for tensors of order two.");
  154. static_assert(Dim<1>() == N1, "Tensor dimensions dont match for multiplication.");
  155. Tensor<ValueType, true, Dim<0>(), N2> M0;
  156. const auto &M1 = *this;
  157. for (Long i = 0; i < Dim<0>(); i++) {
  158. for (Long j = 0; j < N2; j++) {
  159. ValueType Mij = 0;
  160. for (Long k = 0; k < N1; k++) {
  161. Mij += M1(i,k)*M2(k,j);
  162. }
  163. M0(i,j) = Mij;
  164. }
  165. }
  166. return M0;
  167. }
  168. private:
  169. template <Integer k> static Long offset() {
  170. return 0;
  171. }
  172. template <Integer k, class ...PackedLong> static Long offset(Long i, PackedLong... ii) {
  173. return i * SizeHelper<-(k+1),Args...>() + offset<k+1>(ii...);
  174. }
  175. void Init(Iterator<ValueType> src_iter) {
  176. if (own_data) {
  177. if (src_iter != NullIterator<ValueType>()) {
  178. memcopy((Iterator<ValueType>)buff, src_iter, Size());
  179. }
  180. } else {
  181. if (Size()) {
  182. SCTL_UNUSED(src_iter[0]);
  183. SCTL_UNUSED(src_iter[Size()-1]);
  184. iter_[0] = Ptr2Itr<ValueType>(&src_iter[0], Size());
  185. } else {
  186. iter_[0] = NullIterator<ValueType>();
  187. }
  188. }
  189. }
  190. StaticArray<ValueType,own_data?Size():0> buff;
  191. StaticArray<Iterator<ValueType>,own_data?0:1> iter_;
  192. };
  193. template <class ValueType, bool own_data, Long N1, Long N2> std::ostream& operator<<(std::ostream &output, const Tensor<ValueType, own_data, N1, N2> &M) {
  194. std::ios::fmtflags f(std::cout.flags());
  195. output << std::fixed << std::setprecision(4) << std::setiosflags(std::ios::left);
  196. for (Long i = 0; i < N1; i++) {
  197. for (Long j = 0; j < N2; j++) {
  198. float f = ((float)M(i,j));
  199. if (sctl::fabs<float>(f) < 1e-25) f = 0;
  200. output << std::setw(10) << ((double)f) << ' ';
  201. }
  202. output << ";\n";
  203. }
  204. std::cout.flags(f);
  205. return output;
  206. }
  207. } // end namespace
  208. #endif //_SCTL_TENSOR_HPP_