comm.hpp 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. #ifndef _SCTL_COMM_HPP_
  2. #define _SCTL_COMM_HPP_
  3. #include SCTL_INCLUDE(common.hpp)
  4. #include <map>
  5. #include <stack>
  6. #ifdef SCTL_HAVE_MPI
  7. #include <mpi.h>
  8. #endif
  9. namespace SCTL_NAMESPACE {
  10. template <class ValueType> class Vector;
  11. class Comm {
  12. public:
  13. enum class CommOp {
  14. SUM,
  15. MIN,
  16. MAX
  17. };
  18. Comm();
  19. #ifdef SCTL_HAVE_MPI
  20. Comm(const MPI_Comm mpi_comm) { Init(mpi_comm); }
  21. #endif
  22. Comm(const Comm& c);
  23. static Comm Self();
  24. static Comm World();
  25. Comm& operator=(const Comm& c);
  26. ~Comm();
  27. #ifdef SCTL_HAVE_MPI
  28. MPI_Comm GetMPI_Comm() { return mpi_comm_; }
  29. #endif
  30. Comm Split(Integer clr) const;
  31. Integer Rank() const;
  32. Integer Size() const;
  33. void Barrier() const;
  34. template <class SType> void* Isend(ConstIterator<SType> sbuf, Long scount, Integer dest, Integer tag = 0) const;
  35. template <class RType> void* Irecv(Iterator<RType> rbuf, Long rcount, Integer source, Integer tag = 0) const;
  36. void Wait(void* req_ptr) const;
  37. template <class SType, class RType> void Allgather(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
  38. template <class SType, class RType> void Allgatherv(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
  39. template <class SType, class RType> void Alltoall(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
  40. template <class SType, class RType> void* Ialltoallv_sparse(ConstIterator<SType> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls, Integer tag = 0) const;
  41. template <class Type> void Alltoallv(ConstIterator<Type> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<Type> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
  42. template <class Type> void Allreduce(ConstIterator<Type> sbuf, Iterator<Type> rbuf, Long count, CommOp op) const;
  43. template <class Type> void Scan(ConstIterator<Type> sbuf, Iterator<Type> rbuf, int count, CommOp op) const;
  44. template <class Type> void PartitionW(Vector<Type>& nodeList, const Vector<Long>* wts_ = nullptr) const;
  45. template <class Type> void PartitionN(Vector<Type>& v, Long N) const;
  46. template <class Type> void PartitionS(Vector<Type>& nodeList, const Type& splitter) const;
  47. template <class Type> void HyperQuickSort(const Vector<Type>& arr_, Vector<Type>& SortedElem) const;
  48. template <class Type> void SortScatterIndex(const Vector<Type>& key, Vector<Long>& scatter_index, const Type* split_key_ = nullptr) const;
  49. template <class Type> void ScatterForward(Vector<Type>& data_, const Vector<Long>& scatter_index) const;
  50. template <class Type> void ScatterReverse(Vector<Type>& data_, const Vector<Long>& scatter_index_, Long loc_size_ = 0) const;
  51. private:
  52. template <typename A, typename B> struct SortPair {
  53. int operator<(const SortPair<A, B>& p1) const { return key < p1.key; }
  54. A key;
  55. B data;
  56. };
  57. #ifdef SCTL_HAVE_MPI
  58. void Init(const MPI_Comm mpi_comm);
  59. Vector<MPI_Request>* NewReq() const;
  60. void DelReq(Vector<MPI_Request>* req_ptr) const;
  61. mutable std::stack<void*> req;
  62. int mpi_rank_;
  63. int mpi_size_;
  64. MPI_Comm mpi_comm_;
  65. /**
  66. * \class CommDatatype
  67. * \brief An abstract class used for communicating messages using user-defined
  68. * datatypes. The user must implement the static member function "value()" that
  69. * returns the MPI_Datatype corresponding to this user-defined datatype.
  70. * \author Hari Sundar, hsundar@gmail.com
  71. */
  72. template <class Type> class CommDatatype {
  73. public:
  74. static MPI_Datatype value() {
  75. static bool first = true;
  76. static MPI_Datatype datatype;
  77. if (first) {
  78. first = false;
  79. MPI_Type_contiguous(sizeof(Type), MPI_BYTE, &datatype);
  80. MPI_Type_commit(&datatype);
  81. }
  82. return datatype;
  83. }
  84. static MPI_Op sum() {
  85. static bool first = true;
  86. static MPI_Op myop;
  87. if (first) {
  88. first = false;
  89. int commune = 1;
  90. MPI_Op_create(sum_fn, commune, &myop);
  91. }
  92. return myop;
  93. }
  94. static MPI_Op min() {
  95. static bool first = true;
  96. static MPI_Op myop;
  97. if (first) {
  98. first = false;
  99. int commune = 1;
  100. MPI_Op_create(min_fn, commune, &myop);
  101. }
  102. return myop;
  103. }
  104. static MPI_Op max() {
  105. static bool first = true;
  106. static MPI_Op myop;
  107. if (first) {
  108. first = false;
  109. int commune = 1;
  110. MPI_Op_create(max_fn, commune, &myop);
  111. }
  112. return myop;
  113. }
  114. private:
  115. static void sum_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  116. Type* a = (Type*)a_;
  117. Type* b = (Type*)b_;
  118. int len = *len_;
  119. for (int i = 0; i < len; i++) {
  120. b[i] = a[i] + b[i];
  121. }
  122. }
  123. static void min_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  124. Type* a = (Type*)a_;
  125. Type* b = (Type*)b_;
  126. int len = *len_;
  127. for (int i = 0; i < len; i++) {
  128. if (a[i] < b[i]) b[i] = a[i];
  129. }
  130. }
  131. static void max_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  132. Type* a = (Type*)a_;
  133. Type* b = (Type*)b_;
  134. int len = *len_;
  135. for (int i = 0; i < len; i++) {
  136. if (a[i] > b[i]) b[i] = a[i];
  137. }
  138. }
  139. };
  140. #else
  141. mutable std::multimap<Integer, ConstIterator<char>> send_req;
  142. mutable std::multimap<Integer, Iterator<char>> recv_req;
  143. #endif
  144. };
  145. } // end namespace
  146. #include SCTL_INCLUDE(comm.txx)
  147. #endif //_SCTL_COMM_HPP_