mat_utils.txx 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. /**
  2. * \file mat_utils.txx
  3. * \author Dhairya Malhotra, dhairya.malhotra@gmail.com
  4. * \date 2-11-2011
  5. * \brief This file contains BLAS and LAPACK wrapper functions.
  6. */
  7. #include <cassert>
  8. #include <vector>
  9. #include <iostream>
  10. #include <stdint.h>
  11. #include <math.h>
  12. #include <blas.h>
  13. #include <lapack.h>
  14. #include <fft_wrapper.hpp>
  15. #include <device_wrapper.hpp>
  16. #if defined(PVFMM_HAVE_CUDA)
  17. #include <cuda_runtime_api.h>
  18. #include <cublas_v2.h>
  19. #endif
  20. namespace pvfmm{
  21. namespace mat{
  22. inline void gemm(char TransA, char TransB, int M, int N, int K, float alpha, float *A, int lda, float *B, int ldb, float beta, float *C, int ldc){
  23. sgemm_(&TransA, &TransB, &M, &N, &K, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
  24. }
  25. inline void gemm(char TransA, char TransB, int M, int N, int K, double alpha, double *A, int lda, double *B, int ldb, double beta, double *C, int ldc){
  26. dgemm_(&TransA, &TransB, &M, &N, &K, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
  27. }
  28. #if defined(PVFMM_HAVE_CUDA)
  29. // cublasDgemm wrapper
  30. inline void cublasXgemm(char TransA, char TransB, int M, int N, int K, double alpha,
  31. double *A, int lda, double *B, int ldb, double beta, double *C, int ldc){
  32. cublasOperation_t cublasTransA, cublasTransB;
  33. cublasStatus_t status;
  34. cublasHandle_t *handle;
  35. handle = CUDA_Lock::acquire_handle();
  36. if (TransA == 'T' || TransA == 't') cublasTransA = CUBLAS_OP_T;
  37. else if (TransA == 'N' || TransA == 'n') cublasTransA = CUBLAS_OP_N;
  38. if (TransB == 'T' || TransB == 't') cublasTransB = CUBLAS_OP_T;
  39. else if (TransB == 'N' || TransB == 'n') cublasTransB = CUBLAS_OP_N;
  40. //if (N) std::cout << "cublasDgemm (" << M << ", " << N << ", " << K << ");" << '\n';
  41. status = cublasDgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
  42. }
  43. // cublasSgemm wrapper
  44. inline void cublasXgemm(char TransA, char TransB, int M, int N, int K, float alpha,
  45. float *A, int lda, float *B, int ldb, float beta, float *C, int ldc) {
  46. cublasOperation_t cublasTransA, cublasTransB;
  47. cublasStatus_t status;
  48. cublasHandle_t *handle;
  49. handle = CUDA_Lock::acquire_handle();
  50. if (TransA == 'T' || TransA == 't') cublasTransA = CUBLAS_OP_T;
  51. else if (TransA == 'N' || TransA == 'n') cublasTransA = CUBLAS_OP_N;
  52. if (TransB == 'T' || TransB == 't') cublasTransB = CUBLAS_OP_T;
  53. else if (TransB == 'N' || TransB == 'n') cublasTransB = CUBLAS_OP_N;
  54. if (N) std::cout << "cublasSgemm (" << M << ", " << N << ", " << K << ");" << '\n';
  55. status = cublasSgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
  56. }
  57. #endif
  58. inline void svd(char *JOBU, char *JOBVT, int *M, int *N, float *A, int *LDA,
  59. float *S, float *U, int *LDU, float *VT, int *LDVT, float *WORK, int *LWORK,
  60. int *INFO){
  61. sgesvd_(JOBU,JOBVT,M,N,A,LDA,S,U,LDU,VT,LDVT,WORK,LWORK,INFO);
  62. }
  63. inline void svd(char *JOBU, char *JOBVT, int *M, int *N, double *A, int *LDA,
  64. double *S, double *U, int *LDU, double *VT, int *LDVT, double *WORK, int *LWORK,
  65. int *INFO){
  66. dgesvd_(JOBU,JOBVT,M,N,A,LDA,S,U,LDU,VT,LDVT,WORK,LWORK,INFO);
  67. }
  68. /**
  69. * \brief Computes the pseudo inverse of matrix M(n1xn2) (in row major form)
  70. * and returns the output M_(n2xn1).
  71. */
  72. template <class T>
  73. void pinv(T* M, int n1, int n2, T eps, T* M_){
  74. int m = n2;
  75. int n = n1;
  76. int k = (m<n?m:n);
  77. std::vector<T> tU(m*k);
  78. std::vector<T> tS(k);
  79. std::vector<T> tVT(k*n);
  80. //SVD
  81. int INFO=0;
  82. char JOBU = 'S';
  83. char JOBVT = 'S';
  84. //int wssize = max(3*min(m,n)+max(m,n), 5*min(m,n));
  85. int wssize = 3*(m<n?m:n)+(m>n?m:n);
  86. int wssize1 = 5*(m<n?m:n);
  87. wssize = (wssize>wssize1?wssize:wssize1);
  88. T* wsbuf = new T[wssize];
  89. svd(&JOBU, &JOBVT, &m, &n, &M[0], &m, &tS[0], &tU[0], &m, &tVT[0], &k,
  90. wsbuf, &wssize, &INFO);
  91. if(INFO!=0)
  92. std::cout<<INFO<<'\n';
  93. assert(INFO==0);
  94. delete [] wsbuf;
  95. T eps_=tS[0]*eps;
  96. for(int i=0;i<k;i++)
  97. if(tS[i]<eps_)
  98. tS[i]=0;
  99. else
  100. tS[i]=1.0/tS[i];
  101. for(int i=0;i<m;i++){
  102. for(int j=0;j<k;j++){
  103. tU[i+j*m]*=tS[j];
  104. }
  105. }
  106. gemm('T','T',n,m,k,1.0,&tVT[0],k,&tU[0],m,0.0,M_,n);
  107. }
  108. }//end namespace
  109. }//end namespace