comm.hpp 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. #ifndef _SCTL_COMM_HPP_
  2. #define _SCTL_COMM_HPP_
  3. #include SCTL_INCLUDE(common.hpp)
  4. #include <map>
  5. #include <stack>
  6. #ifdef SCTL_HAVE_MPI
  7. #include <mpi.h>
  8. #endif
  9. namespace SCTL_NAMESPACE {
  10. template <class ValueType> class Vector;
  11. class Comm {
  12. public:
  13. enum class CommOp {
  14. SUM,
  15. MIN,
  16. MAX
  17. };
  18. Comm();
  19. #ifdef SCTL_HAVE_MPI
  20. Comm(const MPI_Comm mpi_comm) { Init(mpi_comm); }
  21. #endif
  22. Comm(const Comm& c);
  23. static Comm Self();
  24. static Comm World();
  25. Comm& operator=(const Comm& c);
  26. ~Comm();
  27. #ifdef SCTL_HAVE_MPI
  28. MPI_Comm GetMPI_Comm() { return mpi_comm_; }
  29. #endif
  30. Comm Split(Integer clr) const;
  31. Integer Rank() const;
  32. Integer Size() const;
  33. void Barrier() const;
  34. template <class SType> void* Isend(ConstIterator<SType> sbuf, Long scount, Integer dest, Integer tag = 0) const;
  35. template <class RType> void* Irecv(Iterator<RType> rbuf, Long rcount, Integer source, Integer tag = 0) const;
  36. void Wait(void* req_ptr) const;
  37. template <class SType, class RType> void Allgather(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
  38. template <class SType, class RType> void Allgatherv(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
  39. template <class SType, class RType> void Alltoall(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
  40. template <class SType, class RType> void* Ialltoallv_sparse(ConstIterator<SType> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls, Integer tag = 0) const;
  41. template <class Type> void Alltoallv(ConstIterator<Type> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<Type> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
  42. template <class Type> void Allreduce(ConstIterator<Type> sbuf, Iterator<Type> rbuf, Long count, CommOp op) const;
  43. template <class Type> void Scan(ConstIterator<Type> sbuf, Iterator<Type> rbuf, int count, CommOp op) const;
  44. template <class Type> void PartitionW(Vector<Type>& nodeList, const Vector<Long>* wts_ = nullptr) const;
  45. template <class Type> void PartitionN(Vector<Type>& v, Long N) const;
  46. template <class Type> void PartitionS(Vector<Type>& nodeList, const Type& splitter) const;
  47. template <class Type> void HyperQuickSort(const Vector<Type>& arr_, Vector<Type>& SortedElem) const;
  48. template <class Type> void SortScatterIndex(const Vector<Type>& key, Vector<Long>& scatter_index, const Type* split_key_ = nullptr) const;
  49. template <class Type> void ScatterForward(Vector<Type>& data_, const Vector<Long>& scatter_index) const;
  50. template <class Type> void ScatterReverse(Vector<Type>& data_, const Vector<Long>& scatter_index_, Long loc_size_ = 0) const;
  51. private:
  52. template <typename A, typename B> struct SortPair {
  53. int operator<(const SortPair<A, B>& p1) const { return key < p1.key; }
  54. A key;
  55. B data;
  56. };
  57. #ifdef SCTL_HAVE_MPI
  58. void Init(const MPI_Comm mpi_comm) {
  59. #pragma omp critical(SCTL_COMM_DUP)
  60. MPI_Comm_dup(mpi_comm, &mpi_comm_);
  61. MPI_Comm_rank(mpi_comm_, &mpi_rank_);
  62. MPI_Comm_size(mpi_comm_, &mpi_size_);
  63. }
  64. Vector<MPI_Request>* NewReq() const;
  65. void DelReq(Vector<MPI_Request>* req_ptr) const;
  66. mutable std::stack<void*> req;
  67. int mpi_rank_;
  68. int mpi_size_;
  69. MPI_Comm mpi_comm_;
  70. /**
  71. * \class CommDatatype
  72. * \brief An abstract class used for communicating messages using user-defined
  73. * datatypes. The user must implement the static member function "value()" that
  74. * returns the MPI_Datatype corresponding to this user-defined datatype.
  75. * \author Hari Sundar, hsundar@gmail.com
  76. */
  77. template <class Type> class CommDatatype {
  78. public:
  79. static MPI_Datatype value() {
  80. static bool first = true;
  81. static MPI_Datatype datatype;
  82. if (first) {
  83. first = false;
  84. MPI_Type_contiguous(sizeof(Type), MPI_BYTE, &datatype);
  85. MPI_Type_commit(&datatype);
  86. }
  87. return datatype;
  88. }
  89. static MPI_Op sum() {
  90. static bool first = true;
  91. static MPI_Op myop;
  92. if (first) {
  93. first = false;
  94. int commune = 1;
  95. MPI_Op_create(sum_fn, commune, &myop);
  96. }
  97. return myop;
  98. }
  99. static MPI_Op min() {
  100. static bool first = true;
  101. static MPI_Op myop;
  102. if (first) {
  103. first = false;
  104. int commune = 1;
  105. MPI_Op_create(min_fn, commune, &myop);
  106. }
  107. return myop;
  108. }
  109. static MPI_Op max() {
  110. static bool first = true;
  111. static MPI_Op myop;
  112. if (first) {
  113. first = false;
  114. int commune = 1;
  115. MPI_Op_create(max_fn, commune, &myop);
  116. }
  117. return myop;
  118. }
  119. private:
  120. static void sum_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  121. Type* a = (Type*)a_;
  122. Type* b = (Type*)b_;
  123. int len = *len_;
  124. for (int i = 0; i < len; i++) {
  125. b[i] = a[i] + b[i];
  126. }
  127. }
  128. static void min_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  129. Type* a = (Type*)a_;
  130. Type* b = (Type*)b_;
  131. int len = *len_;
  132. for (int i = 0; i < len; i++) {
  133. if (a[i] < b[i]) b[i] = a[i];
  134. }
  135. }
  136. static void max_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  137. Type* a = (Type*)a_;
  138. Type* b = (Type*)b_;
  139. int len = *len_;
  140. for (int i = 0; i < len; i++) {
  141. if (a[i] > b[i]) b[i] = a[i];
  142. }
  143. }
  144. };
  145. #else
  146. mutable std::multimap<Integer, ConstIterator<char>> send_req;
  147. mutable std::multimap<Integer, Iterator<char>> recv_req;
  148. #endif
  149. };
  150. } // end namespace
  151. #include SCTL_INCLUDE(comm.txx)
  152. #endif //_SCTL_COMM_HPP_