comm.hpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. #ifndef _PVFMM_COMM_HPP_
  2. #define _PVFMM_COMM_HPP_
  3. #include <pvfmm/common.hpp>
  4. #include <map>
  5. #include <stack>
  6. #ifdef PVFMM_HAVE_MPI
  7. #include <mpi.h>
  8. #endif
  9. namespace pvfmm {
  10. template <class ValueType> class Vector;
  11. class Comm {
  12. public:
  13. enum class CommOp {
  14. SUM,
  15. MIN,
  16. MAX
  17. };
  18. Comm();
  19. Comm(const Comm& c);
  20. static Comm Self();
  21. static Comm World();
  22. Comm& operator=(const Comm& c);
  23. ~Comm();
  24. Comm Split(Integer clr) const;
  25. Integer Rank() const;
  26. Integer Size() const;
  27. void Barrier() const;
  28. template <class SType> void* Isend(ConstIterator<SType> sbuf, Long scount, Integer dest, Integer tag = 0) const;
  29. template <class RType> void* Irecv(Iterator<RType> rbuf, Long rcount, Integer source, Integer tag = 0) const;
  30. void Wait(void* req_ptr) const;
  31. template <class SType, class RType> void Allgather(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
  32. template <class SType, class RType> void Allgatherv(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
  33. template <class SType, class RType> void Alltoall(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
  34. template <class SType, class RType> void* Ialltoallv_sparse(ConstIterator<SType> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls, Integer tag = 0) const;
  35. template <class Type> void Alltoallv(ConstIterator<Type> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<Type> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
  36. template <class Type> void Allreduce(ConstIterator<Type> sbuf, Iterator<Type> rbuf, Long count, CommOp op) const;
  37. template <class Type> void Scan(ConstIterator<Type> sbuf, Iterator<Type> rbuf, int count, CommOp op) const;
  38. template <class Type> void PartitionW(Vector<Type>& nodeList, const Vector<Long>* wts_ = NULL) const;
  39. template <class Type> void PartitionN(Vector<Type>& v, Long N) const;
  40. template <class Type> void PartitionS(Vector<Type>& nodeList, const Type& splitter) const;
  41. template <class Type> void HyperQuickSort(const Vector<Type>& arr_, Vector<Type>& SortedElem) const;
  42. template <class Type> void SortScatterIndex(const Vector<Type>& key, Vector<Long>& scatter_index, const Type* split_key_ = NULL) const;
  43. template <class Type> void ScatterForward(Vector<Type>& data_, const Vector<Long>& scatter_index) const;
  44. template <class Type> void ScatterReverse(Vector<Type>& data_, const Vector<Long>& scatter_index_, Long loc_size_ = 0) const;
  45. private:
  46. template <typename A, typename B> struct SortPair {
  47. int operator<(const SortPair<A, B>& p1) const { return key < p1.key; }
  48. A key;
  49. B data;
  50. };
  51. #ifdef PVFMM_HAVE_MPI
  52. Comm(const MPI_Comm mpi_comm) { Init(mpi_comm); }
  53. void Init(const MPI_Comm mpi_comm) {
  54. MPI_Comm_dup(mpi_comm, &mpi_comm_);
  55. MPI_Comm_rank(mpi_comm_, &mpi_rank_);
  56. MPI_Comm_size(mpi_comm_, &mpi_size_);
  57. }
  58. Vector<MPI_Request>* NewReq() const;
  59. void DelReq(Vector<MPI_Request>* req_ptr) const;
  60. mutable std::stack<void*> req;
  61. int mpi_rank_;
  62. int mpi_size_;
  63. MPI_Comm mpi_comm_;
  64. /**
  65. * \class CommDatatype
  66. * \brief An abstract class used for communicating messages using user-defined
  67. * datatypes. The user must implement the static member function "value()" that
  68. * returns the MPI_Datatype corresponding to this user-defined datatype.
  69. * \author Hari Sundar, hsundar@gmail.com
  70. */
  71. template <class Type> class CommDatatype {
  72. public:
  73. static MPI_Datatype value() {
  74. static bool first = true;
  75. static MPI_Datatype datatype;
  76. if (first) {
  77. first = false;
  78. MPI_Type_contiguous(sizeof(Type), MPI_BYTE, &datatype);
  79. MPI_Type_commit(&datatype);
  80. }
  81. return datatype;
  82. }
  83. static MPI_Op sum() {
  84. static bool first = true;
  85. static MPI_Op myop;
  86. if (first) {
  87. first = false;
  88. int commune = 1;
  89. MPI_Op_create(sum_fn, commune, &myop);
  90. }
  91. return myop;
  92. }
  93. static MPI_Op min() {
  94. static bool first = true;
  95. static MPI_Op myop;
  96. if (first) {
  97. first = false;
  98. int commune = 1;
  99. MPI_Op_create(min_fn, commune, &myop);
  100. }
  101. return myop;
  102. }
  103. static MPI_Op max() {
  104. static bool first = true;
  105. static MPI_Op myop;
  106. if (first) {
  107. first = false;
  108. int commune = 1;
  109. MPI_Op_create(max_fn, commune, &myop);
  110. }
  111. return myop;
  112. }
  113. private:
  114. static void sum_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  115. Type* a = (Type*)a_;
  116. Type* b = (Type*)b_;
  117. int len = *len_;
  118. for (int i = 0; i < len; i++) {
  119. b[i] = a[i] + b[i];
  120. }
  121. }
  122. static void min_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  123. Type* a = (Type*)a_;
  124. Type* b = (Type*)b_;
  125. int len = *len_;
  126. for (int i = 0; i < len; i++) {
  127. if (a[i] < b[i]) b[i] = a[i];
  128. }
  129. }
  130. static void max_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  131. Type* a = (Type*)a_;
  132. Type* b = (Type*)b_;
  133. int len = *len_;
  134. for (int i = 0; i < len; i++) {
  135. if (a[i] > b[i]) b[i] = a[i];
  136. }
  137. }
  138. };
  139. #else
  140. mutable std::multimap<Integer, ConstIterator<char>> send_req;
  141. mutable std::multimap<Integer, Iterator<char>> recv_req;
  142. #endif
  143. };
  144. } // end namespace
  145. #include <pvfmm/comm.txx>
  146. #endif //_PVFMM_COMM_HPP_