mat_utils.txx 17 KB


  1. #include <pvfmm/matrix.hpp>
  2. #if defined(PVFMM_HAVE_CUDA)
  3. #include <cuda_runtime_api.h>
  4. #include <cublas_v2.h>
  5. #endif
  6. #if defined(PVFMM_HAVE_BLAS)
  7. #include <pvfmm/blas.h>
  8. #endif
  9. #if defined(PVFMM_HAVE_LAPACK)
  10. #include <pvfmm/lapack.h>
  11. #endif
  12. #include <omp.h>
  13. #include <cmath>
  14. #include <cassert>
  15. #include <cstdlib>
  16. #include <algorithm>
  17. #include <iostream>
  18. #include <vector>
  19. namespace pvfmm {
  20. namespace mat {
  21. template <class ValueType> inline void gemm(char TransA, char TransB, int M, int N, int K, ValueType alpha, Iterator<ValueType> A, int lda, Iterator<ValueType> B, int ldb, ValueType beta, Iterator<ValueType> C, int ldc) {
  22. if ((TransA == 'N' || TransA == 'n') && (TransB == 'N' || TransB == 'n')) {
  23. for (Long n = 0; n < N; n++) { // Columns of C
  24. for (Long m = 0; m < M; m++) { // Rows of C
  25. ValueType AxB = 0;
  26. for (Long k = 0; k < K; k++) {
  27. AxB += A[m + lda * k] * B[k + ldb * n];
  28. }
  29. C[m + ldc * n] = alpha * AxB + (beta == 0 ? 0 : beta * C[m + ldc * n]);
  30. }
  31. }
  32. } else if (TransA == 'N' || TransA == 'n') {
  33. for (Long n = 0; n < N; n++) { // Columns of C
  34. for (Long m = 0; m < M; m++) { // Rows of C
  35. ValueType AxB = 0;
  36. for (Long k = 0; k < K; k++) {
  37. AxB += A[m + lda * k] * B[n + ldb * k];
  38. }
  39. C[m + ldc * n] = alpha * AxB + (beta == 0 ? 0 : beta * C[m + ldc * n]);
  40. }
  41. }
  42. } else if (TransB == 'N' || TransB == 'n') {
  43. for (Long n = 0; n < N; n++) { // Columns of C
  44. for (Long m = 0; m < M; m++) { // Rows of C
  45. ValueType AxB = 0;
  46. for (Long k = 0; k < K; k++) {
  47. AxB += A[k + lda * m] * B[k + ldb * n];
  48. }
  49. C[m + ldc * n] = alpha * AxB + (beta == 0 ? 0 : beta * C[m + ldc * n]);
  50. }
  51. }
  52. } else {
  53. for (Long n = 0; n < N; n++) { // Columns of C
  54. for (Long m = 0; m < M; m++) { // Rows of C
  55. ValueType AxB = 0;
  56. for (Long k = 0; k < K; k++) {
  57. AxB += A[k + lda * m] * B[n + ldb * k];
  58. }
  59. C[m + ldc * n] = alpha * AxB + (beta == 0 ? 0 : beta * C[m + ldc * n]);
  60. }
  61. }
  62. }
  63. }
  64. #if defined(PVFMM_HAVE_BLAS)
  65. template <> inline void gemm<float>(char TransA, char TransB, int M, int N, int K, float alpha, Iterator<float> A, int lda, Iterator<float> B, int ldb, float beta, Iterator<float> C, int ldc) { sgemm_(&TransA, &TransB, &M, &N, &K, &alpha, &A[0], &lda, &B[0], &ldb, &beta, &C[0], &ldc); }
  66. template <> inline void gemm<double>(char TransA, char TransB, int M, int N, int K, double alpha, Iterator<double> A, int lda, Iterator<double> B, int ldb, double beta, Iterator<double> C, int ldc) { dgemm_(&TransA, &TransB, &M, &N, &K, &alpha, &A[0], &lda, &B[0], &ldb, &beta, &C[0], &ldc); }
  67. #endif
  68. #if defined(PVFMM_HAVE_CUDA)
  69. template <> inline void cublasgemm<float>(char TransA, char TransB, int M, int N, int K, float alpha, Iterator<float> A, int lda, Iterator<float> B, int ldb, float beta, Iterator<float> C, int ldc) {
  70. cublasOperation_t cublasTransA, cublasTransB;
  71. cublasHandle_t *handle = CUDA_Lock::acquire_handle();
  72. if (TransA == 'T' || TransA == 't')
  73. cublasTransA = CUBLAS_OP_T;
  74. else if (TransA == 'N' || TransA == 'n')
  75. cublasTransA = CUBLAS_OP_N;
  76. if (TransB == 'T' || TransB == 't')
  77. cublasTransB = CUBLAS_OP_T;
  78. else if (TransB == 'N' || TransB == 'n')
  79. cublasTransB = CUBLAS_OP_N;
  80. cublasStatus_t status = cublasSgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
  81. }
  82. template <> inline void cublasgemm<double>(char TransA, char TransB, int M, int N, int K, double alpha, Iterator<double> A, int lda, Iterator<double> B, int ldb, double beta, Iterator<double> C, int ldc) {
  83. cublasOperation_t cublasTransA, cublasTransB;
  84. cublasHandle_t *handle = CUDA_Lock::acquire_handle();
  85. if (TransA == 'T' || TransA == 't')
  86. cublasTransA = CUBLAS_OP_T;
  87. else if (TransA == 'N' || TransA == 'n')
  88. cublasTransA = CUBLAS_OP_N;
  89. if (TransB == 'T' || TransB == 't')
  90. cublasTransB = CUBLAS_OP_T;
  91. else if (TransB == 'N' || TransB == 'n')
  92. cublasTransB = CUBLAS_OP_N;
  93. cublasStatus_t status = cublasDgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
  94. }
  95. #endif
  96. #define U(i, j) U_[(i) * dim[0] + (j)]
  97. #define S(i, j) S_[(i) * dim[1] + (j)]
  98. #define V(i, j) V_[(i) * dim[1] + (j)]
  99. //#define SVD_DEBUG
  100. template <class ValueType> static inline void GivensL(Iterator<ValueType> S_, StaticArray<Long, 2> &dim, Long m, ValueType a, ValueType b) {
  101. ValueType r = pvfmm::sqrt<ValueType>(a * a + b * b);
  102. ValueType c = a / r;
  103. ValueType s = -b / r;
  104. #pragma omp parallel for
  105. for (Long i = 0; i < dim[1]; i++) {
  106. ValueType S0 = S(m + 0, i);
  107. ValueType S1 = S(m + 1, i);
  108. S(m, i) += S0 * (c - 1);
  109. S(m, i) += S1 * (-s);
  110. S(m + 1, i) += S0 * (s);
  111. S(m + 1, i) += S1 * (c - 1);
  112. }
  113. }
  114. template <class ValueType> static inline void GivensR(Iterator<ValueType> S_, StaticArray<Long, 2> &dim, Long m, ValueType a, ValueType b) {
  115. ValueType r = pvfmm::sqrt<ValueType>(a * a + b * b);
  116. ValueType c = a / r;
  117. ValueType s = -b / r;
  118. #pragma omp parallel for
  119. for (Long i = 0; i < dim[0]; i++) {
  120. ValueType S0 = S(i, m + 0);
  121. ValueType S1 = S(i, m + 1);
  122. S(i, m) += S0 * (c - 1);
  123. S(i, m) += S1 * (-s);
  124. S(i, m + 1) += S0 * (s);
  125. S(i, m + 1) += S1 * (c - 1);
  126. }
  127. }
  128. template <class ValueType> static inline void SVD(StaticArray<Long, 2> &dim, Iterator<ValueType> U_, Iterator<ValueType> S_, Iterator<ValueType> V_, ValueType eps = -1) {
  129. assert(dim[0] >= dim[1]);
  130. #ifdef SVD_DEBUG
  131. Matrix<ValueType> M0(dim[0], dim[1], S_);
  132. #endif
  133. { // Bi-diagonalization
  134. Long n = std::min(dim[0], dim[1]);
  135. std::vector<ValueType> house_vec(std::max(dim[0], dim[1]));
  136. for (Long i = 0; i < n; i++) {
  137. // Column Householder
  138. {
  139. ValueType x1 = S(i, i);
  140. if (x1 < 0) x1 = -x1;
  141. ValueType x_inv_norm = 0;
  142. for (Long j = i; j < dim[0]; j++) {
  143. x_inv_norm += S(j, i) * S(j, i);
  144. }
  145. if (x_inv_norm > 0) x_inv_norm = 1 / pvfmm::sqrt<ValueType>(x_inv_norm);
  146. ValueType alpha = pvfmm::sqrt<ValueType>(1 + x1 * x_inv_norm);
  147. ValueType beta = x_inv_norm / alpha;
  148. house_vec[i] = -alpha;
  149. for (Long j = i + 1; j < dim[0]; j++) {
  150. house_vec[j] = -beta * S(j, i);
  151. }
  152. if (S(i, i) < 0)
  153. for (Long j = i + 1; j < dim[0]; j++) {
  154. house_vec[j] = -house_vec[j];
  155. }
  156. }
  157. #pragma omp parallel for
  158. for (Long k = i; k < dim[1]; k++) {
  159. ValueType dot_prod = 0;
  160. for (Long j = i; j < dim[0]; j++) {
  161. dot_prod += S(j, k) * house_vec[j];
  162. }
  163. for (Long j = i; j < dim[0]; j++) {
  164. S(j, k) -= dot_prod * house_vec[j];
  165. }
  166. }
  167. #pragma omp parallel for
  168. for (Long k = 0; k < dim[0]; k++) {
  169. ValueType dot_prod = 0;
  170. for (Long j = i; j < dim[0]; j++) {
  171. dot_prod += U(k, j) * house_vec[j];
  172. }
  173. for (Long j = i; j < dim[0]; j++) {
  174. U(k, j) -= dot_prod * house_vec[j];
  175. }
  176. }
  177. // Row Householder
  178. if (i >= n - 1) continue;
  179. {
  180. ValueType x1 = S(i, i + 1);
  181. if (x1 < 0) x1 = -x1;
  182. ValueType x_inv_norm = 0;
  183. for (Long j = i + 1; j < dim[1]; j++) {
  184. x_inv_norm += S(i, j) * S(i, j);
  185. }
  186. if (x_inv_norm > 0) x_inv_norm = 1 / pvfmm::sqrt<ValueType>(x_inv_norm);
  187. ValueType alpha = pvfmm::sqrt<ValueType>(1 + x1 * x_inv_norm);
  188. ValueType beta = x_inv_norm / alpha;
  189. house_vec[i + 1] = -alpha;
  190. for (Long j = i + 2; j < dim[1]; j++) {
  191. house_vec[j] = -beta * S(i, j);
  192. }
  193. if (S(i, i + 1) < 0)
  194. for (Long j = i + 2; j < dim[1]; j++) {
  195. house_vec[j] = -house_vec[j];
  196. }
  197. }
  198. #pragma omp parallel for
  199. for (Long k = i; k < dim[0]; k++) {
  200. ValueType dot_prod = 0;
  201. for (Long j = i + 1; j < dim[1]; j++) {
  202. dot_prod += S(k, j) * house_vec[j];
  203. }
  204. for (Long j = i + 1; j < dim[1]; j++) {
  205. S(k, j) -= dot_prod * house_vec[j];
  206. }
  207. }
  208. #pragma omp parallel for
  209. for (Long k = 0; k < dim[1]; k++) {
  210. ValueType dot_prod = 0;
  211. for (Long j = i + 1; j < dim[1]; j++) {
  212. dot_prod += V(j, k) * house_vec[j];
  213. }
  214. for (Long j = i + 1; j < dim[1]; j++) {
  215. V(j, k) -= dot_prod * house_vec[j];
  216. }
  217. }
  218. }
  219. }
  220. Long k0 = 0;
  221. Long iter = 0;
  222. if (eps < 0) {
  223. eps = 1.0;
  224. while (eps + (ValueType)1.0 > 1.0) eps *= 0.5;
  225. eps *= 64.0;
  226. }
  227. while (k0 < dim[1] - 1) { // Diagonalization
  228. iter++;
  229. ValueType S_max = 0.0;
  230. for (Long i = 0; i < dim[1]; i++) S_max = (S_max > S(i, i) ? S_max : S(i, i));
  231. // while(k0<dim[1]-1 && pvfmm::fabs<ValueType>(S(k0,k0+1))<=eps*(pvfmm::fabs<ValueType>(S(k0,k0))+pvfmm::fabs<ValueType>(S(k0+1,k0+1)))) k0++;
  232. while (k0 < dim[1] - 1 && pvfmm::fabs<ValueType>(S(k0, k0 + 1)) <= eps * S_max) k0++;
  233. if (k0 == dim[1] - 1) continue;
  234. Long n = k0 + 2;
  235. // while(n<dim[1] && pvfmm::fabs<ValueType>(S(n-1,n))>eps*(pvfmm::fabs<ValueType>(S(n-1,n-1))+pvfmm::fabs<ValueType>(S(n,n)))) n++;
  236. while (n < dim[1] && pvfmm::fabs<ValueType>(S(n - 1, n)) > eps * S_max) n++;
  237. ValueType alpha = 0;
  238. ValueType beta = 0;
  239. { // Compute mu
  240. StaticArray<ValueType, 2 * 2> C;
  241. C[0 * 2 + 0] = S(n - 2, n - 2) * S(n - 2, n - 2);
  242. if (n - k0 > 2) C[0 * 2 + 0] += S(n - 3, n - 2) * S(n - 3, n - 2);
  243. C[0 * 2 + 1] = S(n - 2, n - 2) * S(n - 2, n - 1);
  244. C[1 * 2 + 0] = S(n - 2, n - 2) * S(n - 2, n - 1);
  245. C[1 * 2 + 1] = S(n - 1, n - 1) * S(n - 1, n - 1) + S(n - 2, n - 1) * S(n - 2, n - 1);
  246. ValueType b = -(C[0 * 2 + 0] + C[1 * 2 + 1]) / 2;
  247. ValueType c = C[0 * 2 + 0] * C[1 * 2 + 1] - C[0 * 2 + 1] * C[1 * 2 + 0];
  248. ValueType d = 0;
  249. if (b * b - c > 0)
  250. d = pvfmm::sqrt<ValueType>(b * b - c);
  251. else {
  252. ValueType b = (C[0 * 2 + 0] - C[1 * 2 + 1]) / 2;
  253. ValueType c = -C[0 * 2 + 1] * C[1 * 2 + 0];
  254. if (b * b - c > 0) d = pvfmm::sqrt<ValueType>(b * b - c);
  255. }
  256. ValueType lambda1 = -b + d;
  257. ValueType lambda2 = -b - d;
  258. ValueType d1 = lambda1 - C[1 * 2 + 1];
  259. d1 = (d1 < 0 ? -d1 : d1);
  260. ValueType d2 = lambda2 - C[1 * 2 + 1];
  261. d2 = (d2 < 0 ? -d2 : d2);
  262. ValueType mu = (d1 < d2 ? lambda1 : lambda2);
  263. alpha = S(k0, k0) * S(k0, k0) - mu;
  264. beta = S(k0, k0) * S(k0, k0 + 1);
  265. }
  266. for (Long k = k0; k < n - 1; k++) {
  267. StaticArray<Long, 2> dimU;
  268. dimU[0] = dim[0];
  269. dimU[1] = dim[0];
  270. StaticArray<Long, 2> dimV;
  271. dimV[0] = dim[1];
  272. dimV[1] = dim[1];
  273. GivensR(S_, dim, k, alpha, beta);
  274. GivensL(V_, dimV, k, alpha, beta);
  275. alpha = S(k, k);
  276. beta = S(k + 1, k);
  277. GivensL(S_, dim, k, alpha, beta);
  278. GivensR(U_, dimU, k, alpha, beta);
  279. alpha = S(k, k + 1);
  280. beta = S(k, k + 2);
  281. }
  282. { // Make S bi-diagonal again
  283. for (Long i0 = k0; i0 < n - 1; i0++) {
  284. for (Long i1 = 0; i1 < dim[1]; i1++) {
  285. if (i0 > i1 || i0 + 1 < i1) S(i0, i1) = 0;
  286. }
  287. }
  288. for (Long i0 = 0; i0 < dim[0]; i0++) {
  289. for (Long i1 = k0; i1 < n - 1; i1++) {
  290. if (i0 > i1 || i0 + 1 < i1) S(i0, i1) = 0;
  291. }
  292. }
  293. for (Long i = 0; i < dim[1] - 1; i++) {
  294. if (pvfmm::fabs<ValueType>(S(i, i + 1)) <= eps * S_max) {
  295. S(i, i + 1) = 0;
  296. }
  297. }
  298. }
  299. // std::cout<<iter<<' '<<k0<<' '<<n<<'\n';
  300. }
  301. { // Check Error
  302. #ifdef SVD_DEBUG
  303. Matrix<ValueType> U0(dim[0], dim[0], U_);
  304. Matrix<ValueType> S0(dim[0], dim[1], S_);
  305. Matrix<ValueType> V0(dim[1], dim[1], V_);
  306. Matrix<ValueType> E = M0 - U0 * S0 * V0;
  307. ValueType max_err = 0;
  308. ValueType max_nondiag0 = 0;
  309. ValueType max_nondiag1 = 0;
  310. for (Long i = 0; i < E.Dim(0); i++)
  311. for (Long j = 0; j < E.Dim(1); j++) {
  312. if (max_err < pvfmm::fabs<ValueType>(E[i][j])) max_err = pvfmm::fabs<ValueType>(E[i][j]);
  313. if ((i > j + 0 || i + 0 < j) && max_nondiag0 < pvfmm::fabs<ValueType>(S0[i][j])) max_nondiag0 = pvfmm::fabs<ValueType>(S0[i][j]);
  314. if ((i > j + 1 || i + 1 < j) && max_nondiag1 < pvfmm::fabs<ValueType>(S0[i][j])) max_nondiag1 = pvfmm::fabs<ValueType>(S0[i][j]);
  315. }
  316. std::cout << max_err << '\n';
  317. std::cout << max_nondiag0 << '\n';
  318. std::cout << max_nondiag1 << '\n';
  319. #endif
  320. }
  321. }
  322. #undef U
  323. #undef S
  324. #undef V
  325. #undef SVD_DEBUG
  326. template <class ValueType> inline void svd(char *JOBU, char *JOBVT, int *M, int *N, Iterator<ValueType> A, int *LDA, Iterator<ValueType> S, Iterator<ValueType> U, int *LDU, Iterator<ValueType> VT, int *LDVT, Iterator<ValueType> WORK, int *LWORK, int *INFO) {
  327. StaticArray<Long, 2> dim;
  328. dim[0] = std::max(*N, *M);
  329. dim[1] = std::min(*N, *M);
  330. Iterator<ValueType> U_ = aligned_new<ValueType>(dim[0] * dim[0]);
  331. memset(U_, 0, dim[0] * dim[0]);
  332. Iterator<ValueType> V_ = aligned_new<ValueType>(dim[1] * dim[1]);
  333. memset(V_, 0, dim[1] * dim[1]);
  334. Iterator<ValueType> S_ = aligned_new<ValueType>(dim[0] * dim[1]);
  335. const Long lda = *LDA;
  336. const Long ldu = *LDU;
  337. const Long ldv = *LDVT;
  338. if (dim[1] == *M) {
  339. for (Long i = 0; i < dim[0]; i++)
  340. for (Long j = 0; j < dim[1]; j++) {
  341. S_[i * dim[1] + j] = A[i * lda + j];
  342. }
  343. } else {
  344. for (Long i = 0; i < dim[0]; i++)
  345. for (Long j = 0; j < dim[1]; j++) {
  346. S_[i * dim[1] + j] = A[j * lda + i];
  347. }
  348. }
  349. for (Long i = 0; i < dim[0]; i++) {
  350. U_[i * dim[0] + i] = 1;
  351. }
  352. for (Long i = 0; i < dim[1]; i++) {
  353. V_[i * dim[1] + i] = 1;
  354. }
  355. SVD<ValueType>(dim, U_, S_, V_, (ValueType) - 1);
  356. for (Long i = 0; i < dim[1]; i++) { // Set S
  357. S[i] = S_[i * dim[1] + i];
  358. }
  359. if (dim[1] == *M) { // Set U
  360. for (Long i = 0; i < dim[1]; i++)
  361. for (Long j = 0; j < *M; j++) {
  362. U[j + ldu * i] = V_[j + i * dim[1]] * (S[i] < 0.0 ? -1.0 : 1.0);
  363. }
  364. } else {
  365. for (Long i = 0; i < dim[1]; i++)
  366. for (Long j = 0; j < *M; j++) {
  367. U[j + ldu * i] = U_[i + j * dim[0]] * (S[i] < 0.0 ? -1.0 : 1.0);
  368. }
  369. }
  370. if (dim[0] == *N) { // Set V
  371. for (Long i = 0; i < *N; i++)
  372. for (Long j = 0; j < dim[1]; j++) {
  373. VT[j + ldv * i] = U_[j + i * dim[0]];
  374. }
  375. } else {
  376. for (Long i = 0; i < *N; i++)
  377. for (Long j = 0; j < dim[1]; j++) {
  378. VT[j + ldv * i] = V_[i + j * dim[1]];
  379. }
  380. }
  381. for (Long i = 0; i < dim[1]; i++) {
  382. S[i] = S[i] * (S[i] < 0.0 ? -1.0 : 1.0);
  383. }
  384. aligned_delete<ValueType>(U_);
  385. aligned_delete<ValueType>(S_);
  386. aligned_delete<ValueType>(V_);
  387. if (0) { // Verify
  388. StaticArray<Long, 2> dim;
  389. dim[0] = std::max(*N, *M);
  390. dim[1] = std::min(*N, *M);
  391. const Long lda = *LDA;
  392. const Long ldu = *LDU;
  393. const Long ldv = *LDVT;
  394. Matrix<ValueType> A1(*M, *N);
  395. Matrix<ValueType> S1(dim[1], dim[1]);
  396. Matrix<ValueType> U1(*M, dim[1]);
  397. Matrix<ValueType> V1(dim[1], *N);
  398. for (Long i = 0; i < *N; i++)
  399. for (Long j = 0; j < *M; j++) {
  400. A1[j][i] = A[j + i * lda];
  401. }
  402. S1.SetZero();
  403. for (Long i = 0; i < dim[1]; i++) { // Set S
  404. S1[i][i] = S[i];
  405. }
  406. for (Long i = 0; i < dim[1]; i++)
  407. for (Long j = 0; j < *M; j++) {
  408. U1[j][i] = U[j + ldu * i];
  409. }
  410. for (Long i = 0; i < *N; i++)
  411. for (Long j = 0; j < dim[1]; j++) {
  412. V1[j][i] = VT[j + ldv * i];
  413. }
  414. std::cout << U1 *S1 *V1 - A1 << '\n';
  415. }
  416. }
  417. #if defined(PVFMM_HAVE_LAPACK)
  418. template <> inline void svd<float>(char *JOBU, char *JOBVT, int *M, int *N, Iterator<float> A, int *LDA, Iterator<float> S, Iterator<float> U, int *LDU, Iterator<float> VT, int *LDVT, Iterator<float> WORK, int *LWORK, int *INFO) { sgesvd_(JOBU, JOBVT, M, N, &A[0], LDA, &S[0], &U[0], LDU, &VT[0], LDVT, &WORK[0], LWORK, INFO); }
  419. template <> inline void svd<double>(char *JOBU, char *JOBVT, int *M, int *N, Iterator<double> A, int *LDA, Iterator<double> S, Iterator<double> U, int *LDU, Iterator<double> VT, int *LDVT, Iterator<double> WORK, int *LWORK, int *INFO) { dgesvd_(JOBU, JOBVT, M, N, &A[0], LDA, &S[0], &U[0], LDU, &VT[0], LDVT, &WORK[0], LWORK, INFO); }
  420. #endif
  421. /**
  422. * \brief Computes the pseudo inverse of matrix M(n1xn2) (in row major form)
  423. * and returns the output M_(n2xn1). Original contents of M are destroyed.
  424. */
  425. template <class ValueType> inline void pinv(Iterator<ValueType> M, int n1, int n2, ValueType eps, Iterator<ValueType> M_) {
  426. if (n1 * n2 == 0) return;
  427. int m = n2;
  428. int n = n1;
  429. int k = (m < n ? m : n);
  430. Iterator<ValueType> tU = aligned_new<ValueType>(m * k);
  431. Iterator<ValueType> tS = aligned_new<ValueType>(k);
  432. Iterator<ValueType> tVT = aligned_new<ValueType>(k * n);
  433. // SVD
  434. int INFO = 0;
  435. char JOBU = 'S';
  436. char JOBVT = 'S';
  437. // int wssize = max(3*min(m,n)+max(m,n), 5*min(m,n));
  438. int wssize = 3 * (m < n ? m : n) + (m > n ? m : n);
  439. int wssize1 = 5 * (m < n ? m : n);
  440. wssize = (wssize > wssize1 ? wssize : wssize1);
  441. Iterator<ValueType> wsbuf = aligned_new<ValueType>(wssize);
  442. svd(&JOBU, &JOBVT, &m, &n, M, &m, tS, tU, &m, tVT, &k, wsbuf, &wssize, &INFO);
  443. if (INFO != 0) std::cout << INFO << '\n';
  444. assert(INFO == 0);
  445. aligned_delete<ValueType>(wsbuf);
  446. ValueType eps_ = tS[0] * eps;
  447. for (int i = 0; i < k; i++)
  448. if (tS[i] < eps_)
  449. tS[i] = 0;
  450. else
  451. tS[i] = 1.0 / tS[i];
  452. for (int i = 0; i < m; i++) {
  453. for (int j = 0; j < k; j++) {
  454. tU[i + j * m] *= tS[j];
  455. }
  456. }
  457. gemm<ValueType>('T', 'T', n, m, k, 1.0, tVT, k, tU, m, 0.0, M_, n);
  458. aligned_delete<ValueType>(tU);
  459. aligned_delete<ValueType>(tS);
  460. aligned_delete<ValueType>(tVT);
  461. }
  462. } // end namespace mat
  463. } // end namespace pvfmm