cuda_func.hpp 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. #ifndef _CUDA_FUNC_HPP_
  2. #define _CUDA_FUNC_HPP_
  3. #include <pvfmm_common.hpp>
  4. #include <stdint.h>
  5. #include <stdio.h>
  6. #include <stdlib.h>
  7. #include <assert.h>
  8. #include <cstring>
  9. #include <device_wrapper.hpp>
  10. #include <matrix.hpp>
  11. #include <vector.hpp>
  12. #ifdef __cplusplus
  13. extern "C" {
  14. #endif
  15. void in_perm_d (char*, size_t*, char*, char*, size_t, size_t, size_t, cudaStream_t*);
  16. void out_perm_d (double*, char*, size_t*, char*, char*, size_t, size_t, size_t, cudaStream_t*);
  17. #ifdef __cplusplus
  18. }
  19. #endif
  20. template <class Real_t>
  21. class cuda_func {
  22. public:
  23. static void in_perm_h (char *precomp_data, char *input_perm, char *input_data, char *buff_in,
  24. size_t interac_indx, size_t M_dim0, size_t vec_cnt);
  25. static void out_perm_h (char *scaling, char *precomp_data, char *output_perm, char *output_data, char *buff_out,
  26. size_t interac_indx, size_t M_dim0, size_t vec_cnt);
  27. };
  28. template <class Real_t>
  29. void cuda_func<Real_t>::in_perm_h (
  30. char *precomp_data,
  31. char *input_perm,
  32. char *input_data,
  33. char *buff_in,
  34. size_t interac_indx,
  35. size_t M_dim0,
  36. size_t vec_cnt )
  37. {
  38. cudaStream_t *stream;
  39. stream = pvfmm::CUDA_Lock::acquire_stream(0);
  40. in_perm_d(precomp_data, (size_t *) input_perm, input_data, buff_in,
  41. interac_indx, M_dim0, vec_cnt, stream);
  42. };
  43. template <class Real_t>
  44. void cuda_func<Real_t>::out_perm_h (
  45. char *scaling,
  46. char *precomp_data,
  47. char *output_perm,
  48. char *output_data,
  49. char *buff_out,
  50. size_t interac_indx,
  51. size_t M_dim1,
  52. size_t vec_cnt )
  53. {
  54. cudaStream_t *stream;
  55. stream = pvfmm::CUDA_Lock::acquire_stream(0);
  56. size_t *a_d, *b_d;
  57. out_perm_d((double *) scaling, precomp_data, (size_t *) output_perm, output_data, buff_out,
  58. interac_indx, M_dim1, vec_cnt, stream );
  59. };
  60. #endif //_CUDA_FUNC_HPP_