|
@@ -31,33 +31,35 @@ namespace mat{
|
|
|
dgemm_(&TransA, &TransB, &M, &N, &K, &alpha, A, &lda, B, &ldb, &beta, C, &ldc);
|
|
|
}
|
|
|
|
|
|
+#if defined(PVFMM_HAVE_CUDA)
|
|
|
// cublasDgemm wrapper
|
|
|
inline void cublasXgemm(char TransA, char TransB, int M, int N, int K, double alpha, double *A, int lda, double *B, int ldb, double beta, double *C, int ldc){
|
|
|
- cublasOperation_t cublasTransA, cublasTransB;
|
|
|
- cublasStatus_t status;
|
|
|
- cublasHandle_t *handle;
|
|
|
- handle = DeviceWrapper::CUDA_Lock::acquire_handle();
|
|
|
- /* Need exeception handling if (handle) */
|
|
|
- if (TransA == 'T' || TransA == 't') cublasTransA = CUBLAS_OP_T;
|
|
|
- else if (TransA == 'N' || TransA == 'n') cublasTransA = CUBLAS_OP_T;
|
|
|
- if (TransB == 'T' || TransB == 't') cublasTransB = CUBLAS_OP_T;
|
|
|
- else if (TransB == 'N' || TransB == 'n') cublasTransB = CUBLAS_OP_T;
|
|
|
+ cublasOperation_t cublasTransA, cublasTransB;
|
|
|
+ cublasStatus_t status;
|
|
|
+ cublasHandle_t *handle;
|
|
|
+ handle = CUDA_Lock::acquire_handle();
|
|
|
+ /* Need exeception handling if (handle) */
|
|
|
+ if (TransA == 'T' || TransA == 't') cublasTransA = CUBLAS_OP_T;
|
|
|
+ else if (TransA == 'N' || TransA == 'n') cublasTransA = CUBLAS_OP_T;
|
|
|
+ if (TransB == 'T' || TransB == 't') cublasTransB = CUBLAS_OP_T;
|
|
|
+ else if (TransB == 'N' || TransB == 'n') cublasTransB = CUBLAS_OP_T;
|
|
|
status = cublasDgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
|
|
|
}
|
|
|
|
|
|
// cublasDgemm wrapper
|
|
|
inline void cublasXgemm(char TransA, char TransB, int M, int N, int K, float alpha, float *A, int lda, float *B, int ldb, float beta, float *C, int ldc){
|
|
|
- cublasOperation_t cublasTransA, cublasTransB;
|
|
|
- cublasStatus_t status;
|
|
|
- cublasHandle_t *handle;
|
|
|
- handle = DeviceWrapper::CUDA_Lock::acquire_handle();
|
|
|
- /* Need exeception handling if (handle) */
|
|
|
- if (TransA == 'T' || TransA == 't') cublasTransA = CUBLAS_OP_T;
|
|
|
- else if (TransA == 'N' || TransA == 'n') cublasTransA = CUBLAS_OP_T;
|
|
|
- if (TransB == 'T' || TransB == 't') cublasTransB = CUBLAS_OP_T;
|
|
|
- else if (TransB == 'N' || TransB == 'n') cublasTransB = CUBLAS_OP_T;
|
|
|
+ cublasOperation_t cublasTransA, cublasTransB;
|
|
|
+ cublasStatus_t status;
|
|
|
+ cublasHandle_t *handle;
|
|
|
+ handle = CUDA_Lock::acquire_handle();
|
|
|
+ /* Need exeception handling if (handle) */
|
|
|
+ if (TransA == 'T' || TransA == 't') cublasTransA = CUBLAS_OP_T;
|
|
|
+ else if (TransA == 'N' || TransA == 'n') cublasTransA = CUBLAS_OP_T;
|
|
|
+ if (TransB == 'T' || TransB == 't') cublasTransB = CUBLAS_OP_T;
|
|
|
+ else if (TransB == 'N' || TransB == 'n') cublasTransB = CUBLAS_OP_T;
|
|
|
status = cublasSgemm(*handle, cublasTransA, cublasTransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc);
|
|
|
}
|
|
|
+#endif
|
|
|
|
|
|
inline void svd(char *JOBU, char *JOBVT, int *M, int *N, float *A, int *LDA,
|
|
|
float *S, float *U, int *LDU, float *VT, int *LDVT, float *WORK, int *LWORK,
|