pvfmm.hpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. /**
  2. * \file pvfmm.hpp
  3. * \author Dhairya Malhotra, dhairya.malhotra@gmail.com
  4. * \date 1-2-2014
  5. * \brief This file contains wrapper functions for PvFMM.
  6. */
  7. #ifndef _PVFMM_HPP_
  8. #define _PVFMM_HPP_
  9. #include <mpi.h>
  10. #include <cstdlib>
  11. #include <iostream>
  12. #include <pvfmm_common.hpp>
  13. #include <fmm_cheb.hpp>
  14. #include <fmm_node.hpp>
  15. #include <fmm_tree.hpp>
  16. namespace pvfmm{
  17. typedef FMM_Node<Cheb_Node<double> > ChebFMM_Node;
  18. typedef FMM_Cheb<ChebFMM_Node> ChebFMM;
  19. typedef FMM_Tree<ChebFMM> ChebFMM_Tree;
  20. typedef ChebFMM_Node::NodeData ChebFMM_Data;
  21. typedef void (*ChebFn)(double* , int , double*);
  22. ChebFMM_Tree* ChebFMM_CreateTree(int cheb_deg, int data_dim, ChebFn fn_ptr, std::vector<double>& trg_coord, MPI_Comm& comm,
  23. double tol=1e-6, int max_pts=100, BoundaryType bndry=FreeSpace, int init_depth=0){
  24. int np, myrank;
  25. MPI_Comm_size(comm, &np);
  26. MPI_Comm_rank(comm, &myrank);
  27. ChebFMM_Data tree_data;
  28. tree_data.cheb_deg=cheb_deg;
  29. tree_data.data_dof=data_dim;
  30. tree_data.input_fn=fn_ptr;
  31. tree_data.tol=tol;
  32. bool adap=true;
  33. tree_data.dim=COORD_DIM;
  34. tree_data.max_depth=MAX_DEPTH;
  35. tree_data.max_pts=max_pts;
  36. { // Set points for initial tree.
  37. std::vector<double> coord;
  38. size_t N=pow(8.0,init_depth);
  39. N=(N<np?np:N)*max_pts;
  40. size_t NN=ceil(pow((double)N,1.0/3.0));
  41. size_t N_total=NN*NN*NN;
  42. size_t start= myrank *N_total/np;
  43. size_t end =(myrank+1)*N_total/np;
  44. for(size_t i=start;i<end;i++){
  45. coord.push_back(((double)((i/ 1 )%NN)+0.5)/NN);
  46. coord.push_back(((double)((i/ NN )%NN)+0.5)/NN);
  47. coord.push_back(((double)((i/(NN*NN))%NN)+0.5)/NN);
  48. }
  49. tree_data.pt_coord=coord;
  50. }
  51. // Set target points.
  52. tree_data.trg_coord=trg_coord;
  53. ChebFMM_Tree* tree=new ChebFMM_Tree(comm);
  54. tree->Initialize(&tree_data);
  55. tree->InitFMM_Tree(adap,bndry);
  56. return tree;
  57. }
  58. void ChebFMM_Evaluate(ChebFMM_Tree* tree, std::vector<double>& trg_val, size_t loc_size=0){
  59. tree->RunFMM();
  60. Vector<double> trg_value;
  61. Vector<size_t> trg_scatter;
  62. {// Collect data from each node to trg_value and trg_scatter.
  63. std::vector<double> trg_value_;
  64. std::vector<size_t> trg_scatter_;
  65. std::vector<ChebFMM_Node*>& nodes=tree->GetNodeList();
  66. for(size_t i=0;i<nodes.size();i++){
  67. if(nodes[i]->IsLeaf() && !nodes[i]->IsGhost()){
  68. Vector<double>& trg_value=nodes[i]->trg_value;
  69. Vector<size_t>& trg_scatter=nodes[i]->trg_scatter;
  70. for(size_t j=0;j<trg_value.Dim();j++) trg_value_.push_back(trg_value[j]);
  71. for(size_t j=0;j<trg_scatter.Dim();j++) trg_scatter_.push_back(trg_scatter[j]);
  72. }
  73. }
  74. trg_value=trg_value_;
  75. trg_scatter=trg_scatter_;
  76. }
  77. par::ScatterReverse(trg_value,trg_scatter,*tree->Comm(),loc_size);
  78. trg_val.assign(&trg_value[0],&trg_value[0]+trg_value.Dim());;
  79. }
  80. typedef FMM_Node<MPI_Node<double> > PtFMM_Node;
  81. typedef FMM_Pts<PtFMM_Node> PtFMM;
  82. typedef FMM_Tree<PtFMM> PtFMM_Tree;
  83. typedef PtFMM_Node::NodeData PtFMM_Data;
  84. PtFMM_Tree* PtFMM_CreateTree(std::vector<double>& src_coord, std::vector<double>& src_value, std::vector<double>& trg_coord, MPI_Comm& comm,
  85. int max_pts=100, BoundaryType bndry=FreeSpace, int init_depth=0){
  86. int np, myrank;
  87. MPI_Comm_size(comm, &np);
  88. MPI_Comm_rank(comm, &myrank);
  89. PtFMM_Data tree_data;
  90. bool adap=true;
  91. tree_data.dim=COORD_DIM;
  92. tree_data.max_depth=MAX_DEPTH;
  93. tree_data.max_pts=max_pts;
  94. // Set source points.
  95. tree_data.pt_coord=src_coord;
  96. tree_data.src_coord=src_coord;
  97. tree_data.src_value=src_value;
  98. // Set target points.
  99. tree_data.trg_coord=trg_coord;
  100. PtFMM_Tree* tree=new PtFMM_Tree(comm);
  101. tree->Initialize(&tree_data);
  102. tree->InitFMM_Tree(adap,bndry);
  103. return tree;
  104. }
  105. void PtFMM_Evaluate(PtFMM_Tree* tree, std::vector<double>& trg_val, size_t loc_size=0, std::vector<double>* src_val=NULL){
  106. if(src_val){
  107. std::vector<size_t> src_scatter_;
  108. std::vector<PtFMM_Node*>& nodes=tree->GetNodeList();
  109. for(size_t i=0;i<nodes.size();i++){
  110. if(nodes[i]->IsLeaf() && !nodes[i]->IsGhost()){
  111. Vector<size_t>& src_scatter=nodes[i]->src_scatter;
  112. for(size_t j=0;j<src_scatter.Dim();j++) src_scatter_.push_back(src_scatter[j]);
  113. }
  114. }
  115. Vector<double> src_value=*src_val;
  116. Vector<size_t> src_scatter=src_scatter_;
  117. par::ScatterForward(src_value,src_scatter,*tree->Comm());
  118. size_t indx=0;
  119. for(size_t i=0;i<nodes.size();i++){
  120. if(nodes[i]->IsLeaf() && !nodes[i]->IsGhost()){
  121. Vector<double>& src_value_=nodes[i]->src_value;
  122. for(size_t j=0;j<src_value_.Dim();j++){
  123. src_value_[j]=src_value[indx];
  124. indx++;
  125. }
  126. }
  127. }
  128. }
  129. tree->RunFMM();
  130. Vector<double> trg_value;
  131. Vector<size_t> trg_scatter;
  132. {
  133. std::vector<double> trg_value_;
  134. std::vector<size_t> trg_scatter_;
  135. std::vector<PtFMM_Node*>& nodes=tree->GetNodeList();
  136. for(size_t i=0;i<nodes.size();i++){
  137. if(nodes[i]->IsLeaf() && !nodes[i]->IsGhost()){
  138. Vector<double>& trg_value=nodes[i]->trg_value;
  139. Vector<size_t>& trg_scatter=nodes[i]->trg_scatter;
  140. for(size_t j=0;j<trg_value.Dim();j++) trg_value_.push_back(trg_value[j]);
  141. for(size_t j=0;j<trg_scatter.Dim();j++) trg_scatter_.push_back(trg_scatter[j]);
  142. }
  143. }
  144. trg_value=trg_value_;
  145. trg_scatter=trg_scatter_;
  146. }
  147. par::ScatterReverse(trg_value,trg_scatter,*tree->Comm(),loc_size);
  148. trg_val.assign(&trg_value[0],&trg_value[0]+trg_value.Dim());;
  149. }
  150. }//end namespace
  151. #endif //_PVFMM_HPP_