mat_utils.txx 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. #include SCTL_INCLUDE(matrix.hpp)
  2. #if defined(SCTL_HAVE_CUDA)
  3. #include <cuda_runtime_api.h>
  4. #include <cublas_v2.h>
  5. #endif
  6. #if defined(SCTL_HAVE_BLAS)
  7. #include SCTL_INCLUDE(blas.h)
  8. #endif
  9. #if defined(SCTL_HAVE_LAPACK)
  10. #include SCTL_INCLUDE(lapack.h)
  11. #endif
  12. #include <omp.h>
  13. #include <cmath>
  14. #include <cassert>
  15. #include <cstdlib>
  16. #include <algorithm>
  17. #include <iostream>
  18. #include <vector>
  19. namespace SCTL_NAMESPACE {
  20. namespace mat {
  21. template <class ValueType> inline void gemm(char TransA, char TransB, int M, int N, int K, ValueType alpha, Iterator<ValueType> A, int lda, Iterator<ValueType> B, int ldb, ValueType beta, Iterator<ValueType> C, int ldc) {
  22. if ((TransA == 'N' || TransA == 'n') && (TransB == 'N' || TransB == 'n')) {
  23. for (Long n = 0; n < N; n++) { // Columns of C
  24. for (Long m = 0; m < M; m++) { // Rows of C
  25. ValueType AxB = 0;
  26. for (Long k = 0; k < K; k++) {
  27. AxB += A[m + lda * k] * B[k + ldb * n];
  28. }
  29. C[m + ldc * n] = alpha * AxB + (beta == 0 ? 0 : beta * C[m + ldc * n]);
  30. }
  31. }
  32. } else if (TransA == 'N' || TransA == 'n') {
  33. for (Long n = 0; n < N; n++) { // Columns of C
  34. for (Long m = 0; m < M; m++) { // Rows of C
  35. ValueType AxB = 0;
  36. for (Long k = 0; k < K; k++) {
  37. AxB += A[m + lda * k] * B[n + ldb * k];
  38. }
  39. C[m + ldc * n] = alpha * AxB + (beta == 0 ? 0 : beta * C[m + ldc * n]);
  40. }
  41. }
  42. } else if (TransB == 'N' || TransB == 'n') {
  43. for (Long n = 0; n < N; n++) { // Columns of C
  44. for (Long m = 0; m < M; m++) { // Rows of C
  45. ValueType AxB = 0;
  46. for (Long k = 0; k < K; k++) {
  47. AxB += A[k + lda * m] * B[k + ldb * n];
  48. }
  49. C[m + ldc * n] = alpha * AxB + (beta == 0 ? 0 : beta * C[m + ldc * n]);
  50. }
  51. }
  52. } else {
  53. for (Long n = 0; n < N; n++) { // Columns of C
  54. for (Long m = 0; m < M; m++) { // Rows of C
  55. ValueType AxB = 0;
  56. for (Long k = 0; k < K; k++) {
  57. AxB += A[k + lda * m] * B[n + ldb * k];
  58. }
  59. C[m + ldc * n] = alpha * AxB + (beta == 0 ? 0 : beta * C[m + ldc * n]);
  60. }
  61. }
  62. }
  63. }
  64. #if defined(SCTL_HAVE_BLAS)
  65. template <> inline void gemm<float>(char TransA, char TransB, int M, int N, int K, float alpha, Iterator<float> A, int lda, Iterator<float> B, int ldb, float beta, Iterator<float> C, int ldc) { sgemm_(&TransA, &TransB, &M, &N, &K, &alpha, &A[0], &lda, &B[0], &ldb, &beta, &C[0], &ldc); }
  66. template <> inline void gemm<double>(char TransA, char TransB, int M, int N, int K, double alpha, Iterator<double> A, int lda, Iterator<double> B, int ldb, double beta, Iterator<double> C, int ldc) { dgemm_(&TransA, &TransB, &M, &N, &K, &alpha, &A[0], &lda, &B[0], &ldb, &beta, &C[0], &ldc); }
  67. #endif
  68. #if defined(SCTL_HAVE_CUDA)
  69. template <> inline void cublasgemm<float>(char TransA, char TransB, int M, int N, int K, float alpha, Iterator<float> A, int lda, Iterator<float> B, int ldb, float beta, Iterator<float> C, int ldc) {
  70. cublasOperation_t cublasTransA, cublasTransB;
  71. cublasHandle_t *handle = CUDA_Lock::acquire_handle();
  72. if (TransA == 'T' || TransA == 't')
  73. cublasTransA = CUBLAS_OP_T;
  74. else if (TransA == 'N' || TransA == 'n')
  75. cublasTransA = CUBLAS_OP_N;
  76. if (TransB == 'T' || TransB == 't')
  77. cublasTransB = CUBLAS_OP_T;
  78. else if (TransB == 'N' || TransB == 'n')
  79. cublasTransB = CUBLAS_OP_N;
  80. cublasStatus_t status = cublasSgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
  81. }
  82. template <> inline void cublasgemm<double>(char TransA, char TransB, int M, int N, int K, double alpha, Iterator<double> A, int lda, Iterator<double> B, int ldb, double beta, Iterator<double> C, int ldc) {
  83. cublasOperation_t cublasTransA, cublasTransB;
  84. cublasHandle_t *handle = CUDA_Lock::acquire_handle();
  85. if (TransA == 'T' || TransA == 't')
  86. cublasTransA = CUBLAS_OP_T;
  87. else if (TransA == 'N' || TransA == 'n')
  88. cublasTransA = CUBLAS_OP_N;
  89. if (TransB == 'T' || TransB == 't')
  90. cublasTransB = CUBLAS_OP_T;
  91. else if (TransB == 'N' || TransB == 'n')
  92. cublasTransB = CUBLAS_OP_N;
  93. cublasStatus_t status = cublasDgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
  94. }
  95. #endif
  96. #define U(i, j) U_[(i) * dim[0] + (j)]
  97. #define S(i, j) S_[(i) * dim[1] + (j)]
  98. #define V(i, j) V_[(i) * dim[1] + (j)]
  99. //#define SVD_DEBUG
  100. template <class ValueType> static inline void GivensL(Iterator<ValueType> S_, StaticArray<Long, 2> &dim, Long m, ValueType a, ValueType b) {
  101. ValueType r = sqrt<ValueType>(a * a + b * b);
  102. if (r == 0) return;
  103. ValueType c = a / r;
  104. ValueType s = -b / r;
  105. #pragma omp parallel for
  106. for (Long i = 0; i < dim[1]; i++) {
  107. ValueType S0 = S(m + 0, i);
  108. ValueType S1 = S(m + 1, i);
  109. S(m, i) += S0 * (c - 1);
  110. S(m, i) += S1 * (-s);
  111. S(m + 1, i) += S0 * (s);
  112. S(m + 1, i) += S1 * (c - 1);
  113. }
  114. }
  115. template <class ValueType> static inline void GivensR(Iterator<ValueType> S_, StaticArray<Long, 2> &dim, Long m, ValueType a, ValueType b) {
  116. ValueType r = sqrt<ValueType>(a * a + b * b);
  117. if (r == 0) return;
  118. ValueType c = a / r;
  119. ValueType s = -b / r;
  120. #pragma omp parallel for
  121. for (Long i = 0; i < dim[0]; i++) {
  122. ValueType S0 = S(i, m + 0);
  123. ValueType S1 = S(i, m + 1);
  124. S(i, m) += S0 * (c - 1);
  125. S(i, m) += S1 * (-s);
  126. S(i, m + 1) += S0 * (s);
  127. S(i, m + 1) += S1 * (c - 1);
  128. }
  129. }
  130. template <class ValueType> static inline void SVD(StaticArray<Long, 2> &dim, Iterator<ValueType> U_, Iterator<ValueType> S_, Iterator<ValueType> V_, ValueType eps = -1) {
  131. assert(dim[0] >= dim[1]);
  132. #ifdef SVD_DEBUG
  133. Matrix<ValueType> M0(dim[0], dim[1], S_);
  134. #endif
  135. { // Bi-diagonalization
  136. Long n = std::min(dim[0], dim[1]);
  137. std::vector<ValueType> house_vec(std::max(dim[0], dim[1]));
  138. for (Long i = 0; i < n; i++) {
  139. // Column Householder
  140. {
  141. ValueType x1 = S(i, i);
  142. if (x1 < 0) x1 = -x1;
  143. ValueType x_inv_norm = 0;
  144. for (Long j = i; j < dim[0]; j++) {
  145. x_inv_norm += S(j, i) * S(j, i);
  146. }
  147. if (x_inv_norm > 0) x_inv_norm = 1 / sqrt<ValueType>(x_inv_norm);
  148. ValueType alpha = sqrt<ValueType>(1 + x1 * x_inv_norm);
  149. ValueType beta = x_inv_norm / alpha;
  150. if (x_inv_norm == 0) alpha = 0; // nothing to do
  151. house_vec[i] = -alpha;
  152. for (Long j = i + 1; j < dim[0]; j++) {
  153. house_vec[j] = -beta * S(j, i);
  154. }
  155. if (S(i, i) < 0)
  156. for (Long j = i + 1; j < dim[0]; j++) {
  157. house_vec[j] = -house_vec[j];
  158. }
  159. }
  160. #pragma omp parallel for
  161. for (Long k = i; k < dim[1]; k++) {
  162. ValueType dot_prod = 0;
  163. for (Long j = i; j < dim[0]; j++) {
  164. dot_prod += S(j, k) * house_vec[j];
  165. }
  166. for (Long j = i; j < dim[0]; j++) {
  167. S(j, k) -= dot_prod * house_vec[j];
  168. }
  169. }
  170. #pragma omp parallel for
  171. for (Long k = 0; k < dim[0]; k++) {
  172. ValueType dot_prod = 0;
  173. for (Long j = i; j < dim[0]; j++) {
  174. dot_prod += U(k, j) * house_vec[j];
  175. }
  176. for (Long j = i; j < dim[0]; j++) {
  177. U(k, j) -= dot_prod * house_vec[j];
  178. }
  179. }
  180. // Row Householder
  181. if (i >= n - 1) continue;
  182. {
  183. ValueType x1 = S(i, i + 1);
  184. if (x1 < 0) x1 = -x1;
  185. ValueType x_inv_norm = 0;
  186. for (Long j = i + 1; j < dim[1]; j++) {
  187. x_inv_norm += S(i, j) * S(i, j);
  188. }
  189. if (x_inv_norm > 0) x_inv_norm = 1 / sqrt<ValueType>(x_inv_norm);
  190. ValueType alpha = sqrt<ValueType>(1 + x1 * x_inv_norm);
  191. ValueType beta = x_inv_norm / alpha;
  192. if (x_inv_norm == 0) alpha = 0; // nothing to do
  193. house_vec[i + 1] = -alpha;
  194. for (Long j = i + 2; j < dim[1]; j++) {
  195. house_vec[j] = -beta * S(i, j);
  196. }
  197. if (S(i, i + 1) < 0)
  198. for (Long j = i + 2; j < dim[1]; j++) {
  199. house_vec[j] = -house_vec[j];
  200. }
  201. }
  202. #pragma omp parallel for
  203. for (Long k = i; k < dim[0]; k++) {
  204. ValueType dot_prod = 0;
  205. for (Long j = i + 1; j < dim[1]; j++) {
  206. dot_prod += S(k, j) * house_vec[j];
  207. }
  208. for (Long j = i + 1; j < dim[1]; j++) {
  209. S(k, j) -= dot_prod * house_vec[j];
  210. }
  211. }
  212. #pragma omp parallel for
  213. for (Long k = 0; k < dim[1]; k++) {
  214. ValueType dot_prod = 0;
  215. for (Long j = i + 1; j < dim[1]; j++) {
  216. dot_prod += V(j, k) * house_vec[j];
  217. }
  218. for (Long j = i + 1; j < dim[1]; j++) {
  219. V(j, k) -= dot_prod * house_vec[j];
  220. }
  221. }
  222. }
  223. }
  224. Long k0 = 0;
  225. Long iter = 0;
  226. if (eps < 0) {
  227. eps = 1.0;
  228. while (eps + (ValueType)1.0 > 1.0) eps *= 0.5;
  229. eps *= 64.0;
  230. }
  231. while (k0 < dim[1] - 1) { // Diagonalization
  232. iter++;
  233. ValueType S_max = 0.0;
  234. for (Long i = 0; i < dim[1]; i++) S_max = (S_max > fabs<ValueType>(S(i, i)) ? S_max : fabs<ValueType>(S(i, i)));
  235. for (Long i = 0; i < dim[1] - 1; i++) S_max = (S_max > fabs<ValueType>(S(i, i + 1)) ? S_max : fabs<ValueType>(S(i, i + 1)));
  236. // while(k0<dim[1]-1 && fabs<ValueType>(S(k0,k0+1))<=eps*(fabs<ValueType>(S(k0,k0))+fabs<ValueType>(S(k0+1,k0+1)))) k0++;
  237. while (k0 < dim[1] - 1 && fabs<ValueType>(S(k0, k0 + 1)) <= eps * S_max) k0++;
  238. if (k0 == dim[1] - 1) continue;
  239. Long n = k0 + 2;
  240. // while(n<dim[1] && fabs<ValueType>(S(n-1,n))>eps*(fabs<ValueType>(S(n-1,n-1))+fabs<ValueType>(S(n,n)))) n++;
  241. while (n < dim[1] && fabs<ValueType>(S(n - 1, n)) > eps * S_max) n++;
  242. ValueType alpha = 0;
  243. ValueType beta = 0;
  244. if (n - k0 == 2 && fabs<ValueType>(S(k0, k0)) < eps * S_max && fabs<ValueType>(S(k0 + 1, k0 + 1)) < eps * S_max) { // Compute mu
  245. alpha=0;
  246. beta=1;
  247. } else {
  248. StaticArray<ValueType, 2 * 2> C;
  249. C[0 * 2 + 0] = S(n - 2, n - 2) * S(n - 2, n - 2);
  250. if (n - k0 > 2) C[0 * 2 + 0] += S(n - 3, n - 2) * S(n - 3, n - 2);
  251. C[0 * 2 + 1] = S(n - 2, n - 2) * S(n - 2, n - 1);
  252. C[1 * 2 + 0] = S(n - 2, n - 2) * S(n - 2, n - 1);
  253. C[1 * 2 + 1] = S(n - 1, n - 1) * S(n - 1, n - 1) + S(n - 2, n - 1) * S(n - 2, n - 1);
  254. ValueType b = -(C[0 * 2 + 0] + C[1 * 2 + 1]) / 2;
  255. ValueType c = C[0 * 2 + 0] * C[1 * 2 + 1] - C[0 * 2 + 1] * C[1 * 2 + 0];
  256. ValueType d = 0;
  257. if (b * b - c > 0)
  258. d = sqrt<ValueType>(b * b - c);
  259. else {
  260. ValueType b = (C[0 * 2 + 0] - C[1 * 2 + 1]) / 2;
  261. ValueType c = -C[0 * 2 + 1] * C[1 * 2 + 0];
  262. if (b * b - c > 0) d = sqrt<ValueType>(b * b - c);
  263. }
  264. ValueType lambda1 = -b + d;
  265. ValueType lambda2 = -b - d;
  266. ValueType d1 = lambda1 - C[1 * 2 + 1];
  267. d1 = (d1 < 0 ? -d1 : d1);
  268. ValueType d2 = lambda2 - C[1 * 2 + 1];
  269. d2 = (d2 < 0 ? -d2 : d2);
  270. ValueType mu = (d1 < d2 ? lambda1 : lambda2);
  271. alpha = S(k0, k0) * S(k0, k0) - mu;
  272. beta = S(k0, k0) * S(k0, k0 + 1);
  273. }
  274. for (Long k = k0; k < n - 1; k++) {
  275. StaticArray<Long, 2> dimU;
  276. dimU[0] = dim[0];
  277. dimU[1] = dim[0];
  278. StaticArray<Long, 2> dimV;
  279. dimV[0] = dim[1];
  280. dimV[1] = dim[1];
  281. GivensR(S_, dim, k, alpha, beta);
  282. GivensL(V_, dimV, k, alpha, beta);
  283. alpha = S(k, k);
  284. beta = S(k + 1, k);
  285. GivensL(S_, dim, k, alpha, beta);
  286. GivensR(U_, dimU, k, alpha, beta);
  287. alpha = S(k, k + 1);
  288. beta = S(k, k + 2);
  289. }
  290. { // Make S bi-diagonal again
  291. for (Long i0 = k0; i0 < n - 1; i0++) {
  292. for (Long i1 = 0; i1 < dim[1]; i1++) {
  293. if (i0 > i1 || i0 + 1 < i1) S(i0, i1) = 0;
  294. }
  295. }
  296. for (Long i0 = 0; i0 < dim[0]; i0++) {
  297. for (Long i1 = k0; i1 < n - 1; i1++) {
  298. if (i0 > i1 || i0 + 1 < i1) S(i0, i1) = 0;
  299. }
  300. }
  301. for (Long i = 0; i < dim[1] - 1; i++) {
  302. if (fabs<ValueType>(S(i, i + 1)) <= eps * S_max) {
  303. S(i, i + 1) = 0;
  304. }
  305. }
  306. }
  307. // std::cout<<iter<<' '<<k0<<' '<<n<<'\n';
  308. }
  309. { // Check Error
  310. #ifdef SVD_DEBUG
  311. Matrix<ValueType> U0(dim[0], dim[0], U_);
  312. Matrix<ValueType> S0(dim[0], dim[1], S_);
  313. Matrix<ValueType> V0(dim[1], dim[1], V_);
  314. Matrix<ValueType> E = M0 - U0 * S0 * V0;
  315. ValueType max_err = 0;
  316. ValueType max_nondiag0 = 0;
  317. ValueType max_nondiag1 = 0;
  318. for (Long i = 0; i < E.Dim(0); i++)
  319. for (Long j = 0; j < E.Dim(1); j++) {
  320. if (max_err < fabs<ValueType>(E[i][j])) max_err = fabs<ValueType>(E[i][j]);
  321. if ((i > j + 0 || i + 0 < j) && max_nondiag0 < fabs<ValueType>(S0[i][j])) max_nondiag0 = fabs<ValueType>(S0[i][j]);
  322. if ((i > j + 1 || i + 1 < j) && max_nondiag1 < fabs<ValueType>(S0[i][j])) max_nondiag1 = fabs<ValueType>(S0[i][j]);
  323. }
  324. std::cout << max_err << '\n';
  325. std::cout << max_nondiag0 << '\n';
  326. std::cout << max_nondiag1 << '\n';
  327. #endif
  328. }
  329. }
  330. #undef U
  331. #undef S
  332. #undef V
  333. #undef SVD_DEBUG
  334. template <class ValueType> inline void svd(char *JOBU, char *JOBVT, int *M, int *N, Iterator<ValueType> A, int *LDA, Iterator<ValueType> S, Iterator<ValueType> U, int *LDU, Iterator<ValueType> VT, int *LDVT, Iterator<ValueType> WORK, int *LWORK, int *INFO) {
  335. StaticArray<Long, 2> dim;
  336. dim[0] = std::max(*N, *M);
  337. dim[1] = std::min(*N, *M);
  338. Iterator<ValueType> U_ = aligned_new<ValueType>(dim[0] * dim[0]);
  339. memset(U_, 0, dim[0] * dim[0]);
  340. Iterator<ValueType> V_ = aligned_new<ValueType>(dim[1] * dim[1]);
  341. memset(V_, 0, dim[1] * dim[1]);
  342. Iterator<ValueType> S_ = aligned_new<ValueType>(dim[0] * dim[1]);
  343. const Long lda = *LDA;
  344. const Long ldu = *LDU;
  345. const Long ldv = *LDVT;
  346. if (dim[1] == *M) {
  347. for (Long i = 0; i < dim[0]; i++)
  348. for (Long j = 0; j < dim[1]; j++) {
  349. S_[i * dim[1] + j] = A[i * lda + j];
  350. }
  351. } else {
  352. for (Long i = 0; i < dim[0]; i++)
  353. for (Long j = 0; j < dim[1]; j++) {
  354. S_[i * dim[1] + j] = A[j * lda + i];
  355. }
  356. }
  357. for (Long i = 0; i < dim[0]; i++) {
  358. U_[i * dim[0] + i] = 1;
  359. }
  360. for (Long i = 0; i < dim[1]; i++) {
  361. V_[i * dim[1] + i] = 1;
  362. }
  363. SVD<ValueType>(dim, U_, S_, V_, (ValueType) - 1);
  364. for (Long i = 0; i < dim[1]; i++) { // Set S
  365. S[i] = S_[i * dim[1] + i];
  366. }
  367. if (dim[1] == *M) { // Set U
  368. for (Long i = 0; i < dim[1]; i++)
  369. for (Long j = 0; j < *M; j++) {
  370. U[j + ldu * i] = V_[j + i * dim[1]] * (S[i] < 0.0 ? -1.0 : 1.0);
  371. }
  372. } else {
  373. for (Long i = 0; i < dim[1]; i++)
  374. for (Long j = 0; j < *M; j++) {
  375. U[j + ldu * i] = U_[i + j * dim[0]] * (S[i] < 0.0 ? -1.0 : 1.0);
  376. }
  377. }
  378. if (dim[0] == *N) { // Set V
  379. for (Long i = 0; i < *N; i++)
  380. for (Long j = 0; j < dim[1]; j++) {
  381. VT[j + ldv * i] = U_[j + i * dim[0]];
  382. }
  383. } else {
  384. for (Long i = 0; i < *N; i++)
  385. for (Long j = 0; j < dim[1]; j++) {
  386. VT[j + ldv * i] = V_[i + j * dim[1]];
  387. }
  388. }
  389. for (Long i = 0; i < dim[1]; i++) {
  390. S[i] = S[i] * (S[i] < 0.0 ? -1.0 : 1.0);
  391. }
  392. aligned_delete<ValueType>(U_);
  393. aligned_delete<ValueType>(S_);
  394. aligned_delete<ValueType>(V_);
  395. if (0) { // Verify
  396. StaticArray<Long, 2> dim;
  397. dim[0] = std::max(*N, *M);
  398. dim[1] = std::min(*N, *M);
  399. const Long lda = *LDA;
  400. const Long ldu = *LDU;
  401. const Long ldv = *LDVT;
  402. Matrix<ValueType> A1(*M, *N);
  403. Matrix<ValueType> S1(dim[1], dim[1]);
  404. Matrix<ValueType> U1(*M, dim[1]);
  405. Matrix<ValueType> V1(dim[1], *N);
  406. for (Long i = 0; i < *N; i++)
  407. for (Long j = 0; j < *M; j++) {
  408. A1[j][i] = A[j + i * lda];
  409. }
  410. S1.SetZero();
  411. for (Long i = 0; i < dim[1]; i++) { // Set S
  412. S1[i][i] = S[i];
  413. }
  414. for (Long i = 0; i < dim[1]; i++)
  415. for (Long j = 0; j < *M; j++) {
  416. U1[j][i] = U[j + ldu * i];
  417. }
  418. for (Long i = 0; i < *N; i++)
  419. for (Long j = 0; j < dim[1]; j++) {
  420. V1[j][i] = VT[j + ldv * i];
  421. }
  422. std::cout << U1 *S1 *V1 - A1 << '\n';
  423. }
  424. }
  425. #if defined(SCTL_HAVE_LAPACK)
  426. template <> inline void svd<float>(char *JOBU, char *JOBVT, int *M, int *N, Iterator<float> A, int *LDA, Iterator<float> S, Iterator<float> U, int *LDU, Iterator<float> VT, int *LDVT, Iterator<float> WORK, int *LWORK, int *INFO) { sgesvd_(JOBU, JOBVT, M, N, &A[0], LDA, &S[0], &U[0], LDU, &VT[0], LDVT, &WORK[0], LWORK, INFO); }
  427. template <> inline void svd<double>(char *JOBU, char *JOBVT, int *M, int *N, Iterator<double> A, int *LDA, Iterator<double> S, Iterator<double> U, int *LDU, Iterator<double> VT, int *LDVT, Iterator<double> WORK, int *LWORK, int *INFO) { dgesvd_(JOBU, JOBVT, M, N, &A[0], LDA, &S[0], &U[0], LDU, &VT[0], LDVT, &WORK[0], LWORK, INFO); }
  428. #endif
  429. /**
  430. * \brief Computes the pseudo inverse of matrix M(n1xn2) (in row major form)
  431. * and returns the output M_(n2xn1). Original contents of M are destroyed.
  432. */
  433. template <class ValueType> inline void pinv(Iterator<ValueType> M, int n1, int n2, ValueType eps, Iterator<ValueType> M_) {
  434. if (n1 * n2 == 0) return;
  435. int m = n2;
  436. int n = n1;
  437. int k = (m < n ? m : n);
  438. Iterator<ValueType> tU = aligned_new<ValueType>(m * k);
  439. Iterator<ValueType> tS = aligned_new<ValueType>(k);
  440. Iterator<ValueType> tVT = aligned_new<ValueType>(k * n);
  441. // SVD
  442. int INFO = 0;
  443. char JOBU = 'S';
  444. char JOBVT = 'S';
  445. // int wssize = max(3*min(m,n)+max(m,n), 5*min(m,n));
  446. int wssize = 3 * (m < n ? m : n) + (m > n ? m : n);
  447. int wssize1 = 5 * (m < n ? m : n);
  448. wssize = (wssize > wssize1 ? wssize : wssize1);
  449. Iterator<ValueType> wsbuf = aligned_new<ValueType>(wssize);
  450. svd(&JOBU, &JOBVT, &m, &n, M, &m, tS, tU, &m, tVT, &k, wsbuf, &wssize, &INFO);
  451. if (INFO != 0) std::cout << INFO << '\n';
  452. assert(INFO == 0);
  453. aligned_delete<ValueType>(wsbuf);
  454. ValueType eps_ = tS[0] * eps;
  455. for (int i = 0; i < k; i++)
  456. if (tS[i] < eps_)
  457. tS[i] = 0;
  458. else
  459. tS[i] = 1.0 / tS[i];
  460. for (int i = 0; i < m; i++) {
  461. for (int j = 0; j < k; j++) {
  462. tU[i + j * m] *= tS[j];
  463. }
  464. }
  465. gemm<ValueType>('T', 'T', n, m, k, 1.0, tVT, k, tU, m, 0.0, M_, n);
  466. aligned_delete<ValueType>(tU);
  467. aligned_delete<ValueType>(tS);
  468. aligned_delete<ValueType>(tVT);
  469. }
  470. } // end namespace mat
  471. } // end namespace SCTL_NAMESPACE