comm.hpp 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. #ifndef _SCTL_COMM_HPP_
  2. #define _SCTL_COMM_HPP_
  3. #include SCTL_INCLUDE(common.hpp)
  4. #include <map>
  5. #include <stack>
  6. #ifdef SCTL_HAVE_MPI
  7. #include <mpi.h>
  8. #endif
  9. namespace SCTL_NAMESPACE {
  10. template <class ValueType> class Vector;
  11. class Comm {
  12. public:
  13. enum class CommOp {
  14. SUM,
  15. MIN,
  16. MAX
  17. };
  18. static void MPI_Init(int* argc, char*** argv) {
  19. #ifdef SCTL_HAVE_PETSC
  20. PetscInitialize(argc, argv, NULL, NULL);
  21. #elif defined(SCTL_HAVE_MPI)
  22. ::MPI_Init(argc, argv);
  23. #endif
  24. }
  25. static void MPI_Finalize() {
  26. #ifdef SCTL_HAVE_PETSC
  27. PetscFinalize();
  28. #elif defined(SCTL_HAVE_MPI)
  29. ::MPI_Finalize();
  30. #endif
  31. }
  32. Comm();
  33. #ifdef SCTL_HAVE_MPI
  34. Comm(const MPI_Comm mpi_comm) { Init(mpi_comm); }
  35. #endif
  36. Comm(const Comm& c);
  37. static Comm Self();
  38. static Comm World();
  39. Comm& operator=(const Comm& c);
  40. ~Comm();
  41. #ifdef SCTL_HAVE_MPI
  42. MPI_Comm GetMPI_Comm() { return mpi_comm_; }
  43. #endif
  44. Comm Split(Integer clr) const;
  45. Integer Rank() const;
  46. Integer Size() const;
  47. void Barrier() const;
  48. template <class SType> void* Isend(ConstIterator<SType> sbuf, Long scount, Integer dest, Integer tag = 0) const;
  49. template <class RType> void* Irecv(Iterator<RType> rbuf, Long rcount, Integer source, Integer tag = 0) const;
  50. void Wait(void* req_ptr) const;
  51. template <class SType, class RType> void Allgather(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
  52. template <class SType, class RType> void Allgatherv(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
  53. template <class SType, class RType> void Alltoall(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
  54. template <class SType, class RType> void* Ialltoallv_sparse(ConstIterator<SType> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls, Integer tag = 0) const;
  55. template <class Type> void Alltoallv(ConstIterator<Type> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<Type> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
  56. template <class Type> void Allreduce(ConstIterator<Type> sbuf, Iterator<Type> rbuf, Long count, CommOp op) const;
  57. template <class Type> void Scan(ConstIterator<Type> sbuf, Iterator<Type> rbuf, int count, CommOp op) const;
  58. template <class Type> void PartitionW(Vector<Type>& nodeList, const Vector<Long>* wts_ = nullptr) const;
  59. template <class Type> void PartitionN(Vector<Type>& v, Long N) const;
  60. template <class Type> void PartitionS(Vector<Type>& nodeList, const Type& splitter) const;
  61. template <class Type> void HyperQuickSort(const Vector<Type>& arr_, Vector<Type>& SortedElem) const;
  62. template <class Type> void SortScatterIndex(const Vector<Type>& key, Vector<Long>& scatter_index, const Type* split_key_ = nullptr) const;
  63. template <class Type> void ScatterForward(Vector<Type>& data_, const Vector<Long>& scatter_index) const;
  64. template <class Type> void ScatterReverse(Vector<Type>& data_, const Vector<Long>& scatter_index_, Long loc_size_ = 0) const;
  65. private:
  66. template <typename A, typename B> struct SortPair {
  67. int operator<(const SortPair<A, B>& p1) const { return key < p1.key; }
  68. A key;
  69. B data;
  70. };
  71. #ifdef SCTL_HAVE_MPI
  72. void Init(const MPI_Comm mpi_comm);
  73. Vector<MPI_Request>* NewReq() const;
  74. void DelReq(Vector<MPI_Request>* req_ptr) const;
  75. mutable std::stack<void*> req;
  76. int mpi_rank_;
  77. int mpi_size_;
  78. MPI_Comm mpi_comm_;
  79. /**
  80. * \class CommDatatype
  81. * \brief An abstract class used for communicating messages using user-defined
  82. * datatypes. The user must implement the static member function "value()" that
  83. * returns the MPI_Datatype corresponding to this user-defined datatype.
  84. * \author Hari Sundar, hsundar@gmail.com
  85. */
  86. template <class Type> class CommDatatype {
  87. public:
  88. static MPI_Datatype value() {
  89. static bool first = true;
  90. static MPI_Datatype datatype;
  91. if (first) {
  92. first = false;
  93. MPI_Type_contiguous(sizeof(Type), MPI_BYTE, &datatype);
  94. MPI_Type_commit(&datatype);
  95. }
  96. return datatype;
  97. }
  98. static MPI_Op sum() {
  99. static bool first = true;
  100. static MPI_Op myop;
  101. if (first) {
  102. first = false;
  103. int commune = 1;
  104. MPI_Op_create(sum_fn, commune, &myop);
  105. }
  106. return myop;
  107. }
  108. static MPI_Op min() {
  109. static bool first = true;
  110. static MPI_Op myop;
  111. if (first) {
  112. first = false;
  113. int commune = 1;
  114. MPI_Op_create(min_fn, commune, &myop);
  115. }
  116. return myop;
  117. }
  118. static MPI_Op max() {
  119. static bool first = true;
  120. static MPI_Op myop;
  121. if (first) {
  122. first = false;
  123. int commune = 1;
  124. MPI_Op_create(max_fn, commune, &myop);
  125. }
  126. return myop;
  127. }
  128. private:
  129. static void sum_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  130. Type* a = (Type*)a_;
  131. Type* b = (Type*)b_;
  132. int len = *len_;
  133. for (int i = 0; i < len; i++) {
  134. b[i] = a[i] + b[i];
  135. }
  136. }
  137. static void min_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  138. Type* a = (Type*)a_;
  139. Type* b = (Type*)b_;
  140. int len = *len_;
  141. for (int i = 0; i < len; i++) {
  142. if (a[i] < b[i]) b[i] = a[i];
  143. }
  144. }
  145. static void max_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) {
  146. Type* a = (Type*)a_;
  147. Type* b = (Type*)b_;
  148. int len = *len_;
  149. for (int i = 0; i < len; i++) {
  150. if (a[i] > b[i]) b[i] = a[i];
  151. }
  152. }
  153. };
  154. #else
  155. mutable std::multimap<Integer, ConstIterator<char>> send_req;
  156. mutable std::multimap<Integer, Iterator<char>> recv_req;
  157. #endif
  158. };
  159. } // end namespace
  160. #include SCTL_INCLUDE(comm.txx)
  161. #endif //_SCTL_COMM_HPP_