fmm-wrapper.hpp 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. #ifndef _SCTL_FMM_WRAPPER_HPP_
  2. #define _SCTL_FMM_WRAPPER_HPP_
  3. #include <sctl/common.hpp>
  4. #include SCTL_INCLUDE(comm.hpp)
  5. #include SCTL_INCLUDE(mem_mgr.hpp)
  6. #include <map>
  7. #include <string>
  8. #ifdef SCTL_HAVE_PVFMM
  9. namespace pvfmm {
  10. template <class Real> struct Kernel;
  11. template <class Real> class MPI_Node;
  12. template <class Node> class FMM_Node;
  13. template <class FMM_Node> class FMM_Pts;
  14. template <class FMM_Mat> class FMM_Tree;
  15. template <class Real> using PtFMM_Node = FMM_Node<MPI_Node<Real>>;
  16. template <class Real> using PtFMM = FMM_Pts<PtFMM_Node<Real>>;
  17. template <class Real> using PtFMM_Tree = FMM_Tree<PtFMM<Real>>;
  18. }
  19. #endif
  20. namespace SCTL_NAMESPACE {
  21. template <class ValueType> class Vector;
  22. template <class Real, Integer DIM> class ParticleFMM {
  23. public:
  24. ParticleFMM(const ParticleFMM&) = delete;
  25. ParticleFMM& operator= (const ParticleFMM&) = delete;
  26. ParticleFMM(const Comm& comm = Comm::Self());
  27. ~ParticleFMM();
  28. void SetComm(const Comm& comm);
  29. void SetAccuracy(Integer digits);
  30. template <class KerM2M, class KerM2L, class KerL2L> void SetKernels(const KerM2M& ker_m2m, const KerM2L& ker_m2l, const KerL2L& ker_l2l);
  31. template <class KerS2M, class KerS2L> void AddSrc(const std::string& name, const KerS2M& ker_s2m, const KerS2L& ker_s2l);
  32. template <class KerM2T, class KerL2T> void AddTrg(const std::string& name, const KerM2T& ker_m2t, const KerL2T& ker_l2t);
  33. template <class KerS2T> void SetKernelS2T(const std::string& src_name, const std::string& trg_name, const KerS2T& ker_s2t);
  34. void DeleteSrc(const std::string& name);
  35. void DeleteTrg(const std::string& name);
  36. void SetSrcCoord(const std::string& name, const Vector<Real>& src_coord, const Vector<Real>& src_normal = Vector<Real>());
  37. void SetSrcDensity(const std::string& name, const Vector<Real>& src_density);
  38. void SetTrgCoord(const std::string& name, const Vector<Real>& trg_coord);
  39. void Eval(Vector<Real>& U, const std::string& trg_name) const;
  40. void EvalDirect(Vector<Real>& U, const std::string& trg_name) const;
  41. static void test(const Comm& comm);
  42. private:
  43. struct FMMKernels {
  44. Iterator<char> ker_m2m, ker_m2l, ker_l2l;
  45. Integer dim_mul_ch, dim_mul_eq;
  46. Integer dim_loc_ch, dim_loc_eq;
  47. void (*ker_m2m_eval)(Vector<Real>& v_trg, const Vector<Real>& r_trg, const Vector<Real>& r_src, const Vector<Real>& n_src, const Vector<Real>& v_src, Integer digits, ConstIterator<char> self);
  48. void (*ker_m2l_eval)(Vector<Real>& v_trg, const Vector<Real>& r_trg, const Vector<Real>& r_src, const Vector<Real>& n_src, const Vector<Real>& v_src, Integer digits, ConstIterator<char> self);
  49. void (*ker_l2l_eval)(Vector<Real>& v_trg, const Vector<Real>& r_trg, const Vector<Real>& r_src, const Vector<Real>& n_src, const Vector<Real>& v_src, Integer digits, ConstIterator<char> self);
  50. void (*delete_ker_m2m)(Iterator<char> ker);
  51. void (*delete_ker_m2l)(Iterator<char> ker);
  52. void (*delete_ker_l2l)(Iterator<char> ker);
  53. #ifdef SCTL_HAVE_PVFMM
  54. pvfmm::Kernel<Real> pvfmm_ker_m2m;
  55. pvfmm::Kernel<Real> pvfmm_ker_m2l;
  56. pvfmm::Kernel<Real> pvfmm_ker_l2l;
  57. #endif
  58. };
  59. struct SrcData {
  60. Vector<Real> X, Xn, F;
  61. Iterator<char> ker_s2m, ker_s2l;
  62. Integer dim_src, dim_mul_ch, dim_loc_ch, dim_normal;
  63. void (*ker_s2m_eval)(Vector<Real>& v_trg, const Vector<Real>& r_trg, const Vector<Real>& r_src, const Vector<Real>& n_src, const Vector<Real>& v_src, Integer digits, ConstIterator<char> self);
  64. void (*ker_s2l_eval)(Vector<Real>& v_trg, const Vector<Real>& r_trg, const Vector<Real>& r_src, const Vector<Real>& n_src, const Vector<Real>& v_src, Integer digits, ConstIterator<char> self);
  65. void (*delete_ker_s2m)(Iterator<char> ker);
  66. void (*delete_ker_s2l)(Iterator<char> ker);
  67. #ifdef SCTL_HAVE_PVFMM
  68. pvfmm::Kernel<Real> pvfmm_ker_s2m;
  69. pvfmm::Kernel<Real> pvfmm_ker_s2l;
  70. StaticArray<Real, DIM*2> bbox;
  71. #endif
  72. };
  73. struct TrgData {
  74. Vector<Real> X, U;
  75. Iterator<char> ker_m2t, ker_l2t;
  76. Integer dim_mul_eq, dim_loc_eq, dim_trg;
  77. void (*ker_m2t_eval)(Vector<Real>& v_trg, const Vector<Real>& r_trg, const Vector<Real>& r_src, const Vector<Real>& n_src, const Vector<Real>& v_src, Integer digits, ConstIterator<char> self);
  78. void (*ker_l2t_eval)(Vector<Real>& v_trg, const Vector<Real>& r_trg, const Vector<Real>& r_src, const Vector<Real>& n_src, const Vector<Real>& v_src, Integer digits, ConstIterator<char> self);
  79. void (*delete_ker_m2t)(Iterator<char> ker);
  80. void (*delete_ker_l2t)(Iterator<char> ker);
  81. #ifdef SCTL_HAVE_PVFMM
  82. pvfmm::Kernel<Real> pvfmm_ker_m2t;
  83. pvfmm::Kernel<Real> pvfmm_ker_l2t;
  84. StaticArray<Real, DIM*2> bbox;
  85. #endif
  86. };
  87. struct S2TData {
  88. Iterator<char> ker_s2t;
  89. Integer dim_src, dim_trg, dim_normal;
  90. void (*ker_s2t_eval)(Vector<Real>& v_trg, const Vector<Real>& r_trg, const Vector<Real>& r_src, const Vector<Real>& n_src, const Vector<Real>& v_src, Integer digits, ConstIterator<char> self);
  91. void (*ker_s2t_eval_omp)(Vector<Real>& v_trg, const Vector<Real>& r_trg, const Vector<Real>& r_src, const Vector<Real>& n_src, const Vector<Real>& v_src, Integer digits, ConstIterator<char> self);
  92. void (*delete_ker_s2t)(Iterator<char> ker);
  93. #ifdef SCTL_HAVE_PVFMM
  94. mutable Real bbox_scale;
  95. mutable StaticArray<Real,DIM> bbox_offset;
  96. mutable Vector<Real> src_scal_exp, trg_scal_exp;
  97. mutable Vector<Real> src_scal, trg_scal;
  98. mutable pvfmm::Kernel<Real> pvfmm_ker_s2t;
  99. mutable pvfmm::PtFMM_Tree<Real>* tree_ptr;
  100. mutable pvfmm::PtFMM<Real> fmm_ctx;
  101. mutable bool setup_tree;
  102. mutable bool setup_ker;
  103. #endif
  104. };
  105. static void BuildSrcTrgScal(const S2TData& s2t_data, bool verbose);
  106. template <class Ker> static void DeleteKer(Iterator<char> ker);
  107. void CheckKernelDims() const;
  108. void DeleteS2T(const std::string& src_name, const std::string& trg_name);
  109. #ifdef SCTL_HAVE_PVFMM
  110. template <class SCTLKernel, bool use_dummy_normal=false> struct PVFMMKernelFn; // construct PVFMMKernel from SCTLKernel
  111. void EvalPVFMM(Vector<Real>& U, const std::string& trg_name) const;
  112. #endif
  113. FMMKernels fmm_ker;
  114. std::map<std::string, SrcData> src_map;
  115. std::map<std::string, TrgData> trg_map;
  116. std::map<std::pair<std::string,std::string>, S2TData> s2t_map;
  117. Comm comm_;
  118. Integer digits_;
  119. };
  120. } // end namespace
  121. #include SCTL_INCLUDE(fmm-wrapper.txx)
  122. #endif //_SCTL_FMM_WRAPPER_HPP_