mat.hpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. #ifndef _SCTL_MAT_HPP_
  2. #define _SCTL_MAT_HPP_
  3. namespace biest {
  4. template <class Real, sctl::Integer N1, sctl::Integer N2, bool own_data = true> class Mat {
  5. public:
  6. Mat() {
  7. static_assert(own_data,"A data pointer must be provided when own_data=false.");
  8. iter_ = buff;
  9. }
  10. Mat(sctl::Iterator<Real> src_iter) { Init(src_iter); }
  11. Mat(sctl::ConstIterator<Real> src_iter) { ConstInit(src_iter); }
  12. Mat(const Mat &M) { ConstInit(M.begin()); }
  13. template <bool own_data_> Mat(Mat<Real, N1, N2, own_data_> &M) { Init(M.begin()); }
  14. template <bool own_data_> Mat(const Mat<Real, N1, N2, own_data_> &M) { ConstInit(M.begin()); }
  15. Mat &operator=(const Mat &M) {
  16. auto src_iter = M.begin();
  17. for (sctl::Integer i = 0; i < N1 * N2; i++) this->begin()[i] = src_iter[i];
  18. return *this;
  19. }
  20. template <bool own_data_> Mat &operator=(const Mat<Real, N1, N2, own_data_> &M) {
  21. auto src_iter = M.begin();
  22. for (sctl::Integer i = 0; i < N1 * N2; i++) this->begin()[i] = src_iter[i];
  23. return *this;
  24. }
  25. sctl::Integer Dim0() const { return N1; }
  26. sctl::Integer Dim1() const { return N2; }
  27. sctl::Iterator<Real> begin() { return iter_; }
  28. sctl::ConstIterator<Real> begin() const { return iter_; }
  29. Mat<Real, N1, N2> operator*(const Real &s) const {
  30. Mat<Real, N1, N2> M0;
  31. const auto &M1 = *this;
  32. for (sctl::Integer i1 = 0; i1 < N1; i1++) {
  33. for (sctl::Integer i2 = 0; i2 < N2; i2++) {
  34. M0[i1][i2] = M1[i1][i2] * s;
  35. }
  36. }
  37. return M0;
  38. }
  39. template <sctl::Integer N3, bool own_data_> Mat<Real, N1, N3> operator*(const Mat<Real, N2, N3, own_data_> &M2) const {
  40. Mat<Real, N1, N3> M0;
  41. const auto &M1 = *this;
  42. for (sctl::Integer i1 = 0; i1 < N1; i1++) {
  43. for (sctl::Integer i3 = 0; i3 < N3; i3++) {
  44. Real v = 0;
  45. for (sctl::Integer i2 = 0; i2 < N2; i2++) {
  46. v += M1[i1][i2] * M2[i2][i3];
  47. }
  48. M0[i1][i3] = v;
  49. }
  50. }
  51. return M0;
  52. }
  53. Mat<Real, N1, N2> operator+(const Mat<Real, N1, N2> &M2) const {
  54. Mat<Real, N1, N2> M0;
  55. const auto &M1 = *this;
  56. for (sctl::Integer i1 = 0; i1 < N1; i1++) {
  57. for (sctl::Integer i2 = 0; i2 < N2; i2++) {
  58. M0[i1][i2] = M1[i1][i2] + M2[i1][i2];
  59. }
  60. }
  61. return M0;
  62. }
  63. Mat<Real, N1, N2> operator-(const Mat<Real, N1, N2> &M2) const {
  64. Mat<Real, N1, N2> M0;
  65. const auto &M1 = *this;
  66. for (sctl::Integer i1 = 0; i1 < N1; i1++) {
  67. for (sctl::Integer i2 = 0; i2 < N2; i2++) {
  68. M0[i1][i2] = M1[i1][i2] - M2[i1][i2];
  69. }
  70. }
  71. return M0;
  72. }
  73. sctl::Iterator<Real> operator[](sctl::Integer i) {
  74. #ifdef SCTL_MEMDEBUG
  75. SCTL_ASSERT(i < N1);
  76. #endif
  77. return iter_ + i * N2;
  78. }
  79. sctl::ConstIterator<Real> operator[](sctl::Integer i) const {
  80. #ifdef SCTL_MEMDEBUG
  81. SCTL_ASSERT(i < N1);
  82. #endif
  83. return iter_ + i * N2;
  84. }
  85. Mat<Real, N2, N1> Transpose() const {
  86. Mat<Real, N2, N1> M0;
  87. const auto &M1 = *this;
  88. for (sctl::Integer i1 = 0; i1 < N1; i1++) {
  89. for (sctl::Integer i2 = 0; i2 < N2; i2++) {
  90. M0[i2][i1] = M1[i1][i2];
  91. }
  92. }
  93. return M0;
  94. }
  95. Real Trace() const {
  96. Real sum = 0;
  97. const auto &M1 = *this;
  98. static_assert(N1 == N2,"Cannot compute trace of non-square matrix.");
  99. for (sctl::Integer i = 0; i < N1; i++) sum += M1[i][i];
  100. return sum;
  101. }
  102. bool OwnData() const { return own_data; }
  103. private:
  104. void ConstInit(sctl::ConstIterator<Real> src_iter) {
  105. iter_ = buff;
  106. static_assert(own_data,"Data must be modifiable when own_data=false.");
  107. for (sctl::Integer i = 0; i < N1 * N2; i++) this->begin()[i] = src_iter[i];
  108. }
  109. void Init(sctl::Iterator<Real> src_iter) {
  110. if (own_data) {
  111. iter_ = buff;
  112. for (sctl::Integer i = 0; i < N1 * N2; i++) this->begin()[i] = src_iter[i];
  113. } else {
  114. iter_ = src_iter;
  115. #ifdef SCTL_MEMDEBUG
  116. if (N1 && N2) {
  117. SCTL_UNUSED(src_iter[0]);
  118. SCTL_UNUSED(src_iter[N1 * N2 - 1]);
  119. }
  120. #endif
  121. }
  122. }
  123. sctl::Iterator<Real> iter_;
  124. sctl::StaticArray<Real, own_data ? N1 * N2 : 0> buff;
  125. };
  126. template <class Real, sctl::Integer N1, sctl::Integer N2, bool own_data> Mat<Real, N1, N2> operator*(Real s, const Mat<Real, N1, N2, own_data> &M1) {
  127. Mat<Real, N1, N2> M0;
  128. for (sctl::Integer i1 = 0; i1 < N1; i1++) {
  129. for (sctl::Integer i2 = 0; i2 < N2; i2++) {
  130. M0[i1][i2] = M1[i1][i2] * s;
  131. }
  132. }
  133. return M0;
  134. }
  135. template <class Real, sctl::Integer N1, sctl::Integer N2, bool own_data> std::ostream &operator<<(std::ostream &output, const Mat<Real, N1, N2, own_data> &M) {
  136. std::ios::fmtflags f(std::cout.flags());
  137. output << std::fixed << std::setprecision(4) << std::setiosflags(std::ios::left);
  138. for (sctl::Long i = 0; i < N1; i++) {
  139. for (sctl::Long j = 0; j < N2; j++) {
  140. float f = ((float)M[i][j]);
  141. if (sctl::fabs<Real>(f) < 1e-25) f = 0;
  142. output << std::setw(10) << ((double)f) << ' ';
  143. }
  144. output << ";\n";
  145. }
  146. std::cout.flags(f);
  147. return output;
  148. }
  149. }
  150. #endif //_SCTL_MAT_HPP_