nearinterac.cpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #include <iostream>
  2. #include <NearInteraction.hpp>
  3. template <class Real, int DIM> class MyObject {
  4. public:
  5. MyObject() {
  6. for (int i = 0; i < DIM; i++) coord[i] = drand48() * 3;
  7. rad = drand48() * 0.01;
  8. }
  9. const Real* Coord() const { return coord; }
  10. Real Rad() const { return rad; }
  11. void Pack(std::vector<char>& buff) const {
  12. size_t count = sizeof(MyObject<Real, DIM>);
  13. buff.resize(count);
  14. memcpy(buff.data(), this, count);
  15. }
  16. void Unpack(const std::vector<char>& buff) {
  17. size_t count = sizeof(MyObject<Real, DIM>);
  18. memcpy(this, buff.data(), count);
  19. }
  20. private:
  21. Real coord[DIM];
  22. Real rad;
  23. };
  24. typedef double Real;
  25. constexpr int DIM = 3;
  26. typedef MyObject<Real,DIM> SrcObj;
  27. typedef MyObject<Real,DIM> TrgObj;
  28. int main(int argc, char **argv) {
  29. int comm_rank = 0;
  30. int comm_size = 1;
  31. #ifdef SCTL_HAVE_MPI
  32. MPI_Init(&argc, &argv);
  33. MPI_Comm mpi_comm = MPI_COMM_WORLD;
  34. MPI_Comm_rank(mpi_comm, &comm_rank);
  35. MPI_Comm_size(mpi_comm, &comm_size);
  36. #endif
  37. // Generate source and target points
  38. srand48(comm_rank);
  39. std::vector<SrcObj> src_vec(10000/comm_size);
  40. std::vector<TrgObj> trg_vec(20000/comm_size);
  41. { // Compute near interactions
  42. sctl::Comm comm;
  43. #ifdef SCTL_HAVE_MPI
  44. comm = mpi_comm;
  45. #endif
  46. NearInteraction<Real,DIM> near_interac(comm);
  47. sctl::Profile::Enable(true);
  48. { // Repartition data
  49. // Setup for repartition
  50. sctl::Profile::Tic("RepartSetup", &comm, true);
  51. near_interac.SetupRepartition(src_vec, trg_vec);
  52. sctl::Profile::Toc();
  53. // Distribute source and target vectors
  54. sctl::Profile::Tic("Repart", &comm, true);
  55. std::vector<SrcObj> src_new;
  56. std::vector<TrgObj> trg_new;
  57. near_interac.ForwardScatterSrc<SrcObj>(src_vec, src_new);
  58. near_interac.ForwardScatterTrg<TrgObj>(trg_vec, trg_new);
  59. src_vec.swap(src_new);
  60. trg_vec.swap(trg_new);
  61. sctl::Profile::Toc();
  62. }
  63. { // Compute near interactions
  64. // Setup for near interaction
  65. sctl::Profile::Tic("NearSetup", &comm, true);
  66. near_interac.SetupNearInterac(src_vec, trg_vec);
  67. sctl::Profile::Toc();
  68. // Following code can repeat multiple times without calling Setup again as long as
  69. // the position and shape of the sources and the targets do not change.
  70. // Forward scatter
  71. sctl::Profile::Tic("Near", &comm, true);
  72. std::vector<SrcObj> src_near;
  73. std::vector<TrgObj> trg_near;
  74. near_interac.ForwardScatterSrc<SrcObj>(src_vec, src_near);
  75. near_interac.ForwardScatterTrg<TrgObj>(trg_vec, trg_near);
  76. const auto& trg_src_interac = near_interac.GetInteractionList(); // Get interaction list
  77. //for (auto interac : trg_src_interac) { // compute near interactions (single thread)
  78. // SrcObj& t = trg_near[interac.first];
  79. // TrgObj& s = src_near[interac.second];
  80. // // ( compute interaction between t and s )
  81. //}
  82. #pragma omp parallel
  83. { // compute near interactions (multiple threads)
  84. int omp_p = omp_get_num_threads();
  85. int tid = omp_get_thread_num();
  86. long N = trg_src_interac.size();
  87. long a = (tid + 0) * N / omp_p;
  88. long b = (tid + 1) * N / omp_p;
  89. // Ensure each thread works on a different target
  90. if (tid > 0) while (a + 1 < N && trg_src_interac[a].first == trg_src_interac[a + 1].first) a++;
  91. while (b + 1 < N && trg_src_interac[b].first == trg_src_interac[b + 1].first) b++;
  92. for (long i = a; i < b; i++) {
  93. SrcObj& t = trg_near[trg_src_interac[i].first];
  94. TrgObj& s = src_near[trg_src_interac[i].second];
  95. // ( compute interaction between t and s )
  96. }
  97. }
  98. // Reverse scatter
  99. near_interac.ReverseScatterTrg<TrgObj>(trg_near, trg_vec);
  100. sctl::Profile::Toc();
  101. }
  102. sctl::Profile::print(&comm);
  103. }
  104. #ifdef SCTL_HAVE_MPI
  105. MPI_Finalize();
  106. #endif
  107. return 0;
  108. }