mat_utils.txx 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. /**
  2. * \file mat_utils.txx
  3. * \author Dhairya Malhotra, dhairya.malhotra@gmail.com
  4. * \date 2-11-2011
  5. * \brief This file contains BLAS and LAPACK wrapper functions.
  6. */
  7. #include <cassert>
  8. #include <vector>
  9. #include <iostream>
  10. #include <stdint.h>
  11. #include <math.h>
  12. #include <blas.h>
  13. #include <lapack.h>
  14. #include <fft_wrapper.hpp>
  15. #include <device_wrapper.hpp>
  16. #if defined(PVFMM_HAVE_CUDA)
  17. #include <cuda_runtime_api.h>
  18. #include <cublas_v2.h>
  19. #endif
  20. namespace pvfmm{
  21. namespace mat{
  22. inline void gemm(char TransA, char TransB, int M, int N, int K, float alpha, float *A, int lda, float *B, int ldb, float beta, float *C, int ldc){
  23. sgemm_(&TransA, &TransB, &M, &N, &K, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
  24. }
  25. inline void gemm(char TransA, char TransB, int M, int N, int K, double alpha, double *A, int lda, double *B, int ldb, double beta, double *C, int ldc){
  26. dgemm_(&TransA, &TransB, &M, &N, &K, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
  27. }
  28. #if defined(PVFMM_HAVE_CUDA)
  29. // cublasDgemm wrapper
  30. inline void cublasXgemm(char TransA, char TransB, int M, int N, int K, double alpha, double *A, int lda, double *B, int ldb, double beta, double *C, int ldc){
  31. cublasOperation_t cublasTransA, cublasTransB;
  32. cublasStatus_t status;
  33. cublasHandle_t *handle;
  34. handle = CUDA_Lock::acquire_handle();
  35. /* Need exeception handling if (handle) */
  36. if (TransA == 'T' || TransA == 't') cublasTransA = CUBLAS_OP_T;
  37. else if (TransA == 'N' || TransA == 'n') cublasTransA = CUBLAS_OP_T;
  38. if (TransB == 'T' || TransB == 't') cublasTransB = CUBLAS_OP_T;
  39. else if (TransB == 'N' || TransB == 'n') cublasTransB = CUBLAS_OP_T;
  40. status = cublasDgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
  41. }
  42. // cublasDgemm wrapper
  43. inline void cublasXgemm(char TransA, char TransB, int M, int N, int K, float alpha, float *A, int lda, float *B, int ldb, float beta, float *C, int ldc){
  44. cublasOperation_t cublasTransA, cublasTransB;
  45. cublasStatus_t status;
  46. cublasHandle_t *handle;
  47. handle = CUDA_Lock::acquire_handle();
  48. /* Need exeception handling if (handle) */
  49. if (TransA == 'T' || TransA == 't') cublasTransA = CUBLAS_OP_T;
  50. else if (TransA == 'N' || TransA == 'n') cublasTransA = CUBLAS_OP_T;
  51. if (TransB == 'T' || TransB == 't') cublasTransB = CUBLAS_OP_T;
  52. else if (TransB == 'N' || TransB == 'n') cublasTransB = CUBLAS_OP_T;
  53. status = cublasSgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
  54. }
  55. #endif
  56. inline void svd(char *JOBU, char *JOBVT, int *M, int *N, float *A, int *LDA,
  57. float *S, float *U, int *LDU, float *VT, int *LDVT, float *WORK, int *LWORK,
  58. int *INFO){
  59. sgesvd_(JOBU,JOBVT,M,N,A,LDA,S,U,LDU,VT,LDVT,WORK,LWORK,INFO);
  60. }
  61. inline void svd(char *JOBU, char *JOBVT, int *M, int *N, double *A, int *LDA,
  62. double *S, double *U, int *LDU, double *VT, int *LDVT, double *WORK, int *LWORK,
  63. int *INFO){
  64. dgesvd_(JOBU,JOBVT,M,N,A,LDA,S,U,LDU,VT,LDVT,WORK,LWORK,INFO);
  65. }
  66. /**
  67. * \brief Computes the pseudo inverse of matrix M(n1xn2) (in row major form)
  68. * and returns the output M_(n2xn1).
  69. */
  70. template <class T>
  71. void pinv(T* M, int n1, int n2, T eps, T* M_){
  72. int m = n2;
  73. int n = n1;
  74. int k = (m<n?m:n);
  75. std::vector<T> tU(m*k);
  76. std::vector<T> tS(k);
  77. std::vector<T> tVT(k*n);
  78. //SVD
  79. int INFO=0;
  80. char JOBU = 'S';
  81. char JOBVT = 'S';
  82. //int wssize = max(3*min(m,n)+max(m,n), 5*min(m,n));
  83. int wssize = 3*(m<n?m:n)+(m>n?m:n);
  84. int wssize1 = 5*(m<n?m:n);
  85. wssize = (wssize>wssize1?wssize:wssize1);
  86. T* wsbuf = new T[wssize];
  87. svd(&JOBU, &JOBVT, &m, &n, &M[0], &m, &tS[0], &tU[0], &m, &tVT[0], &k,
  88. wsbuf, &wssize, &INFO);
  89. if(INFO!=0)
  90. std::cout<<INFO<<'\n';
  91. assert(INFO==0);
  92. delete [] wsbuf;
  93. T eps_=tS[0]*eps;
  94. for(int i=0;i<k;i++)
  95. if(tS[i]<eps_)
  96. tS[i]=0;
  97. else
  98. tS[i]=1.0/tS[i];
  99. for(int i=0;i<m;i++){
  100. for(int j=0;j<k;j++){
  101. tU[i+j*m]*=tS[j];
  102. }
  103. }
  104. gemm('T','T',n,m,k,1.0,&tVT[0],k,&tU[0],m,0.0,M_,n);
  105. }
  106. }//end namespace
  107. }//end namespace