comm.hpp 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. #ifndef _SCTL_COMM_HPP_
  2. #define _SCTL_COMM_HPP_
  3. #include <sctl/common.hpp>
  4. #include <map>
  5. #include <stack>
  6. #ifdef SCTL_HAVE_MPI
  7. #include <mpi.h>
  8. #endif
  9. #ifdef SCTL_HAVE_PETSC
  10. #include <petscsys.h>
  11. #endif
  12. namespace SCTL_NAMESPACE {
  13. template <class ValueType> class Vector;
  14. class Comm {
  15. public:
  16. enum class CommOp {
  17. SUM,
  18. MIN,
  19. MAX
  20. };
  21. static void MPI_Init(int* argc, char*** argv) {
  22. #ifdef SCTL_HAVE_PETSC
  23. PetscInitialize(argc, argv, NULL, NULL);
  24. #elif defined(SCTL_HAVE_MPI)
  25. ::MPI_Init(argc, argv);
  26. #endif
  27. }
  28. static void MPI_Finalize() {
  29. #ifdef SCTL_HAVE_PETSC
  30. PetscFinalize();
  31. #elif defined(SCTL_HAVE_MPI)
  32. ::MPI_Finalize();
  33. #endif
  34. }
  35. Comm();
  36. #ifdef SCTL_HAVE_MPI
  37. explicit Comm(const MPI_Comm mpi_comm) { Init(mpi_comm); }
  38. #endif
  39. Comm(const Comm& c);
  40. static Comm Self();
  41. static Comm World();
  42. Comm& operator=(const Comm& c);
  43. ~Comm();
  44. #ifdef SCTL_HAVE_MPI
  45. MPI_Comm& GetMPI_Comm() { return mpi_comm_; }
  46. const MPI_Comm& GetMPI_Comm() const { return mpi_comm_; }
  47. #endif
  48. Comm Split(Integer clr) const;
  49. Integer Rank() const;
  50. Integer Size() const;
  51. void Barrier() const;
  52. template <class SType> void* Isend(ConstIterator<SType> sbuf, Long scount, Integer dest, Integer tag = 0) const;
  53. template <class RType> void* Irecv(Iterator<RType> rbuf, Long rcount, Integer source, Integer tag = 0) const;
  54. void Wait(void* req_ptr) const;
  55. template <class Type> void Bcast(Iterator<Type> buf, Long count, Long root) const;
  56. template <class SType, class RType> void Allgather(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
  57. template <class SType, class RType> void Allgatherv(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
  58. template <class SType, class RType> void Alltoall(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
  59. template <class SType, class RType> void* Ialltoallv_sparse(ConstIterator<SType> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls, Integer tag = 0) const;
  60. template <class Type> void Alltoallv(ConstIterator<Type> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<Type> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
  61. template <class Type> void Allreduce(ConstIterator<Type> sbuf, Iterator<Type> rbuf, Long count, CommOp op) const;
  62. template <class Type> void Scan(ConstIterator<Type> sbuf, Iterator<Type> rbuf, int count, CommOp op) const;
  63. template <class Type> void PartitionW(Vector<Type>& nodeList, const Vector<Long>* wts_ = nullptr) const;
  64. template <class Type> void PartitionN(Vector<Type>& v, Long N) const;
  65. template <class Type, class Compare> void PartitionS(Vector<Type>& nodeList, const Type& splitter, Compare comp) const;
  66. template <class Type> void PartitionS(Vector<Type>& nodeList, const Type& splitter) const {
  67. PartitionS(nodeList, splitter, std::less<Type>());
  68. }
  69. template <class Type, class Compare> void HyperQuickSort(const Vector<Type>& arr_, Vector<Type>& SortedElem, Compare comp) const;
  70. template <class Type> void HyperQuickSort(const Vector<Type>& arr_, Vector<Type>& SortedElem) const {
  71. HyperQuickSort(arr_, SortedElem, std::less<Type>());
  72. }
  73. template <class Type> void SortScatterIndex(const Vector<Type>& key, Vector<Long>& scatter_index, const Type* split_key_ = nullptr) const;
  74. template <class Type> void ScatterForward(Vector<Type>& data_, const Vector<Long>& scatter_index) const;
  75. template <class Type> void ScatterReverse(Vector<Type>& data_, const Vector<Long>& scatter_index_, Long loc_size_ = 0) const;
  76. private:
  77. template <typename A, typename B> struct SortPair {
  78. int operator<(const SortPair<A, B>& p1) const { return key < p1.key; }
  79. A key;
  80. B data;
  81. };
  82. #ifdef SCTL_HAVE_MPI
  83. void Init(const MPI_Comm mpi_comm);
  84. Vector<MPI_Request>* NewReq() const;
  85. void DelReq(Vector<MPI_Request>* req_ptr) const;
  86. mutable std::stack<void*> req;
  87. int mpi_rank_;
  88. int mpi_size_;
  89. MPI_Comm mpi_comm_;
  90. /**
  91. * \class CommDatatype
  92. * \brief An abstract class used for communicating messages using user-defined
  93. * datatypes. The user must implement the static member function "value()" that
  94. * returns the MPI_Datatype corresponding to this user-defined datatype.
  95. * \author Hari Sundar, hsundar@gmail.com
  96. */
  97. template <class Type> class CommDatatype {
  98. public:
  99. static MPI_Datatype value() {
  100. static bool first = true;
  101. static MPI_Datatype datatype;
  102. if (first) {
  103. first = false;
  104. MPI_Type_contiguous(sizeof(Type), MPI_BYTE, &datatype);
  105. MPI_Type_commit(&datatype);
  106. }
  107. return datatype;
  108. }
  109. static MPI_Op sum() {
  110. static bool first = true;
  111. static MPI_Op myop;
  112. if (first) {
  113. first = false;
  114. int commune = 1;
  115. MPI_Op_create(sum_fn, commune, &myop);
  116. }
  117. return myop;
  118. }
  119. static MPI_Op min() {
  120. static bool first = true;
  121. static MPI_Op myop;
  122. if (first) {
  123. first = false;
  124. int commune = 1;
  125. MPI_Op_create(min_fn, commune, &myop);
  126. }
  127. return myop;
  128. }
  129. static MPI_Op max() {
  130. static bool first = true;
  131. static MPI_Op myop;
  132. if (first) {
  133. first = false;
  134. int commune = 1;
  135. MPI_Op_create(max_fn, commune, &myop);
  136. }
  137. return myop;
  138. }
  139. private:
  140. static void sum_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  141. Type* a = (Type*)a_;
  142. Type* b = (Type*)b_;
  143. int len = *len_;
  144. for (int i = 0; i < len; i++) {
  145. b[i] = a[i] + b[i];
  146. }
  147. }
  148. static void min_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  149. Type* a = (Type*)a_;
  150. Type* b = (Type*)b_;
  151. int len = *len_;
  152. for (int i = 0; i < len; i++) {
  153. if (a[i] < b[i]) b[i] = a[i];
  154. }
  155. }
  156. static void max_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  157. Type* a = (Type*)a_;
  158. Type* b = (Type*)b_;
  159. int len = *len_;
  160. for (int i = 0; i < len; i++) {
  161. if (a[i] > b[i]) b[i] = a[i];
  162. }
  163. }
  164. };
  165. #else
  166. mutable std::multimap<Integer, ConstIterator<char>> send_req;
  167. mutable std::multimap<Integer, Iterator<char>> recv_req;
  168. #endif
  169. };
  170. } // end namespace
  171. #include SCTL_INCLUDE(comm.txx)
  172. #endif //_SCTL_COMM_HPP_