#ifndef _SCTL_COMM_HPP_ #define _SCTL_COMM_HPP_ #include #include #include #ifdef SCTL_HAVE_MPI #include #endif #ifdef SCTL_HAVE_PETSC #include #endif namespace SCTL_NAMESPACE { template class Vector; class Comm { public: enum class CommOp { SUM, MIN, MAX }; static void MPI_Init(int* argc, char*** argv) { #ifdef SCTL_HAVE_PETSC PetscInitialize(argc, argv, NULL, NULL); #elif defined(SCTL_HAVE_MPI) ::MPI_Init(argc, argv); #endif } static void MPI_Finalize() { #ifdef SCTL_HAVE_PETSC PetscFinalize(); #elif defined(SCTL_HAVE_MPI) ::MPI_Finalize(); #endif } Comm(); #ifdef SCTL_HAVE_MPI explicit Comm(const MPI_Comm mpi_comm) { Init(mpi_comm); } #endif Comm(const Comm& c); static Comm Self(); static Comm World(); Comm& operator=(const Comm& c); ~Comm(); #ifdef SCTL_HAVE_MPI MPI_Comm& GetMPI_Comm() { return mpi_comm_; } const MPI_Comm& GetMPI_Comm() const { return mpi_comm_; } #endif Comm Split(Integer clr) const; Integer Rank() const; Integer Size() const; void Barrier() const; template void* Isend(ConstIterator sbuf, Long scount, Integer dest, Integer tag = 0) const; template void* Irecv(Iterator rbuf, Long rcount, Integer source, Integer tag = 0) const; void Wait(void* req_ptr) const; template void Bcast(Iterator buf, Long count, Long root) const; template void Allgather(ConstIterator sbuf, Long scount, Iterator rbuf, Long rcount) const; template void Allgatherv(ConstIterator sbuf, Long scount, Iterator rbuf, ConstIterator rcounts, ConstIterator rdispls) const; template void Alltoall(ConstIterator sbuf, Long scount, Iterator rbuf, Long rcount) const; template void* Ialltoallv_sparse(ConstIterator sbuf, ConstIterator scounts, ConstIterator sdispls, Iterator rbuf, ConstIterator rcounts, ConstIterator rdispls, Integer tag = 0) const; template void Alltoallv(ConstIterator sbuf, ConstIterator scounts, ConstIterator sdispls, Iterator rbuf, ConstIterator rcounts, ConstIterator rdispls) const; template void Allreduce(ConstIterator sbuf, Iterator rbuf, Long count, CommOp op) const; template void Scan(ConstIterator sbuf, Iterator rbuf, int count, CommOp op) const; template void PartitionW(Vector& nodeList, const Vector* wts_ = nullptr) const; template void PartitionN(Vector& v, Long N) const; template void PartitionS(Vector& nodeList, const Type& splitter, Compare comp) const; template void PartitionS(Vector& nodeList, const Type& splitter) const { PartitionS(nodeList, splitter, std::less()); } template void HyperQuickSort(const Vector& arr_, Vector& SortedElem, Compare comp) const; template void HyperQuickSort(const Vector& arr_, Vector& SortedElem) const { HyperQuickSort(arr_, SortedElem, std::less()); } template void SortScatterIndex(const Vector& key, Vector& scatter_index, const Type* split_key_ = nullptr) const; template void ScatterForward(Vector& data_, const Vector& scatter_index) const; template void ScatterReverse(Vector& data_, const Vector& scatter_index_, Long loc_size_ = 0) const; private: template struct SortPair { int operator<(const SortPair& p1) const { return key < p1.key; } A key; B data; }; #ifdef SCTL_HAVE_MPI void Init(const MPI_Comm mpi_comm); Vector* NewReq() const; void DelReq(Vector* req_ptr) const; mutable std::stack req; int mpi_rank_; int mpi_size_; MPI_Comm mpi_comm_; /** * \class CommDatatype * \brief An abstract class used for communicating messages using user-defined * datatypes. The user must implement the static member function "value()" that * returns the MPI_Datatype corresponding to this user-defined datatype. * \author Hari Sundar, hsundar@gmail.com */ template class CommDatatype { public: static MPI_Datatype value() { static bool first = true; static MPI_Datatype datatype; if (first) { first = false; MPI_Type_contiguous(sizeof(Type), MPI_BYTE, &datatype); MPI_Type_commit(&datatype); } return datatype; } static MPI_Op sum() { static bool first = true; static MPI_Op myop; if (first) { first = false; int commune = 1; MPI_Op_create(sum_fn, commune, &myop); } return myop; } static MPI_Op min() { static bool first = true; static MPI_Op myop; if (first) { first = false; int commune = 1; MPI_Op_create(min_fn, commune, &myop); } return myop; } static MPI_Op max() { static bool first = true; static MPI_Op myop; if (first) { first = false; int commune = 1; MPI_Op_create(max_fn, commune, &myop); } return myop; } private: static void sum_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) { Type* a = (Type*)a_; Type* b = (Type*)b_; int len = *len_; for (int i = 0; i < len; i++) { b[i] = a[i] + b[i]; } } static void min_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) { Type* a = (Type*)a_; Type* b = (Type*)b_; int len = *len_; for (int i = 0; i < len; i++) { if (a[i] < b[i]) b[i] = a[i]; } } static void max_fn(void* a_, void* b_, int* len_, MPI_Datatype* datatype) { Type* a = (Type*)a_; Type* b = (Type*)b_; int len = *len_; for (int i = 0; i < len; i++) { if (a[i] > b[i]) b[i] = a[i]; } } }; #else mutable std::multimap> send_req; mutable std::multimap> recv_req; #endif }; } // end namespace #include SCTL_INCLUDE(comm.txx) #endif //_SCTL_COMM_HPP_