mat_utils.txx 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. /**
  2. * \file mat_utils.txx
  3. * \author Dhairya Malhotra, dhairya.malhotra@gmail.com
  4. * \date 2-11-2011
  5. * \brief This file contains BLAS and LAPACK wrapper functions.
  6. */
  7. #include <omp.h>
  8. #include <cmath>
  9. #include <cassert>
  10. #include <cstdlib>
  11. #include <algorithm>
  12. #include <iostream>
  13. #include <vector>
  14. #include <matrix.hpp>
  15. #include <device_wrapper.hpp>
  16. #if defined(PVFMM_HAVE_CUDA)
  17. #include <cuda_runtime_api.h>
  18. #include <cublas_v2.h>
  19. #endif
  20. namespace pvfmm{
  21. namespace mat{
  22. template <class T>
  23. inline void gemm(char TransA, char TransB, int M, int N, int K, T alpha, T *A, int lda, T *B, int ldb, T beta, T *C, int ldc){
  24. if((TransA=='N' || TransA=='n') && (TransB=='N' || TransB=='n')){
  25. for(size_t n=0;n<N;n++){ // Columns of C
  26. for(size_t m=0;m<M;m++){ // Rows of C
  27. T AxB=0;
  28. for(size_t k=0;k<K;k++){
  29. AxB+=A[m+lda*k]*B[k+ldb*n];
  30. }
  31. C[m+ldc*n]=alpha*AxB+(beta==0?0:beta*C[m+ldc*n]);
  32. }
  33. }
  34. }else if(TransA=='N' || TransA=='n'){
  35. for(size_t n=0;n<N;n++){ // Columns of C
  36. for(size_t m=0;m<M;m++){ // Rows of C
  37. T AxB=0;
  38. for(size_t k=0;k<K;k++){
  39. AxB+=A[m+lda*k]*B[n+ldb*k];
  40. }
  41. C[m+ldc*n]=alpha*AxB+(beta==0?0:beta*C[m+ldc*n]);
  42. }
  43. }
  44. }else if(TransB=='N' || TransB=='n'){
  45. for(size_t n=0;n<N;n++){ // Columns of C
  46. for(size_t m=0;m<M;m++){ // Rows of C
  47. T AxB=0;
  48. for(size_t k=0;k<K;k++){
  49. AxB+=A[k+lda*m]*B[k+ldb*n];
  50. }
  51. C[m+ldc*n]=alpha*AxB+(beta==0?0:beta*C[m+ldc*n]);
  52. }
  53. }
  54. }else{
  55. for(size_t n=0;n<N;n++){ // Columns of C
  56. for(size_t m=0;m<M;m++){ // Rows of C
  57. T AxB=0;
  58. for(size_t k=0;k<K;k++){
  59. AxB+=A[k+lda*m]*B[n+ldb*k];
  60. }
  61. C[m+ldc*n]=alpha*AxB+(beta==0?0:beta*C[m+ldc*n]);
  62. }
  63. }
  64. }
  65. }
  66. template<>
  67. void gemm<float>(char TransA, char TransB, int M, int N, int K, float alpha, float *A, int lda, float *B, int ldb, float beta, float *C, int ldc);
  68. template<>
  69. void gemm<double>(char TransA, char TransB, int M, int N, int K, double alpha, double *A, int lda, double *B, int ldb, double beta, double *C, int ldc);
  70. #if defined(PVFMM_HAVE_CUDA)
  71. template<>
  72. inline void cublasgemm<float>(char TransA, char TransB, int M, int N, int K, float alpha, float *A, int lda, float *B, int ldb, float beta, float *C, int ldc) {
  73. cublasOperation_t cublasTransA, cublasTransB;
  74. cublasHandle_t *handle = CUDA_Lock::acquire_handle();
  75. if (TransA == 'T' || TransA == 't') cublasTransA = CUBLAS_OP_T;
  76. else if (TransA == 'N' || TransA == 'n') cublasTransA = CUBLAS_OP_N;
  77. if (TransB == 'T' || TransB == 't') cublasTransB = CUBLAS_OP_T;
  78. else if (TransB == 'N' || TransB == 'n') cublasTransB = CUBLAS_OP_N;
  79. cublasStatus_t status = cublasSgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
  80. }
  81. template<>
  82. inline void cublasgemm<double>(char TransA, char TransB, int M, int N, int K, double alpha, double *A, int lda, double *B, int ldb, double beta, double *C, int ldc){
  83. cublasOperation_t cublasTransA, cublasTransB;
  84. cublasHandle_t *handle = CUDA_Lock::acquire_handle();
  85. if (TransA == 'T' || TransA == 't') cublasTransA = CUBLAS_OP_T;
  86. else if (TransA == 'N' || TransA == 'n') cublasTransA = CUBLAS_OP_N;
  87. if (TransB == 'T' || TransB == 't') cublasTransB = CUBLAS_OP_T;
  88. else if (TransB == 'N' || TransB == 'n') cublasTransB = CUBLAS_OP_N;
  89. cublasStatus_t status = cublasDgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
  90. }
  91. #endif
  92. #define U(i,j) U_[(i)*dim[0]+(j)]
  93. #define S(i,j) S_[(i)*dim[1]+(j)]
  94. #define V(i,j) V_[(i)*dim[1]+(j)]
  95. //#define SVD_DEBUG
  96. template <class T>
  97. static void GivensL(T* S_, const size_t dim[2], size_t m, T a, T b){
  98. T r=pvfmm::sqrt<T>(a*a+b*b);
  99. if (r == 0) return;
  100. T c=a/r;
  101. T s=-b/r;
  102. #pragma omp parallel for
  103. for(size_t i=0;i<dim[1];i++){
  104. T S0=S(m+0,i);
  105. T S1=S(m+1,i);
  106. S(m ,i)+=S0*(c-1);
  107. S(m ,i)+=S1*(-s );
  108. S(m+1,i)+=S0*( s );
  109. S(m+1,i)+=S1*(c-1);
  110. }
  111. }
  112. template <class T>
  113. static void GivensR(T* S_, const size_t dim[2], size_t m, T a, T b){
  114. T r=pvfmm::sqrt<T>(a*a+b*b);
  115. if (r == 0) return;
  116. T c=a/r;
  117. T s=-b/r;
  118. #pragma omp parallel for
  119. for(size_t i=0;i<dim[0];i++){
  120. T S0=S(i,m+0);
  121. T S1=S(i,m+1);
  122. S(i,m )+=S0*(c-1);
  123. S(i,m )+=S1*(-s );
  124. S(i,m+1)+=S0*( s );
  125. S(i,m+1)+=S1*(c-1);
  126. }
  127. }
  128. template <class T>
  129. static void SVD(const size_t dim[2], T* U_, T* S_, T* V_, T eps=-1){
  130. assert(dim[0]>=dim[1]);
  131. #ifdef SVD_DEBUG
  132. Matrix<T> M0(dim[0],dim[1],S_);
  133. #endif
  134. { // Bi-diagonalization
  135. size_t n=std::min(dim[0],dim[1]);
  136. std::vector<T> house_vec(std::max(dim[0],dim[1]));
  137. for(size_t i=0;i<n;i++){
  138. // Column Householder
  139. {
  140. T x1=S(i,i);
  141. if(x1<0) x1=-x1;
  142. T x_inv_norm=0;
  143. for(size_t j=i;j<dim[0];j++){
  144. x_inv_norm+=S(j,i)*S(j,i);
  145. }
  146. if(x_inv_norm>0) x_inv_norm=1/pvfmm::sqrt<T>(x_inv_norm);
  147. T alpha=pvfmm::sqrt<T>(1+x1*x_inv_norm);
  148. T beta=x_inv_norm/alpha;
  149. if (x_inv_norm == 0) alpha = 0; // nothing to do
  150. house_vec[i]=-alpha;
  151. for(size_t j=i+1;j<dim[0];j++){
  152. house_vec[j]=-beta*S(j,i);
  153. }
  154. if(S(i,i)<0) for(size_t j=i+1;j<dim[0];j++){
  155. house_vec[j]=-house_vec[j];
  156. }
  157. }
  158. #pragma omp parallel for
  159. for(size_t k=i;k<dim[1];k++){
  160. T dot_prod=0;
  161. for(size_t j=i;j<dim[0];j++){
  162. dot_prod+=S(j,k)*house_vec[j];
  163. }
  164. for(size_t j=i;j<dim[0];j++){
  165. S(j,k)-=dot_prod*house_vec[j];
  166. }
  167. }
  168. #pragma omp parallel for
  169. for(size_t k=0;k<dim[0];k++){
  170. T dot_prod=0;
  171. for(size_t j=i;j<dim[0];j++){
  172. dot_prod+=U(k,j)*house_vec[j];
  173. }
  174. for(size_t j=i;j<dim[0];j++){
  175. U(k,j)-=dot_prod*house_vec[j];
  176. }
  177. }
  178. // Row Householder
  179. if(i>=n-1) continue;
  180. {
  181. T x1=S(i,i+1);
  182. if(x1<0) x1=-x1;
  183. T x_inv_norm=0;
  184. for(size_t j=i+1;j<dim[1];j++){
  185. x_inv_norm+=S(i,j)*S(i,j);
  186. }
  187. if(x_inv_norm>0) x_inv_norm=1/pvfmm::sqrt<T>(x_inv_norm);
  188. T alpha=pvfmm::sqrt<T>(1+x1*x_inv_norm);
  189. T beta=x_inv_norm/alpha;
  190. if (x_inv_norm == 0) alpha = 0; // nothing to do
  191. house_vec[i+1]=-alpha;
  192. for(size_t j=i+2;j<dim[1];j++){
  193. house_vec[j]=-beta*S(i,j);
  194. }
  195. if(S(i,i+1)<0) for(size_t j=i+2;j<dim[1];j++){
  196. house_vec[j]=-house_vec[j];
  197. }
  198. }
  199. #pragma omp parallel for
  200. for(size_t k=i;k<dim[0];k++){
  201. T dot_prod=0;
  202. for(size_t j=i+1;j<dim[1];j++){
  203. dot_prod+=S(k,j)*house_vec[j];
  204. }
  205. for(size_t j=i+1;j<dim[1];j++){
  206. S(k,j)-=dot_prod*house_vec[j];
  207. }
  208. }
  209. #pragma omp parallel for
  210. for(size_t k=0;k<dim[1];k++){
  211. T dot_prod=0;
  212. for(size_t j=i+1;j<dim[1];j++){
  213. dot_prod+=V(j,k)*house_vec[j];
  214. }
  215. for(size_t j=i+1;j<dim[1];j++){
  216. V(j,k)-=dot_prod*house_vec[j];
  217. }
  218. }
  219. }
  220. }
  221. size_t k0=0;
  222. size_t iter=0;
  223. if(eps<0){
  224. eps=1.0;
  225. while(eps+(T)1.0>1.0) eps*=0.5;
  226. eps*=64.0;
  227. }
  228. while(k0<dim[1]-1){ // Diagonalization
  229. iter++;
  230. T S_max=0.0;
  231. for(size_t i=0;i<dim[1];i++) S_max=(S_max>S(i,i)?S_max:S(i,i));
  232. //while(k0<dim[1]-1 && pvfmm::fabs<T>(S(k0,k0+1))<=eps*(pvfmm::fabs<T>(S(k0,k0))+pvfmm::fabs<T>(S(k0+1,k0+1)))) k0++;
  233. while(k0<dim[1]-1 && pvfmm::fabs<T>(S(k0,k0+1))<=eps*S_max) k0++;
  234. if(k0==dim[1]-1) continue;
  235. size_t n=k0+2;
  236. //while(n<dim[1] && pvfmm::fabs<T>(S(n-1,n))>eps*(pvfmm::fabs<T>(S(n-1,n-1))+pvfmm::fabs<T>(S(n,n)))) n++;
  237. while(n<dim[1] && pvfmm::fabs<T>(S(n-1,n))>eps*S_max) n++;
  238. T alpha=0;
  239. T beta=0;
  240. if (n - k0 == 2 && S(k0, k0) == 0 && S(k0 + 1, k0 + 1) == 0) { // Compute mu
  241. alpha=0;
  242. beta=1;
  243. } else {
  244. T C[2][2];
  245. C[0][0]=S(n-2,n-2)*S(n-2,n-2);
  246. if(n-k0>2) C[0][0]+=S(n-3,n-2)*S(n-3,n-2);
  247. C[0][1]=S(n-2,n-2)*S(n-2,n-1);
  248. C[1][0]=S(n-2,n-2)*S(n-2,n-1);
  249. C[1][1]=S(n-1,n-1)*S(n-1,n-1)+S(n-2,n-1)*S(n-2,n-1);
  250. T b=-(C[0][0]+C[1][1])/2;
  251. T c= C[0][0]*C[1][1] - C[0][1]*C[1][0];
  252. T d=0;
  253. if(b*b-c>0) d=pvfmm::sqrt<T>(b*b-c);
  254. else{
  255. T b=(C[0][0]-C[1][1])/2;
  256. T c=-C[0][1]*C[1][0];
  257. if(b*b-c>0) d=pvfmm::sqrt<T>(b*b-c);
  258. }
  259. T lambda1=-b+d;
  260. T lambda2=-b-d;
  261. T d1=lambda1-C[1][1]; d1=(d1<0?-d1:d1);
  262. T d2=lambda2-C[1][1]; d2=(d2<0?-d2:d2);
  263. T mu=(d1<d2?lambda1:lambda2);
  264. alpha=S(k0,k0)*S(k0,k0)-mu;
  265. beta=S(k0,k0)*S(k0,k0+1);
  266. }
  267. for(size_t k=k0;k<n-1;k++){
  268. size_t dimU[2]={dim[0],dim[0]};
  269. size_t dimV[2]={dim[1],dim[1]};
  270. GivensR(S_,dim ,k,alpha,beta);
  271. GivensL(V_,dimV,k,alpha,beta);
  272. alpha=S(k,k);
  273. beta=S(k+1,k);
  274. GivensL(S_,dim ,k,alpha,beta);
  275. GivensR(U_,dimU,k,alpha,beta);
  276. alpha=S(k,k+1);
  277. beta=S(k,k+2);
  278. }
  279. { // Make S bi-diagonal again
  280. for(size_t i0=k0;i0<n-1;i0++){
  281. for(size_t i1=0;i1<dim[1];i1++){
  282. if(i0>i1 || i0+1<i1) S(i0,i1)=0;
  283. }
  284. }
  285. for(size_t i0=0;i0<dim[0];i0++){
  286. for(size_t i1=k0;i1<n-1;i1++){
  287. if(i0>i1 || i0+1<i1) S(i0,i1)=0;
  288. }
  289. }
  290. for(size_t i=0;i<dim[1]-1;i++){
  291. if(pvfmm::fabs<T>(S(i,i+1))<=eps*S_max){
  292. S(i,i+1)=0;
  293. }
  294. }
  295. }
  296. //std::cout<<iter<<' '<<k0<<' '<<n<<'\n';
  297. }
  298. { // Check Error
  299. #ifdef SVD_DEBUG
  300. Matrix<T> U0(dim[0],dim[0],U_);
  301. Matrix<T> S0(dim[0],dim[1],S_);
  302. Matrix<T> V0(dim[1],dim[1],V_);
  303. Matrix<T> E=M0-U0*S0*V0;
  304. T max_err=0;
  305. T max_nondiag0=0;
  306. T max_nondiag1=0;
  307. for(size_t i=0;i<E.Dim(0);i++)
  308. for(size_t j=0;j<E.Dim(1);j++){
  309. if(max_err<pvfmm::fabs<T>(E[i][j])) max_err=pvfmm::fabs<T>(E[i][j]);
  310. if((i>j+0 || i+0<j) && max_nondiag0<pvfmm::fabs<T>(S0[i][j])) max_nondiag0=pvfmm::fabs<T>(S0[i][j]);
  311. if((i>j+1 || i+1<j) && max_nondiag1<pvfmm::fabs<T>(S0[i][j])) max_nondiag1=pvfmm::fabs<T>(S0[i][j]);
  312. }
  313. std::cout<<max_err<<'\n';
  314. std::cout<<max_nondiag0<<'\n';
  315. std::cout<<max_nondiag1<<'\n';
  316. #endif
  317. }
  318. }
  319. #undef U
  320. #undef S
  321. #undef V
  322. #undef SVD_DEBUG
  323. template<class T>
  324. inline void svd(char *JOBU, char *JOBVT, int *M, int *N, T *A, int *LDA,
  325. T *S, T *U, int *LDU, T *VT, int *LDVT, T *WORK, int *LWORK,
  326. int *INFO){
  327. const size_t dim[2]={std::max(*N,*M), std::min(*N,*M)};
  328. T* U_=mem::aligned_new<T>(dim[0]*dim[0]); memset(U_, 0, dim[0]*dim[0]*sizeof(T));
  329. T* V_=mem::aligned_new<T>(dim[1]*dim[1]); memset(V_, 0, dim[1]*dim[1]*sizeof(T));
  330. T* S_=mem::aligned_new<T>(dim[0]*dim[1]);
  331. const size_t lda=*LDA;
  332. const size_t ldu=*LDU;
  333. const size_t ldv=*LDVT;
  334. if(dim[1]==*M){
  335. for(size_t i=0;i<dim[0];i++)
  336. for(size_t j=0;j<dim[1];j++){
  337. S_[i*dim[1]+j]=A[i*lda+j];
  338. }
  339. }else{
  340. for(size_t i=0;i<dim[0];i++)
  341. for(size_t j=0;j<dim[1];j++){
  342. S_[i*dim[1]+j]=A[j*lda+i];
  343. }
  344. }
  345. for(size_t i=0;i<dim[0];i++){
  346. U_[i*dim[0]+i]=1;
  347. }
  348. for(size_t i=0;i<dim[1];i++){
  349. V_[i*dim[1]+i]=1;
  350. }
  351. SVD<T>(dim, U_, S_, V_, (T)-1);
  352. for(size_t i=0;i<dim[1];i++){ // Set S
  353. S[i]=S_[i*dim[1]+i];
  354. }
  355. if(dim[1]==*M){ // Set U
  356. for(size_t i=0;i<dim[1];i++)
  357. for(size_t j=0;j<*M;j++){
  358. U[j+ldu*i]=V_[j+i*dim[1]]*(S[i]<0.0?-1.0:1.0);
  359. }
  360. }else{
  361. for(size_t i=0;i<dim[1];i++)
  362. for(size_t j=0;j<*M;j++){
  363. U[j+ldu*i]=U_[i+j*dim[0]]*(S[i]<0.0?-1.0:1.0);
  364. }
  365. }
  366. if(dim[0]==*N){ // Set V
  367. for(size_t i=0;i<*N;i++)
  368. for(size_t j=0;j<dim[1];j++){
  369. VT[j+ldv*i]=U_[j+i*dim[0]];
  370. }
  371. }else{
  372. for(size_t i=0;i<*N;i++)
  373. for(size_t j=0;j<dim[1];j++){
  374. VT[j+ldv*i]=V_[i+j*dim[1]];
  375. }
  376. }
  377. for(size_t i=0;i<dim[1];i++){
  378. S[i]=S[i]*(S[i]<0.0?-1.0:1.0);
  379. }
  380. mem::aligned_delete<T>(U_);
  381. mem::aligned_delete<T>(S_);
  382. mem::aligned_delete<T>(V_);
  383. if(0){ // Verify
  384. const size_t dim[2]={std::max(*N,*M), std::min(*N,*M)};
  385. const size_t lda=*LDA;
  386. const size_t ldu=*LDU;
  387. const size_t ldv=*LDVT;
  388. Matrix<T> A1(*M,*N);
  389. Matrix<T> S1(dim[1],dim[1]);
  390. Matrix<T> U1(*M,dim[1]);
  391. Matrix<T> V1(dim[1],*N);
  392. for(size_t i=0;i<*N;i++)
  393. for(size_t j=0;j<*M;j++){
  394. A1[j][i]=A[j+i*lda];
  395. }
  396. S1.SetZero();
  397. for(size_t i=0;i<dim[1];i++){ // Set S
  398. S1[i][i]=S[i];
  399. }
  400. for(size_t i=0;i<dim[1];i++)
  401. for(size_t j=0;j<*M;j++){
  402. U1[j][i]=U[j+ldu*i];
  403. }
  404. for(size_t i=0;i<*N;i++)
  405. for(size_t j=0;j<dim[1];j++){
  406. V1[j][i]=VT[j+ldv*i];
  407. }
  408. std::cout<<U1*S1*V1-A1<<'\n';
  409. }
  410. }
  411. template<>
  412. void svd<float>(char *JOBU, char *JOBVT, int *M, int *N, float *A, int *LDA,
  413. float *S, float *U, int *LDU, float *VT, int *LDVT, float *WORK, int *LWORK,
  414. int *INFO);
  415. template<>
  416. void svd<double>(char *JOBU, char *JOBVT, int *M, int *N, double *A, int *LDA,
  417. double *S, double *U, int *LDU, double *VT, int *LDVT, double *WORK, int *LWORK,
  418. int *INFO);
  419. /**
  420. * \brief Computes the pseudo inverse of matrix M(n1xn2) (in row major form)
  421. * and returns the output M_(n2xn1). Original contents of M are destroyed.
  422. */
  423. template <class T>
  424. void pinv(T* M, int n1, int n2, T eps, T* M_){
  425. if(n1*n2==0) return;
  426. int m = n2;
  427. int n = n1;
  428. int k = (m<n?m:n);
  429. T* tU =mem::aligned_new<T>(m*k);
  430. T* tS =mem::aligned_new<T>(k);
  431. T* tVT=mem::aligned_new<T>(k*n);
  432. //SVD
  433. int INFO=0;
  434. char JOBU = 'S';
  435. char JOBVT = 'S';
  436. //int wssize = max(3*min(m,n)+max(m,n), 5*min(m,n));
  437. int wssize = 3*(m<n?m:n)+(m>n?m:n);
  438. int wssize1 = 5*(m<n?m:n);
  439. wssize = (wssize>wssize1?wssize:wssize1);
  440. T* wsbuf = mem::aligned_new<T>(wssize);
  441. svd(&JOBU, &JOBVT, &m, &n, &M[0], &m, &tS[0], &tU[0], &m, &tVT[0], &k,
  442. wsbuf, &wssize, &INFO);
  443. if(INFO!=0)
  444. std::cout<<INFO<<'\n';
  445. assert(INFO==0);
  446. mem::aligned_delete<T>(wsbuf);
  447. T eps_=tS[0]*eps;
  448. for(int i=0;i<k;i++)
  449. if(tS[i]<eps_)
  450. tS[i]=0;
  451. else
  452. tS[i]=1.0/tS[i];
  453. for(int i=0;i<m;i++){
  454. for(int j=0;j<k;j++){
  455. tU[i+j*m]*=tS[j];
  456. }
  457. }
  458. gemm<T>('T','T',n,m,k,1.0,&tVT[0],k,&tU[0],m,0.0,M_,n);
  459. mem::aligned_delete<T>(tU);
  460. mem::aligned_delete<T>(tS);
  461. mem::aligned_delete<T>(tVT);
  462. }
  463. }//end namespace
  464. }//end namespace