mat_utils.txx 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. /**
  2. * \file mat_utils.txx
  3. * \author Dhairya Malhotra, dhairya.malhotra@gmail.com
  4. * \date 2-11-2011
  5. * \brief This file contains BLAS and LAPACK wrapper functions.
  6. */
  7. #include <cassert>
  8. #include <vector>
  9. #include <iostream>
  10. #include <stdint.h>
  11. #include <math.h>
  12. #include <blas.h>
  13. #include <lapack.h>
  14. #include <fft_wrapper.hpp>
  15. #include <device_wrapper.hpp>
  16. #if defined(PVFMM_HAVE_CUDA)
  17. #include <cuda_runtime_api.h>
  18. #include <cublas_v2.h>
  19. #endif
  20. namespace pvfmm{
  21. namespace mat{
  22. inline void gemm(char TransA, char TransB, int M, int N, int K, float alpha, float *A, int lda, float *B, int ldb, float beta, float *C, int ldc){
  23. sgemm_(&TransA, &TransB, &M, &N, &K, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
  24. }
  25. inline void gemm(char TransA, char TransB, int M, int N, int K, double alpha, double *A, int lda, double *B, int ldb, double beta, double *C, int ldc){
  26. dgemm_(&TransA, &TransB, &M, &N, &K, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
  27. }
  28. #if defined(PVFMM_HAVE_CUDA)
  29. // cublasDgemm wrapper
  30. inline void cublasXgemm(char TransA, char TransB, int M, int N, int K, double alpha,
  31. double *A, int lda, double *B, int ldb, double beta, double *C, int ldc){
  32. cublasOperation_t cublasTransA, cublasTransB;
  33. cublasStatus_t status;
  34. cublasHandle_t *handle;
  35. handle = CUDA_Lock::acquire_handle();
  36. /* Need exeception handling if (handle) */
  37. if (TransA == 'T' || TransA == 't') cublasTransA = CUBLAS_OP_T;
  38. else if (TransA == 'N' || TransA == 'n') cublasTransA = CUBLAS_OP_N;
  39. if (TransB == 'T' || TransB == 't') cublasTransB = CUBLAS_OP_T;
  40. else if (TransB == 'N' || TransB == 'n') cublasTransB = CUBLAS_OP_N;
  41. status = cublasDgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
  42. }
  43. // cublasSgemm wrapper
  44. inline void cublasXgemm(char TransA, char TransB, int M, int N, int K, float alpha,
  45. float *A, int lda, float *B, int ldb, float beta, float *C, int ldc) {
  46. cublasOperation_t cublasTransA, cublasTransB;
  47. cublasStatus_t status;
  48. cublasHandle_t *handle;
  49. handle = CUDA_Lock::acquire_handle();
  50. if (TransA == 'T' || TransA == 't') cublasTransA = CUBLAS_OP_T;
  51. else if (TransA == 'N' || TransA == 'n') cublasTransA = CUBLAS_OP_N;
  52. if (TransB == 'T' || TransB == 't') cublasTransB = CUBLAS_OP_T;
  53. else if (TransB == 'N' || TransB == 'n') cublasTransB = CUBLAS_OP_N;
  54. status = cublasSgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
  55. }
  56. #endif
  57. inline void svd(char *JOBU, char *JOBVT, int *M, int *N, float *A, int *LDA,
  58. float *S, float *U, int *LDU, float *VT, int *LDVT, float *WORK, int *LWORK,
  59. int *INFO){
  60. sgesvd_(JOBU,JOBVT,M,N,A,LDA,S,U,LDU,VT,LDVT,WORK,LWORK,INFO);
  61. }
  62. inline void svd(char *JOBU, char *JOBVT, int *M, int *N, double *A, int *LDA,
  63. double *S, double *U, int *LDU, double *VT, int *LDVT, double *WORK, int *LWORK,
  64. int *INFO){
  65. dgesvd_(JOBU,JOBVT,M,N,A,LDA,S,U,LDU,VT,LDVT,WORK,LWORK,INFO);
  66. }
  67. /**
  68. * \brief Computes the pseudo inverse of matrix M(n1xn2) (in row major form)
  69. * and returns the output M_(n2xn1).
  70. */
  71. template <class T>
  72. void pinv(T* M, int n1, int n2, T eps, T* M_){
  73. int m = n2;
  74. int n = n1;
  75. int k = (m<n?m:n);
  76. std::vector<T> tU(m*k);
  77. std::vector<T> tS(k);
  78. std::vector<T> tVT(k*n);
  79. //SVD
  80. int INFO=0;
  81. char JOBU = 'S';
  82. char JOBVT = 'S';
  83. //int wssize = max(3*min(m,n)+max(m,n), 5*min(m,n));
  84. int wssize = 3*(m<n?m:n)+(m>n?m:n);
  85. int wssize1 = 5*(m<n?m:n);
  86. wssize = (wssize>wssize1?wssize:wssize1);
  87. T* wsbuf = new T[wssize];
  88. svd(&JOBU, &JOBVT, &m, &n, &M[0], &m, &tS[0], &tU[0], &m, &tVT[0], &k,
  89. wsbuf, &wssize, &INFO);
  90. if(INFO!=0)
  91. std::cout<<INFO<<'\n';
  92. assert(INFO==0);
  93. delete [] wsbuf;
  94. T eps_=tS[0]*eps;
  95. for(int i=0;i<k;i++)
  96. if(tS[i]<eps_)
  97. tS[i]=0;
  98. else
  99. tS[i]=1.0/tS[i];
  100. for(int i=0;i<m;i++){
  101. for(int j=0;j<k;j++){
  102. tU[i+j*m]*=tS[j];
  103. }
  104. }
  105. gemm('T','T',n,m,k,1.0,&tVT[0],k,&tU[0],m,0.0,M_,n);
  106. }
  107. }//end namespace
  108. }//end namespace