mat_utils.txx 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534
  1. #include SCTL_INCLUDE(matrix.hpp)
  2. #include SCTL_INCLUDE(math_utils.hpp)
  3. #if defined(SCTL_HAVE_CUDA)
  4. #include <cuda_runtime_api.h>
  5. #include <cublas_v2.h>
  6. #endif
  7. #if defined(SCTL_HAVE_BLAS)
  8. #include SCTL_INCLUDE(blas.h)
  9. #endif
  10. #if defined(SCTL_HAVE_LAPACK)
  11. #include SCTL_INCLUDE(lapack.h)
  12. #endif
  13. #include <omp.h>
  14. #include <cmath>
  15. #include <cassert>
  16. #include <cstdlib>
  17. #include <algorithm>
  18. #include <iostream>
  19. #include <vector>
  20. namespace SCTL_NAMESPACE {
  21. namespace mat {
  22. template <class ValueType> inline void gemm(char TransA, char TransB, int M, int N, int K, ValueType alpha, Iterator<ValueType> A, int lda, Iterator<ValueType> B, int ldb, ValueType beta, Iterator<ValueType> C, int ldc) {
  23. if ((TransA == 'N' || TransA == 'n') && (TransB == 'N' || TransB == 'n')) {
  24. #pragma omp parallel for schedule(static)
  25. for (Long n = 0; n < N; n++) { // Columns of C
  26. for (Long m = 0; m < M; m++) { // Rows of C
  27. ValueType AxB = 0;
  28. for (Long k = 0; k < K; k++) {
  29. AxB += A[m + lda * k] * B[k + ldb * n];
  30. }
  31. C[m + ldc * n] = alpha * AxB + (beta == 0 ? 0 : beta * C[m + ldc * n]);
  32. }
  33. }
  34. } else if (TransA == 'N' || TransA == 'n') {
  35. #pragma omp parallel for schedule(static)
  36. for (Long n = 0; n < N; n++) { // Columns of C
  37. for (Long m = 0; m < M; m++) { // Rows of C
  38. ValueType AxB = 0;
  39. for (Long k = 0; k < K; k++) {
  40. AxB += A[m + lda * k] * B[n + ldb * k];
  41. }
  42. C[m + ldc * n] = alpha * AxB + (beta == 0 ? 0 : beta * C[m + ldc * n]);
  43. }
  44. }
  45. } else if (TransB == 'N' || TransB == 'n') {
  46. #pragma omp parallel for schedule(static)
  47. for (Long n = 0; n < N; n++) { // Columns of C
  48. for (Long m = 0; m < M; m++) { // Rows of C
  49. ValueType AxB = 0;
  50. for (Long k = 0; k < K; k++) {
  51. AxB += A[k + lda * m] * B[k + ldb * n];
  52. }
  53. C[m + ldc * n] = alpha * AxB + (beta == 0 ? 0 : beta * C[m + ldc * n]);
  54. }
  55. }
  56. } else {
  57. #pragma omp parallel for schedule(static)
  58. for (Long n = 0; n < N; n++) { // Columns of C
  59. for (Long m = 0; m < M; m++) { // Rows of C
  60. ValueType AxB = 0;
  61. for (Long k = 0; k < K; k++) {
  62. AxB += A[k + lda * m] * B[n + ldb * k];
  63. }
  64. C[m + ldc * n] = alpha * AxB + (beta == 0 ? 0 : beta * C[m + ldc * n]);
  65. }
  66. }
  67. }
  68. }
  69. #if defined(SCTL_HAVE_BLAS)
  70. template <> inline void gemm<float>(char TransA, char TransB, int M, int N, int K, float alpha, Iterator<float> A, int lda, Iterator<float> B, int ldb, float beta, Iterator<float> C, int ldc) { sgemm_(&TransA, &TransB, &M, &N, &K, &alpha, &A[0], &lda, &B[0], &ldb, &beta, &C[0], &ldc); }
  71. template <> inline void gemm<double>(char TransA, char TransB, int M, int N, int K, double alpha, Iterator<double> A, int lda, Iterator<double> B, int ldb, double beta, Iterator<double> C, int ldc) { dgemm_(&TransA, &TransB, &M, &N, &K, &alpha, &A[0], &lda, &B[0], &ldb, &beta, &C[0], &ldc); }
  72. #endif
  73. #if defined(SCTL_HAVE_CUDA)
  74. template <> inline void cublasgemm<float>(char TransA, char TransB, int M, int N, int K, float alpha, Iterator<float> A, int lda, Iterator<float> B, int ldb, float beta, Iterator<float> C, int ldc) {
  75. cublasOperation_t cublasTransA, cublasTransB;
  76. cublasHandle_t *handle = CUDA_Lock::acquire_handle();
  77. if (TransA == 'T' || TransA == 't')
  78. cublasTransA = CUBLAS_OP_T;
  79. else if (TransA == 'N' || TransA == 'n')
  80. cublasTransA = CUBLAS_OP_N;
  81. if (TransB == 'T' || TransB == 't')
  82. cublasTransB = CUBLAS_OP_T;
  83. else if (TransB == 'N' || TransB == 'n')
  84. cublasTransB = CUBLAS_OP_N;
  85. cublasStatus_t status = cublasSgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
  86. }
  87. template <> inline void cublasgemm<double>(char TransA, char TransB, int M, int N, int K, double alpha, Iterator<double> A, int lda, Iterator<double> B, int ldb, double beta, Iterator<double> C, int ldc) {
  88. cublasOperation_t cublasTransA, cublasTransB;
  89. cublasHandle_t *handle = CUDA_Lock::acquire_handle();
  90. if (TransA == 'T' || TransA == 't')
  91. cublasTransA = CUBLAS_OP_T;
  92. else if (TransA == 'N' || TransA == 'n')
  93. cublasTransA = CUBLAS_OP_N;
  94. if (TransB == 'T' || TransB == 't')
  95. cublasTransB = CUBLAS_OP_T;
  96. else if (TransB == 'N' || TransB == 'n')
  97. cublasTransB = CUBLAS_OP_N;
  98. cublasStatus_t status = cublasDgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
  99. }
  100. #endif
  101. //#define SCTL_SVD_DEBUG
  102. template <class ValueType> static inline void GivensL(Iterator<ValueType> S_, const StaticArray<Long, 2> &dim, Long m, ValueType a, ValueType b) {
  103. auto S = [S_,dim](Long i, Long j) -> ValueType& { return S_[(i) * dim[1] + (j)]; };
  104. ValueType r = sqrt<ValueType>(a * a + b * b);
  105. if (r == 0) return;
  106. ValueType c = a / r;
  107. ValueType s = -b / r;
  108. #pragma omp parallel for
  109. for (Long i = 0; i < dim[1]; i++) {
  110. ValueType S0 = S(m + 0, i);
  111. ValueType S1 = S(m + 1, i);
  112. S(m, i) += S0 * (c - 1);
  113. S(m, i) += S1 * (-s);
  114. S(m + 1, i) += S0 * (s);
  115. S(m + 1, i) += S1 * (c - 1);
  116. }
  117. }
  118. template <class ValueType> static inline void GivensR(Iterator<ValueType> S_, const StaticArray<Long, 2> &dim, Long m, ValueType a, ValueType b) {
  119. auto S = [S_,dim](Long i, Long j) -> ValueType& { return S_[(i) * dim[1] + (j)]; };
  120. ValueType r = sqrt<ValueType>(a * a + b * b);
  121. if (r == 0) return;
  122. ValueType c = a / r;
  123. ValueType s = -b / r;
  124. #pragma omp parallel for
  125. for (Long i = 0; i < dim[0]; i++) {
  126. ValueType S0 = S(i, m + 0);
  127. ValueType S1 = S(i, m + 1);
  128. S(i, m) += S0 * (c - 1);
  129. S(i, m) += S1 * (-s);
  130. S(i, m + 1) += S0 * (s);
  131. S(i, m + 1) += S1 * (c - 1);
  132. }
  133. }
  134. template <class ValueType> static inline void SVD(const StaticArray<Long, 2> &dim, Iterator<ValueType> U_, Iterator<ValueType> S_, Iterator<ValueType> V_, ValueType eps = -1) {
  135. auto U = [U_,dim](Long i, Long j) -> ValueType& { return U_[(i) * dim[0] + (j)]; };
  136. auto S = [S_,dim](Long i, Long j) -> ValueType& { return S_[(i) * dim[1] + (j)]; };
  137. auto V = [V_,dim](Long i, Long j) -> ValueType& { return V_[(i) * dim[1] + (j)]; };
  138. assert(dim[0] >= dim[1]);
  139. #ifdef SCTL_SVD_DEBUG
  140. Matrix<ValueType> M0(dim[0], dim[1], S_);
  141. #endif
  142. { // Bi-diagonalization
  143. Long n = std::min(dim[0], dim[1]);
  144. std::vector<ValueType> house_vec(std::max(dim[0], dim[1]));
  145. for (Long i = 0; i < n; i++) {
  146. // Column Householder
  147. {
  148. ValueType x1 = S(i, i);
  149. if (x1 < 0) x1 = -x1;
  150. ValueType x_inv_norm = 0;
  151. for (Long j = i; j < dim[0]; j++) {
  152. x_inv_norm += S(j, i) * S(j, i);
  153. }
  154. if (x_inv_norm > 0) x_inv_norm = 1 / sqrt<ValueType>(x_inv_norm);
  155. ValueType alpha = sqrt<ValueType>(1 + x1 * x_inv_norm);
  156. ValueType beta = x_inv_norm / alpha;
  157. if (x_inv_norm == 0) alpha = 0; // nothing to do
  158. house_vec[i] = -alpha;
  159. for (Long j = i + 1; j < dim[0]; j++) {
  160. house_vec[j] = -beta * S(j, i);
  161. }
  162. if (S(i, i) < 0)
  163. for (Long j = i + 1; j < dim[0]; j++) {
  164. house_vec[j] = -house_vec[j];
  165. }
  166. }
  167. #pragma omp parallel for
  168. for (Long k = i; k < dim[1]; k++) {
  169. ValueType dot_prod = 0;
  170. for (Long j = i; j < dim[0]; j++) {
  171. dot_prod += S(j, k) * house_vec[j];
  172. }
  173. for (Long j = i; j < dim[0]; j++) {
  174. S(j, k) -= dot_prod * house_vec[j];
  175. }
  176. }
  177. #pragma omp parallel for
  178. for (Long k = 0; k < dim[0]; k++) {
  179. ValueType dot_prod = 0;
  180. for (Long j = i; j < dim[0]; j++) {
  181. dot_prod += U(k, j) * house_vec[j];
  182. }
  183. for (Long j = i; j < dim[0]; j++) {
  184. U(k, j) -= dot_prod * house_vec[j];
  185. }
  186. }
  187. // Row Householder
  188. if (i >= n - 1) continue;
  189. {
  190. ValueType x1 = S(i, i + 1);
  191. if (x1 < 0) x1 = -x1;
  192. ValueType x_inv_norm = 0;
  193. for (Long j = i + 1; j < dim[1]; j++) {
  194. x_inv_norm += S(i, j) * S(i, j);
  195. }
  196. if (x_inv_norm > 0) x_inv_norm = 1 / sqrt<ValueType>(x_inv_norm);
  197. ValueType alpha = sqrt<ValueType>(1 + x1 * x_inv_norm);
  198. ValueType beta = x_inv_norm / alpha;
  199. if (x_inv_norm == 0) alpha = 0; // nothing to do
  200. house_vec[i + 1] = -alpha;
  201. for (Long j = i + 2; j < dim[1]; j++) {
  202. house_vec[j] = -beta * S(i, j);
  203. }
  204. if (S(i, i + 1) < 0)
  205. for (Long j = i + 2; j < dim[1]; j++) {
  206. house_vec[j] = -house_vec[j];
  207. }
  208. }
  209. #pragma omp parallel for
  210. for (Long k = i; k < dim[0]; k++) {
  211. ValueType dot_prod = 0;
  212. for (Long j = i + 1; j < dim[1]; j++) {
  213. dot_prod += S(k, j) * house_vec[j];
  214. }
  215. for (Long j = i + 1; j < dim[1]; j++) {
  216. S(k, j) -= dot_prod * house_vec[j];
  217. }
  218. }
  219. #pragma omp parallel for
  220. for (Long k = 0; k < dim[1]; k++) {
  221. ValueType dot_prod = 0;
  222. for (Long j = i + 1; j < dim[1]; j++) {
  223. dot_prod += V(j, k) * house_vec[j];
  224. }
  225. for (Long j = i + 1; j < dim[1]; j++) {
  226. V(j, k) -= dot_prod * house_vec[j];
  227. }
  228. }
  229. }
  230. }
  231. Long k0 = 0;
  232. Long iter = 0;
  233. if (eps < 0) eps = 64.0 * machine_eps<ValueType>();
  234. while (k0 < dim[1] - 1) { // Diagonalization
  235. iter++;
  236. ValueType S_max = 0.0;
  237. for (Long i = 0; i < dim[1]; i++) S_max = (S_max > fabs<ValueType>(S(i, i)) ? S_max : fabs<ValueType>(S(i, i)));
  238. for (Long i = 0; i < dim[1] - 1; i++) S_max = (S_max > fabs<ValueType>(S(i, i + 1)) ? S_max : fabs<ValueType>(S(i, i + 1)));
  239. // while(k0<dim[1]-1 && fabs<ValueType>(S(k0,k0+1))<=eps*(fabs<ValueType>(S(k0,k0))+fabs<ValueType>(S(k0+1,k0+1)))) k0++;
  240. while (k0 < dim[1] - 1 && fabs<ValueType>(S(k0, k0 + 1)) <= eps * S_max) k0++;
  241. if (k0 == dim[1] - 1) continue;
  242. Long n = k0 + 2;
  243. // while(n<dim[1] && fabs<ValueType>(S(n-1,n))>eps*(fabs<ValueType>(S(n-1,n-1))+fabs<ValueType>(S(n,n)))) n++;
  244. while (n < dim[1] && fabs<ValueType>(S(n - 1, n)) > eps * S_max) n++;
  245. ValueType alpha = 0;
  246. ValueType beta = 0;
  247. if (n - k0 == 2 && fabs<ValueType>(S(k0, k0)) < eps * S_max && fabs<ValueType>(S(k0 + 1, k0 + 1)) < eps * S_max) { // Compute mu
  248. alpha=0;
  249. beta=1;
  250. } else {
  251. StaticArray<ValueType, 2 * 2> C;
  252. C[0 * 2 + 0] = S(n - 2, n - 2) * S(n - 2, n - 2);
  253. if (n - k0 > 2) C[0 * 2 + 0] += S(n - 3, n - 2) * S(n - 3, n - 2);
  254. C[0 * 2 + 1] = S(n - 2, n - 2) * S(n - 2, n - 1);
  255. C[1 * 2 + 0] = S(n - 2, n - 2) * S(n - 2, n - 1);
  256. C[1 * 2 + 1] = S(n - 1, n - 1) * S(n - 1, n - 1) + S(n - 2, n - 1) * S(n - 2, n - 1);
  257. ValueType b = -(C[0 * 2 + 0] + C[1 * 2 + 1]) / 2;
  258. ValueType c = C[0 * 2 + 0] * C[1 * 2 + 1] - C[0 * 2 + 1] * C[1 * 2 + 0];
  259. ValueType d = 0;
  260. if (fabs(b * b - c) > eps*b*b)
  261. d = sqrt<ValueType>(b * b - c);
  262. else {
  263. ValueType b = (C[0 * 2 + 0] - C[1 * 2 + 1]) / 2;
  264. ValueType c = -C[0 * 2 + 1] * C[1 * 2 + 0];
  265. if (b * b - c > 0) d = sqrt<ValueType>(b * b - c);
  266. }
  267. ValueType lambda1 = -b + d;
  268. ValueType lambda2 = -b - d;
  269. ValueType d1 = lambda1 - C[1 * 2 + 1];
  270. d1 = (d1 < 0 ? -d1 : d1);
  271. ValueType d2 = lambda2 - C[1 * 2 + 1];
  272. d2 = (d2 < 0 ? -d2 : d2);
  273. ValueType mu = (d1 < d2 ? lambda1 : lambda2);
  274. alpha = S(k0, k0) * S(k0, k0) - mu;
  275. beta = S(k0, k0) * S(k0, k0 + 1);
  276. }
  277. for (Long k = k0; k < n - 1; k++) {
  278. StaticArray<Long, 2> dimU;
  279. dimU[0] = dim[0];
  280. dimU[1] = dim[0];
  281. StaticArray<Long, 2> dimV;
  282. dimV[0] = dim[1];
  283. dimV[1] = dim[1];
  284. GivensR(S_, dim, k, alpha, beta);
  285. GivensL(V_, dimV, k, alpha, beta);
  286. alpha = S(k, k);
  287. beta = S(k + 1, k);
  288. GivensL(S_, dim, k, alpha, beta);
  289. GivensR(U_, dimU, k, alpha, beta);
  290. alpha = S(k, k + 1);
  291. beta = S(k, k + 2);
  292. }
  293. { // Make S bi-diagonal again
  294. for (Long i0 = k0; i0 < n - 1; i0++) {
  295. for (Long i1 = 0; i1 < dim[1]; i1++) {
  296. if (i0 > i1 || i0 + 1 < i1) S(i0, i1) = 0;
  297. }
  298. }
  299. for (Long i0 = 0; i0 < dim[0]; i0++) {
  300. for (Long i1 = k0; i1 < n - 1; i1++) {
  301. if (i0 > i1 || i0 + 1 < i1) S(i0, i1) = 0;
  302. }
  303. }
  304. for (Long i = 0; i < dim[1] - 1; i++) {
  305. if (fabs<ValueType>(S(i, i + 1)) <= eps * S_max) {
  306. S(i, i + 1) = 0;
  307. }
  308. }
  309. }
  310. // std::cout<<iter<<' '<<k0<<' '<<n<<'\n';
  311. }
  312. { // Check Error
  313. #ifdef SCTL_SVD_DEBUG
  314. Matrix<ValueType> U0(dim[0], dim[0], U_);
  315. Matrix<ValueType> S0(dim[0], dim[1], S_);
  316. Matrix<ValueType> V0(dim[1], dim[1], V_);
  317. Matrix<ValueType> E = M0 - U0 * S0 * V0;
  318. ValueType max_err = 0;
  319. ValueType max_nondiag0 = 0;
  320. ValueType max_nondiag1 = 0;
  321. for (Long i = 0; i < E.Dim(0); i++)
  322. for (Long j = 0; j < E.Dim(1); j++) {
  323. if (max_err < fabs<ValueType>(E[i][j])) max_err = fabs<ValueType>(E[i][j]);
  324. if ((i > j + 0 || i + 0 < j) && max_nondiag0 < fabs<ValueType>(S0[i][j])) max_nondiag0 = fabs<ValueType>(S0[i][j]);
  325. if ((i > j + 1 || i + 1 < j) && max_nondiag1 < fabs<ValueType>(S0[i][j])) max_nondiag1 = fabs<ValueType>(S0[i][j]);
  326. }
  327. std::cout << max_err << '\n';
  328. std::cout << max_nondiag0 << '\n';
  329. std::cout << max_nondiag1 << '\n';
  330. #endif
  331. }
  332. }
  333. #undef SCTL_SVD_DEBUG
  334. template <class ValueType> inline void svd(char *JOBU, char *JOBVT, int *M, int *N, Iterator<ValueType> A, int *LDA, Iterator<ValueType> S, Iterator<ValueType> U, int *LDU, Iterator<ValueType> VT, int *LDVT, Iterator<ValueType> WORK, int *LWORK, int *INFO) {
  335. StaticArray<Long, 2> dim;
  336. dim[0] = std::max(*N, *M);
  337. dim[1] = std::min(*N, *M);
  338. Iterator<ValueType> U_ = aligned_new<ValueType>(dim[0] * dim[0]);
  339. memset(U_, 0, dim[0] * dim[0]);
  340. Iterator<ValueType> V_ = aligned_new<ValueType>(dim[1] * dim[1]);
  341. memset(V_, 0, dim[1] * dim[1]);
  342. Iterator<ValueType> S_ = aligned_new<ValueType>(dim[0] * dim[1]);
  343. const Long lda = *LDA;
  344. const Long ldu = *LDU;
  345. const Long ldv = *LDVT;
  346. if (dim[1] == *M) {
  347. for (Long i = 0; i < dim[0]; i++)
  348. for (Long j = 0; j < dim[1]; j++) {
  349. S_[i * dim[1] + j] = A[i * lda + j];
  350. }
  351. } else {
  352. for (Long i = 0; i < dim[0]; i++)
  353. for (Long j = 0; j < dim[1]; j++) {
  354. S_[i * dim[1] + j] = A[j * lda + i];
  355. }
  356. }
  357. for (Long i = 0; i < dim[0]; i++) {
  358. U_[i * dim[0] + i] = 1;
  359. }
  360. for (Long i = 0; i < dim[1]; i++) {
  361. V_[i * dim[1] + i] = 1;
  362. }
  363. SVD<ValueType>(dim, U_, S_, V_, (ValueType) - 1);
  364. for (Long i = 0; i < dim[1]; i++) { // Set S
  365. S[i] = S_[i * dim[1] + i];
  366. }
  367. if (dim[1] == *M) { // Set U
  368. for (Long i = 0; i < dim[1]; i++)
  369. for (Long j = 0; j < *M; j++) {
  370. U[j + ldu * i] = V_[j + i * dim[1]] * (S[i] < 0.0 ? -1.0 : 1.0);
  371. }
  372. } else {
  373. for (Long i = 0; i < dim[1]; i++)
  374. for (Long j = 0; j < *M; j++) {
  375. U[j + ldu * i] = U_[i + j * dim[0]] * (S[i] < 0.0 ? -1.0 : 1.0);
  376. }
  377. }
  378. if (dim[0] == *N) { // Set V
  379. for (Long i = 0; i < *N; i++)
  380. for (Long j = 0; j < dim[1]; j++) {
  381. VT[j + ldv * i] = U_[j + i * dim[0]];
  382. }
  383. } else {
  384. for (Long i = 0; i < *N; i++)
  385. for (Long j = 0; j < dim[1]; j++) {
  386. VT[j + ldv * i] = V_[i + j * dim[1]];
  387. }
  388. }
  389. for (Long i = 0; i < dim[1]; i++) {
  390. S[i] = S[i] * (S[i] < 0.0 ? -1.0 : 1.0);
  391. }
  392. aligned_delete<ValueType>(U_);
  393. aligned_delete<ValueType>(S_);
  394. aligned_delete<ValueType>(V_);
  395. if (0) { // Verify
  396. StaticArray<Long, 2> dim;
  397. dim[0] = std::max(*N, *M);
  398. dim[1] = std::min(*N, *M);
  399. const Long lda = *LDA;
  400. const Long ldu = *LDU;
  401. const Long ldv = *LDVT;
  402. Matrix<ValueType> A1(*M, *N);
  403. Matrix<ValueType> S1(dim[1], dim[1]);
  404. Matrix<ValueType> U1(*M, dim[1]);
  405. Matrix<ValueType> V1(dim[1], *N);
  406. for (Long i = 0; i < *N; i++)
  407. for (Long j = 0; j < *M; j++) {
  408. A1[j][i] = A[j + i * lda];
  409. }
  410. S1.SetZero();
  411. for (Long i = 0; i < dim[1]; i++) { // Set S
  412. S1[i][i] = S[i];
  413. }
  414. for (Long i = 0; i < dim[1]; i++)
  415. for (Long j = 0; j < *M; j++) {
  416. U1[j][i] = U[j + ldu * i];
  417. }
  418. for (Long i = 0; i < *N; i++)
  419. for (Long j = 0; j < dim[1]; j++) {
  420. V1[j][i] = VT[j + ldv * i];
  421. }
  422. std::cout << U1 *S1 *V1 - A1 << '\n';
  423. }
  424. }
  425. #if defined(SCTL_HAVE_LAPACK)
  426. template <> inline void svd<float>(char *JOBU, char *JOBVT, int *M, int *N, Iterator<float> A, int *LDA, Iterator<float> S, Iterator<float> U, int *LDU, Iterator<float> VT, int *LDVT, Iterator<float> WORK, int *LWORK, int *INFO) { sgesvd_(JOBU, JOBVT, M, N, &A[0], LDA, &S[0], &U[0], LDU, &VT[0], LDVT, &WORK[0], LWORK, INFO); }
  427. template <> inline void svd<double>(char *JOBU, char *JOBVT, int *M, int *N, Iterator<double> A, int *LDA, Iterator<double> S, Iterator<double> U, int *LDU, Iterator<double> VT, int *LDVT, Iterator<double> WORK, int *LWORK, int *INFO) { dgesvd_(JOBU, JOBVT, M, N, &A[0], LDA, &S[0], &U[0], LDU, &VT[0], LDVT, &WORK[0], LWORK, INFO); }
  428. #endif
  429. /**
  430. * \brief Computes the pseudo inverse of matrix M(n1xn2) (in row major form)
  431. * and returns the output M_(n2xn1). Original contents of M are destroyed.
  432. */
  433. template <class ValueType> inline void pinv(Iterator<ValueType> M, int n1, int n2, ValueType eps, Iterator<ValueType> M_) {
  434. if (n1 * n2 == 0) return;
  435. int m = n2;
  436. int n = n1;
  437. int k = (m < n ? m : n);
  438. Iterator<ValueType> tU = aligned_new<ValueType>(m * k);
  439. Iterator<ValueType> tS = aligned_new<ValueType>(k);
  440. Iterator<ValueType> tVT = aligned_new<ValueType>(k * n);
  441. // SVD
  442. int INFO = 0;
  443. char JOBU = 'S';
  444. char JOBVT = 'S';
  445. // int wssize = max(3*min(m,n)+max(m,n), 5*min(m,n));
  446. int wssize = 3 * (m < n ? m : n) + (m > n ? m : n);
  447. int wssize1 = 5 * (m < n ? m : n);
  448. wssize = (wssize > wssize1 ? wssize : wssize1);
  449. Iterator<ValueType> wsbuf = aligned_new<ValueType>(wssize);
  450. svd(&JOBU, &JOBVT, &m, &n, M, &m, tS, tU, &m, tVT, &k, wsbuf, &wssize, &INFO);
  451. if (INFO != 0) std::cout << INFO << '\n';
  452. assert(INFO == 0);
  453. aligned_delete<ValueType>(wsbuf);
  454. ValueType eps_ = tS[0] * eps;
  455. for (int i = 0; i < k; i++)
  456. if (tS[i] < eps_)
  457. tS[i] = 0;
  458. else
  459. tS[i] = 1 / tS[i];
  460. for (int i = 0; i < m; i++) {
  461. for (int j = 0; j < k; j++) {
  462. tU[i + j * m] *= tS[j];
  463. }
  464. }
  465. gemm<ValueType>('T', 'T', n, m, k, 1.0, tVT, k, tU, m, 0.0, M_, n);
  466. aligned_delete<ValueType>(tU);
  467. aligned_delete<ValueType>(tS);
  468. aligned_delete<ValueType>(tVT);
  469. }
  470. } // end namespace mat
  471. } // end namespace SCTL_NAMESPACE