comm.hpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. #ifndef _SCTL_COMM_HPP_
  2. #define _SCTL_COMM_HPP_
  3. #include SCTL_INCLUDE(common.hpp)
  4. #include <map>
  5. #include <stack>
  6. #ifdef SCTL_HAVE_MPI
  7. #include <mpi.h>
  8. #endif
  9. namespace SCTL_NAMESPACE {
  10. template <class ValueType> class Vector;
  11. class Comm {
  12. public:
  13. enum class CommOp {
  14. SUM,
  15. MIN,
  16. MAX
  17. };
  18. Comm();
  19. #ifdef SCTL_HAVE_MPI
  20. Comm(const MPI_Comm mpi_comm) { Init(mpi_comm); }
  21. #endif
  22. Comm(const Comm& c);
  23. static Comm Self();
  24. static Comm World();
  25. Comm& operator=(const Comm& c);
  26. ~Comm();
  27. Comm Split(Integer clr) const;
  28. Integer Rank() const;
  29. Integer Size() const;
  30. void Barrier() const;
  31. template <class SType> void* Isend(ConstIterator<SType> sbuf, Long scount, Integer dest, Integer tag = 0) const;
  32. template <class RType> void* Irecv(Iterator<RType> rbuf, Long rcount, Integer source, Integer tag = 0) const;
  33. void Wait(void* req_ptr) const;
  34. template <class SType, class RType> void Allgather(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
  35. template <class SType, class RType> void Allgatherv(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
  36. template <class SType, class RType> void Alltoall(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
  37. template <class SType, class RType> void* Ialltoallv_sparse(ConstIterator<SType> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls, Integer tag = 0) const;
  38. template <class Type> void Alltoallv(ConstIterator<Type> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<Type> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
  39. template <class Type> void Allreduce(ConstIterator<Type> sbuf, Iterator<Type> rbuf, Long count, CommOp op) const;
  40. template <class Type> void Scan(ConstIterator<Type> sbuf, Iterator<Type> rbuf, int count, CommOp op) const;
  41. template <class Type> void PartitionW(Vector<Type>& nodeList, const Vector<Long>* wts_ = nullptr) const;
  42. template <class Type> void PartitionN(Vector<Type>& v, Long N) const;
  43. template <class Type> void PartitionS(Vector<Type>& nodeList, const Type& splitter) const;
  44. template <class Type> void HyperQuickSort(const Vector<Type>& arr_, Vector<Type>& SortedElem) const;
  45. template <class Type> void SortScatterIndex(const Vector<Type>& key, Vector<Long>& scatter_index, const Type* split_key_ = nullptr) const;
  46. template <class Type> void ScatterForward(Vector<Type>& data_, const Vector<Long>& scatter_index) const;
  47. template <class Type> void ScatterReverse(Vector<Type>& data_, const Vector<Long>& scatter_index_, Long loc_size_ = 0) const;
  48. private:
  49. template <typename A, typename B> struct SortPair {
  50. int operator<(const SortPair<A, B>& p1) const { return key < p1.key; }
  51. A key;
  52. B data;
  53. };
  54. #ifdef SCTL_HAVE_MPI
  55. void Init(const MPI_Comm mpi_comm) {
  56. MPI_Comm_dup(mpi_comm, &mpi_comm_);
  57. MPI_Comm_rank(mpi_comm_, &mpi_rank_);
  58. MPI_Comm_size(mpi_comm_, &mpi_size_);
  59. }
  60. Vector<MPI_Request>* NewReq() const;
  61. void DelReq(Vector<MPI_Request>* req_ptr) const;
  62. mutable std::stack<void*> req;
  63. int mpi_rank_;
  64. int mpi_size_;
  65. MPI_Comm mpi_comm_;
  66. /**
  67. * \class CommDatatype
  68. * \brief An abstract class used for communicating messages using user-defined
  69. * datatypes. The user must implement the static member function "value()" that
  70. * returns the MPI_Datatype corresponding to this user-defined datatype.
  71. * \author Hari Sundar, hsundar@gmail.com
  72. */
  73. template <class Type> class CommDatatype {
  74. public:
  75. static MPI_Datatype value() {
  76. static bool first = true;
  77. static MPI_Datatype datatype;
  78. if (first) {
  79. first = false;
  80. MPI_Type_contiguous(sizeof(Type), MPI_BYTE, &datatype);
  81. MPI_Type_commit(&datatype);
  82. }
  83. return datatype;
  84. }
  85. static MPI_Op sum() {
  86. static bool first = true;
  87. static MPI_Op myop;
  88. if (first) {
  89. first = false;
  90. int commune = 1;
  91. MPI_Op_create(sum_fn, commune, &myop);
  92. }
  93. return myop;
  94. }
  95. static MPI_Op min() {
  96. static bool first = true;
  97. static MPI_Op myop;
  98. if (first) {
  99. first = false;
  100. int commune = 1;
  101. MPI_Op_create(min_fn, commune, &myop);
  102. }
  103. return myop;
  104. }
  105. static MPI_Op max() {
  106. static bool first = true;
  107. static MPI_Op myop;
  108. if (first) {
  109. first = false;
  110. int commune = 1;
  111. MPI_Op_create(max_fn, commune, &myop);
  112. }
  113. return myop;
  114. }
  115. private:
  116. static void sum_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  117. Type* a = (Type*)a_;
  118. Type* b = (Type*)b_;
  119. int len = *len_;
  120. for (int i = 0; i < len; i++) {
  121. b[i] = a[i] + b[i];
  122. }
  123. }
  124. static void min_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  125. Type* a = (Type*)a_;
  126. Type* b = (Type*)b_;
  127. int len = *len_;
  128. for (int i = 0; i < len; i++) {
  129. if (a[i] < b[i]) b[i] = a[i];
  130. }
  131. }
  132. static void max_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  133. Type* a = (Type*)a_;
  134. Type* b = (Type*)b_;
  135. int len = *len_;
  136. for (int i = 0; i < len; i++) {
  137. if (a[i] > b[i]) b[i] = a[i];
  138. }
  139. }
  140. };
  141. #else
  142. mutable std::multimap<Integer, ConstIterator<char>> send_req;
  143. mutable std::multimap<Integer, Iterator<char>> recv_req;
  144. #endif
  145. };
  146. } // end namespace
  147. #include SCTL_INCLUDE(comm.txx)
  148. #endif //_SCTL_COMM_HPP_