Browse Source

Moved BLAS, LAPACK wrappers to src/mat_utils.cpp

This means we cannot have inlined wrapper functions. But this was
necessary as other applications using PVFMM may include blas.h and
lapack.h which will collide with our declarations of BLAS and LAPACK
functions.
Dhairya Malhotra 10 years ago
parent
commit
af98b59c23
3 changed files with 47 additions and 22 deletions
  1. 1 0
      Makefile.am
  2. 6 22
      include/mat_utils.txx
  3. 40 0
      src/mat_utils.cpp

+ 1 - 0
Makefile.am

@@ -107,6 +107,7 @@ lib_libpvfmm_a_SOURCES = \
 									src/device_wrapper.cpp \
 									src/device_wrapper.cpp \
 									src/fmm_gll.cpp \
 									src/fmm_gll.cpp \
 									src/legendre_rule.cpp \
 									src/legendre_rule.cpp \
+									src/mat_utils.cpp \
 									src/mem_mgr.cpp \
 									src/mem_mgr.cpp \
 									src/mortonid.cpp \
 									src/mortonid.cpp \
 									src/profile.cpp \
 									src/profile.cpp \

+ 6 - 22
include/mat_utils.txx

@@ -13,10 +13,7 @@
 #include <iostream>
 #include <iostream>
 #include <vector>
 #include <vector>
 
 
-#include <blas.h>
-#include <lapack.h>
 #include <matrix.hpp>
 #include <matrix.hpp>
-
 #include <device_wrapper.hpp>
 #include <device_wrapper.hpp>
 #if defined(PVFMM_HAVE_CUDA)
 #if defined(PVFMM_HAVE_CUDA)
 #include <cuda_runtime_api.h>
 #include <cuda_runtime_api.h>
@@ -72,21 +69,12 @@ namespace mat{
   }
   }
 
 
   template<>
   template<>
-  inline void gemm<float>(char TransA, char TransB,  int M,  int N,  int K,  float alpha,  float *A,  int lda,  float *B,  int ldb,  float beta, float *C,  int ldc){
-      sgemm_(&TransA, &TransB, &M, &N, &K, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
-  }
+  void gemm<float>(char TransA, char TransB,  int M,  int N,  int K,  float alpha,  float *A,  int lda,  float *B,  int ldb,  float beta, float *C,  int ldc);
 
 
   template<>
   template<>
-  inline void gemm<double>(char TransA, char TransB,  int M,  int N,  int K,  double alpha,  double *A,  int lda,  double *B,  int ldb,  double beta, double *C,  int ldc){
-      dgemm_(&TransA, &TransB, &M, &N, &K, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
-  }
+  void gemm<double>(char TransA, char TransB,  int M,  int N,  int K,  double alpha,  double *A,  int lda,  double *B,  int ldb,  double beta, double *C,  int ldc);
 
 
   #if defined(PVFMM_HAVE_CUDA)
   #if defined(PVFMM_HAVE_CUDA)
-  //template <class T>
-  //inline void cublasgemm(char TransA, char TransB, int M, int N, int K, T alpha, T*A, int lda, T *B, int ldb, T beta, T *C, int ldc){
-  //  assert(false);
-  //}
-
   template<>
   template<>
   inline void cublasgemm<float>(char TransA, char TransB, int M, int N, int K, float alpha, float *A, int lda, float *B, int ldb, float beta, float *C, int ldc) {
   inline void cublasgemm<float>(char TransA, char TransB, int M, int N, int K, float alpha, float *A, int lda, float *B, int ldb, float beta, float *C, int ldc) {
     cublasOperation_t cublasTransA, cublasTransB;
     cublasOperation_t cublasTransA, cublasTransB;
@@ -434,18 +422,14 @@ namespace mat{
   }
   }
 
 
   template<>
   template<>
-  inline void svd<float>(char *JOBU, char *JOBVT, int *M, int *N, float *A, int *LDA,
+  void svd<float>(char *JOBU, char *JOBVT, int *M, int *N, float *A, int *LDA,
       float *S, float *U, int *LDU, float *VT, int *LDVT, float *WORK, int *LWORK,
       float *S, float *U, int *LDU, float *VT, int *LDVT, float *WORK, int *LWORK,
-      int *INFO){
-    sgesvd_(JOBU,JOBVT,M,N,A,LDA,S,U,LDU,VT,LDVT,WORK,LWORK,INFO);
-  }
+      int *INFO);
 
 
   template<>
   template<>
-  inline void svd<double>(char *JOBU, char *JOBVT, int *M, int *N, double *A, int *LDA,
+  void svd<double>(char *JOBU, char *JOBVT, int *M, int *N, double *A, int *LDA,
       double *S, double *U, int *LDU, double *VT, int *LDVT, double *WORK, int *LWORK,
       double *S, double *U, int *LDU, double *VT, int *LDVT, double *WORK, int *LWORK,
-      int *INFO){
-    dgesvd_(JOBU,JOBVT,M,N,A,LDA,S,U,LDU,VT,LDVT,WORK,LWORK,INFO);
-  }
+      int *INFO);
 
 
   /**
   /**
    * \brief Computes the pseudo inverse of matrix M(n1xn2) (in row major form)
    * \brief Computes the pseudo inverse of matrix M(n1xn2) (in row major form)

+ 40 - 0
src/mat_utils.cpp

@@ -0,0 +1,40 @@
+/**
+ * \file mat_utils.cpp
+ * \author Dhairya Malhotra, dhairya.malhotra@gmail.com
+ * \date November, 2014
+ * \brief This file contains implementation of BLAS and LAPACK wrapper functions.
+ */
+
+#include <blas.h>
+#include <lapack.h>
+#include <mat_utils.hpp>
+
+namespace pvfmm{
+namespace mat{
+
+template<>
+void gemm<float>(char TransA, char TransB,  int M,  int N,  int K,  float alpha,  float *A,  int lda,  float *B,  int ldb,  float beta, float *C,  int ldc){
+    sgemm_(&TransA, &TransB, &M, &N, &K, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
+}
+
+template<>
+void gemm<double>(char TransA, char TransB,  int M,  int N,  int K,  double alpha,  double *A,  int lda,  double *B,  int ldb,  double beta, double *C,  int ldc){
+    dgemm_(&TransA, &TransB, &M, &N, &K, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
+}
+
+template<>
+void svd<float>(char *JOBU, char *JOBVT, int *M, int *N, float *A, int *LDA,
+    float *S, float *U, int *LDU, float *VT, int *LDVT, float *WORK, int *LWORK,
+    int *INFO){
+  sgesvd_(JOBU,JOBVT,M,N,A,LDA,S,U,LDU,VT,LDVT,WORK,LWORK,INFO);
+}
+
+template<>
+void svd<double>(char *JOBU, char *JOBVT, int *M, int *N, double *A, int *LDA,
+    double *S, double *U, int *LDU, double *VT, int *LDVT, double *WORK, int *LWORK,
+    int *INFO){
+  dgesvd_(JOBU,JOBVT,M,N,A,LDA,S,U,LDU,VT,LDVT,WORK,LWORK,INFO);
+}
+
+}//end namespace
+}//end namespace