cuda_func.hpp 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. #ifndef _CUDA_FUNC_HPP_
  2. #define _CUDA_FUNC_HPP_
  3. #include <pvfmm_common.hpp>
  4. #include <assert.h>
  5. #include <cstring>
  6. #include <device_wrapper.hpp>
  7. #include <matrix.hpp>
  8. #include <vector.hpp>
  9. //namespace pvfmm {
  10. // external functions
  11. extern "C" void in_perm_d (uintptr_t, uintptr_t, uintptr_t, uintptr_t, size_t, size_t, size_t, cudaStream_t*);
  12. extern "C" void out_perm_d (uintptr_t, uintptr_t, uintptr_t, uintptr_t, uintptr_t, size_t, size_t, size_t, cudaStream_t*);
  13. template <class Real_t>
  14. class cuda_func {
  15. public:
  16. static void in_perm_h (uintptr_t precomp_data, uintptr_t input_perm,
  17. uintptr_t input_data, uintptr_t buff_in, size_t interac_indx,
  18. size_t M_dim0, size_t vec_cnt);
  19. static void out_perm_h (uintptr_t scaling, uintptr_t precomp_data, uintptr_t output_perm,
  20. uintptr_t output_data, uintptr_t buff_out, size_t interac_indx,
  21. size_t M_dim0, size_t vec_cnt);
  22. };
  23. template <class Real_t>
  24. void cuda_func<Real_t>::in_perm_h (
  25. uintptr_t precomp_data,
  26. uintptr_t input_perm,
  27. uintptr_t input_data,
  28. uintptr_t buff_in,
  29. size_t interac_indx,
  30. size_t M_dim0,
  31. size_t vec_cnt )
  32. {
  33. cudaStream_t *stream;
  34. //stream = DeviceWrapper::CUDA_Lock::acquire_stream(0);
  35. stream = pvfmm::CUDA_Lock::acquire_stream(0);
  36. /*
  37. intptr_t precomp_data_d = precomp_data[0];
  38. intptr_t input_perm_d = input_perm[0];
  39. intptr_t input_data_d = input_data[0];
  40. intptr_t buff_in_d = buff_in[0];
  41. */
  42. in_perm_d(precomp_data, input_perm, input_data, buff_in, interac_indx, M_dim0, vec_cnt, stream);
  43. };
  44. template <class Real_t>
  45. void cuda_func<Real_t>::out_perm_h (
  46. uintptr_t scaling,
  47. uintptr_t precomp_data,
  48. uintptr_t output_perm,
  49. uintptr_t output_data,
  50. uintptr_t buff_out,
  51. size_t interac_indx,
  52. size_t M_dim1,
  53. size_t vec_cnt )
  54. {
  55. cudaStream_t *stream;
  56. //stream = DeviceWrapper::CUDA_Lock::acquire_stream(0);
  57. stream = pvfmm::CUDA_Lock::acquire_stream(0);
  58. out_perm_d(scaling, precomp_data, output_perm, output_data, buff_out, interac_indx, M_dim1, vec_cnt, stream);
  59. }
  60. //};
  61. #endif //_CUDA_FUNC_HPP_