cuda_func.hpp 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. #ifndef _CUDA_FUNC_HPP_
  2. #define _CUDA_FUNC_HPP_
  3. #include <pvfmm_common.hpp>
  4. #include <stdint.h>
  5. #include <stdio.h>
  6. #include <stdlib.h>
  7. #include <assert.h>
  8. #include <cstring>
  9. #include <device_wrapper.hpp>
  10. #include <matrix.hpp>
  11. #include <vector.hpp>
  12. #ifdef __cplusplus
  13. extern "C" {
  14. #endif
  15. void test_d(uintptr_t, uintptr_t, uintptr_t, uintptr_t, int, cudaStream_t*);
  16. //void in_perm_d (uintptr_t, uintptr_t , uintptr_t, uintptr_t, size_t, size_t, size_t, cudaStream_t*);
  17. void in_perm_d (char*, size_t*, char*, char*, size_t, size_t, size_t, cudaStream_t*);
  18. void out_perm_d (double*, char*, size_t*, char*, char*, size_t, size_t, size_t, cudaStream_t*);
  19. #ifdef __cplusplus
  20. }
  21. #endif
  22. template <class Real_t>
  23. class cuda_func {
  24. public:
  25. /*
  26. static void in_perm_h (uintptr_t precomp_data, uintptr_t input_perm, uintptr_t input_data, uintptr_t buff_in,
  27. size_t interac_indx, size_t M_dim0, size_t vec_cnt);
  28. */
  29. static void in_perm_h (char *precomp_data, char *input_perm, char *input_data, char *buff_in,
  30. size_t interac_indx, size_t M_dim0, size_t vec_cnt);
  31. static void out_perm_h (char *scaling, char *precomp_data, char *output_perm, char *output_data, char *buff_out,
  32. size_t interac_indx, size_t M_dim0, size_t vec_cnt);
  33. };
  34. template <class Real_t>
  35. void cuda_func<Real_t>::in_perm_h (
  36. /*
  37. uintptr_t precomp_data,
  38. uintptr_t input_perm,
  39. uintptr_t input_data,
  40. uintptr_t buff_in,
  41. */
  42. char *precomp_data,
  43. char *input_perm,
  44. char *input_data,
  45. char *buff_in,
  46. size_t interac_indx,
  47. size_t M_dim0,
  48. size_t vec_cnt )
  49. {
  50. cudaStream_t *stream;
  51. stream = pvfmm::CUDA_Lock::acquire_stream(0);
  52. /*
  53. intptr_t precomp_data_d = precomp_data[0];
  54. intptr_t input_perm_d = input_perm[0];
  55. intptr_t input_data_d = input_data[0];
  56. intptr_t buff_in_d = buff_in[0];
  57. */
  58. in_perm_d(precomp_data, (size_t *) input_perm, input_data, buff_in,
  59. interac_indx, M_dim0, vec_cnt, stream);
  60. //test_d(precomp_data, input_perm, input_data, buff_in, interac_indx, stream);
  61. };
  62. template <class Real_t>
  63. void cuda_func<Real_t>::out_perm_h (
  64. char *scaling,
  65. char *precomp_data,
  66. char *output_perm,
  67. char *output_data,
  68. char *buff_out,
  69. size_t interac_indx,
  70. size_t M_dim1,
  71. size_t vec_cnt )
  72. {
  73. cudaStream_t *stream;
  74. //stream = DeviceWrapper::CUDA_Lock::acquire_stream(0);
  75. stream = pvfmm::CUDA_Lock::acquire_stream(0);
  76. out_perm_d((double *) scaling, precomp_data, (size_t *) output_perm, output_data, buff_out,
  77. interac_indx, M_dim1, vec_cnt, stream);
  78. }
  79. #endif //_CUDA_FUNC_HPP_