mat_utils.txx 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. /**
  2. * \file mat_utils.txx
  3. * \author Dhairya Malhotra, dhairya.malhotra@gmail.com
  4. * \date 2-11-2011
  5. * \brief This file contains BLAS and LAPACK wrapper functions.
  6. */
  7. #include <cassert>
  8. #include <vector>
  9. #include <iostream>
  10. #include <stdint.h>
  11. #include <math.h>
  12. #include <blas.h>
  13. #include <lapack.h>
  14. #include <fft_wrapper.hpp>
  15. namespace pvfmm{
  16. namespace mat{
  17. inline void gemm(char TransA, char TransB, int M, int N, int K, float alpha, float *A, int lda, float *B, int ldb, float beta, float *C, int ldc){
  18. sgemm_(&TransA, &TransB, &M, &N, &K, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
  19. }
  20. inline void gemm(char TransA, char TransB, int M, int N, int K, double alpha, double *A, int lda, double *B, int ldb, double beta, double *C, int ldc){
  21. dgemm_(&TransA, &TransB, &M, &N, &K, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
  22. }
  23. inline void svd(char *JOBU, char *JOBVT, int *M, int *N, float *A, int *LDA,
  24. float *S, float *U, int *LDU, float *VT, int *LDVT, float *WORK, int *LWORK,
  25. int *INFO){
  26. sgesvd_(JOBU,JOBVT,M,N,A,LDA,S,U,LDU,VT,LDVT,WORK,LWORK,INFO);
  27. }
  28. inline void svd(char *JOBU, char *JOBVT, int *M, int *N, double *A, int *LDA,
  29. double *S, double *U, int *LDU, double *VT, int *LDVT, double *WORK, int *LWORK,
  30. int *INFO){
  31. dgesvd_(JOBU,JOBVT,M,N,A,LDA,S,U,LDU,VT,LDVT,WORK,LWORK,INFO);
  32. }
  33. /**
  34. * \brief Computes the pseudo inverse of matrix M(n1xn2) (in row major form)
  35. * and returns the output M_(n2xn1).
  36. */
  37. template <class T>
  38. void pinv(T* M, int n1, int n2, T eps, T* M_){
  39. int m = n2;
  40. int n = n1;
  41. int k = (m<n?m:n);
  42. std::vector<T> tU(m*k);
  43. std::vector<T> tS(k);
  44. std::vector<T> tVT(k*n);
  45. //SVD
  46. int INFO=0;
  47. char JOBU = 'S';
  48. char JOBVT = 'S';
  49. //int wssize = max(3*min(m,n)+max(m,n), 5*min(m,n));
  50. int wssize = 3*(m<n?m:n)+(m>n?m:n);
  51. int wssize1 = 5*(m<n?m:n);
  52. wssize = (wssize>wssize1?wssize:wssize1);
  53. T* wsbuf = new T[wssize];
  54. svd(&JOBU, &JOBVT, &m, &n, &M[0], &m, &tS[0], &tU[0], &m, &tVT[0], &k,
  55. wsbuf, &wssize, &INFO);
  56. if(INFO!=0)
  57. std::cout<<INFO<<'\n';
  58. assert(INFO==0);
  59. delete [] wsbuf;
  60. T eps_=tS[0]*eps;
  61. for(int i=0;i<k;i++)
  62. if(tS[i]<eps_)
  63. tS[i]=0;
  64. else
  65. tS[i]=1.0/tS[i];
  66. for(int i=0;i<m;i++){
  67. for(int j=0;j<k;j++){
  68. tU[i+j*m]*=tS[j];
  69. }
  70. }
  71. gemm('T','T',n,m,k,1.0,&tVT[0],k,&tU[0],m,0.0,M_,n);
  72. }
  73. }//end namespace
  74. }//end namespace