comm.txx 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158
  1. #include <type_traits>
  2. #include SCTL_INCLUDE(ompUtils.hpp)
  3. #include SCTL_INCLUDE(vector.hpp)
  4. namespace SCTL_NAMESPACE {
  5. inline Comm::Comm() {
  6. #ifdef SCTL_HAVE_MPI
  7. Init(MPI_COMM_SELF);
  8. #endif
  9. }
  10. inline Comm::Comm(const Comm& c) {
  11. #ifdef SCTL_HAVE_MPI
  12. Init(c.mpi_comm_);
  13. #endif
  14. }
  15. inline Comm Comm::Self() {
  16. #ifdef SCTL_HAVE_MPI
  17. Comm comm_self(MPI_COMM_SELF);
  18. return comm_self;
  19. #else
  20. Comm comm_self;
  21. return comm_self;
  22. #endif
  23. }
  24. inline Comm Comm::World() {
  25. #ifdef SCTL_HAVE_MPI
  26. Comm comm_world(MPI_COMM_WORLD);
  27. return comm_world;
  28. #else
  29. Comm comm_self;
  30. return comm_self;
  31. #endif
  32. }
  33. inline Comm& Comm::operator=(const Comm& c) {
  34. #ifdef SCTL_HAVE_MPI
  35. MPI_Comm_free(&mpi_comm_);
  36. Init(c.mpi_comm_);
  37. #endif
  38. return *this;
  39. }
  40. inline Comm::~Comm() {
  41. #ifdef SCTL_HAVE_MPI
  42. while (!req.empty()) {
  43. delete (Vector<MPI_Request>*)req.top();
  44. req.pop();
  45. }
  46. MPI_Comm_free(&mpi_comm_);
  47. #endif
  48. }
  49. inline Comm Comm::Split(Integer clr) const {
  50. #ifdef SCTL_HAVE_MPI
  51. MPI_Comm new_comm;
  52. MPI_Comm_split(mpi_comm_, clr, mpi_rank_, &new_comm);
  53. Comm c(new_comm);
  54. MPI_Comm_free(&new_comm);
  55. return c;
  56. #else
  57. Comm c;
  58. return c;
  59. #endif
  60. }
  61. inline Integer Comm::Rank() const {
  62. #ifdef SCTL_HAVE_MPI
  63. return mpi_rank_;
  64. #else
  65. return 0;
  66. #endif
  67. }
  68. inline Integer Comm::Size() const {
  69. #ifdef SCTL_HAVE_MPI
  70. return mpi_size_;
  71. #else
  72. return 1;
  73. #endif
  74. }
  75. inline void Comm::Barrier() const {
  76. #ifdef SCTL_HAVE_MPI
  77. MPI_Barrier(mpi_comm_);
  78. #endif
  79. }
  80. template <class SType> void* Comm::Isend(ConstIterator<SType> sbuf, Long scount, Integer dest, Integer tag) const {
  81. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  82. #ifdef SCTL_HAVE_MPI
  83. if (!scount) return nullptr;
  84. Vector<MPI_Request>& request = *NewReq();
  85. request.ReInit(1);
  86. SCTL_UNUSED(sbuf[0] );
  87. SCTL_UNUSED(sbuf[scount - 1]);
  88. #ifndef NDEBUG
  89. MPI_Issend(&sbuf[0], scount, CommDatatype<SType>::value(), dest, tag, mpi_comm_, &request[0]);
  90. #else
  91. MPI_Isend(&sbuf[0], scount, CommDatatype<SType>::value(), dest, tag, mpi_comm_, &request[0]);
  92. #endif
  93. return &request;
  94. #else
  95. auto it = recv_req.find(tag);
  96. if (it == recv_req.end())
  97. send_req.insert(std::pair<Integer, ConstIterator<char>>(tag, (ConstIterator<char>)sbuf));
  98. else
  99. memcopy(it->second, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  100. return nullptr;
  101. #endif
  102. }
  103. template <class RType> void* Comm::Irecv(Iterator<RType> rbuf, Long rcount, Integer source, Integer tag) const {
  104. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  105. #ifdef SCTL_HAVE_MPI
  106. if (!rcount) return nullptr;
  107. Vector<MPI_Request>& request = *NewReq();
  108. request.ReInit(1);
  109. SCTL_UNUSED(rbuf[0] );
  110. SCTL_UNUSED(rbuf[rcount - 1]);
  111. MPI_Irecv(&rbuf[0], rcount, CommDatatype<RType>::value(), source, tag, mpi_comm_, &request[0]);
  112. return &request;
  113. #else
  114. auto it = send_req.find(tag);
  115. if (it == send_req.end())
  116. recv_req.insert(std::pair<Integer, Iterator<char>>(tag, (Iterator<char>)rbuf));
  117. else
  118. memcopy((Iterator<char>)rbuf, it->second, rcount * sizeof(RType));
  119. return nullptr;
  120. #endif
  121. }
  122. inline void Comm::Wait(void* req_ptr) const {
  123. #ifdef SCTL_HAVE_MPI
  124. if (req_ptr == nullptr) return;
  125. Vector<MPI_Request>& request = *(Vector<MPI_Request>*)req_ptr;
  126. // std::vector<MPI_Status> status(request.Dim());
  127. if (request.Dim()) MPI_Waitall(request.Dim(), &request[0], MPI_STATUSES_IGNORE); //&status[0]);
  128. DelReq(&request);
  129. #endif
  130. }
  131. template <class SType, class RType> void Comm::Allgather(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const {
  132. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  133. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  134. #ifdef SCTL_HAVE_MPI
  135. if (scount) {
  136. SCTL_UNUSED(sbuf[0] );
  137. SCTL_UNUSED(sbuf[scount - 1]);
  138. }
  139. if (rcount) {
  140. SCTL_UNUSED(rbuf[0] );
  141. SCTL_UNUSED(rbuf[rcount * Size() - 1]);
  142. }
  143. MPI_Allgather((scount ? &sbuf[0] : nullptr), scount, CommDatatype<SType>::value(), (rcount ? &rbuf[0] : nullptr), rcount, CommDatatype<RType>::value(), mpi_comm_);
  144. #else
  145. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  146. #endif
  147. }
  148. template <class SType, class RType> void Comm::Allgatherv(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const {
  149. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  150. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  151. #ifdef SCTL_HAVE_MPI
  152. Vector<int> rcounts_(mpi_size_), rdispls_(mpi_size_);
  153. Long rcount_sum = 0;
  154. #pragma omp parallel for schedule(static) reduction(+ : rcount_sum)
  155. for (Integer i = 0; i < mpi_size_; i++) {
  156. rcounts_[i] = rcounts[i];
  157. rdispls_[i] = rdispls[i];
  158. rcount_sum += rcounts[i];
  159. }
  160. if (scount) {
  161. SCTL_UNUSED(sbuf[0] );
  162. SCTL_UNUSED(sbuf[scount - 1]);
  163. }
  164. if (rcount_sum) {
  165. SCTL_UNUSED(rbuf[0] );
  166. SCTL_UNUSED(rbuf[rcount_sum - 1]);
  167. }
  168. MPI_Allgatherv((scount ? &sbuf[0] : nullptr), scount, CommDatatype<SType>::value(), (rcount_sum ? &rbuf[0] : nullptr), &rcounts_.begin()[0], &rdispls_.begin()[0], CommDatatype<RType>::value(), mpi_comm_);
  169. #else
  170. memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)sbuf, scount * sizeof(SType));
  171. #endif
  172. }
  173. template <class SType, class RType> void Comm::Alltoall(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const {
  174. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  175. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  176. #ifdef SCTL_HAVE_MPI
  177. if (scount) {
  178. SCTL_UNUSED(sbuf[0] );
  179. SCTL_UNUSED(sbuf[scount * Size() - 1]);
  180. }
  181. if (rcount) {
  182. SCTL_UNUSED(rbuf[0] );
  183. SCTL_UNUSED(rbuf[rcount * Size() - 1]);
  184. }
  185. MPI_Alltoall((scount ? &sbuf[0] : nullptr), scount, CommDatatype<SType>::value(), (rcount ? &rbuf[0] : nullptr), rcount, CommDatatype<RType>::value(), mpi_comm_);
  186. #else
  187. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  188. #endif
  189. }
  190. template <class SType, class RType> void* Comm::Ialltoallv_sparse(ConstIterator<SType> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls, Integer tag) const {
  191. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  192. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  193. #ifdef SCTL_HAVE_MPI
  194. Integer request_count = 0;
  195. for (Integer i = 0; i < mpi_size_; i++) {
  196. if (rcounts[i]) request_count++;
  197. if (scounts[i]) request_count++;
  198. }
  199. if (!request_count) return nullptr;
  200. Vector<MPI_Request>& request = *NewReq();
  201. request.ReInit(request_count);
  202. Integer request_iter = 0;
  203. for (Integer i = 0; i < mpi_size_; i++) {
  204. if (rcounts[i]) {
  205. rbuf[rdispls[i]];
  206. rbuf[rdispls[i] + rcounts[i] - 1];
  207. MPI_Irecv(&rbuf[rdispls[i]], rcounts[i], CommDatatype<RType>::value(), i, tag, mpi_comm_, &request[request_iter]);
  208. request_iter++;
  209. }
  210. }
  211. for (Integer i = 0; i < mpi_size_; i++) {
  212. if (scounts[i]) {
  213. sbuf[sdispls[i]];
  214. sbuf[sdispls[i] + scounts[i] - 1];
  215. MPI_Isend(&sbuf[sdispls[i]], scounts[i], CommDatatype<SType>::value(), i, tag, mpi_comm_, &request[request_iter]);
  216. request_iter++;
  217. }
  218. }
  219. return &request;
  220. #else
  221. memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)(sbuf + sdispls[0]), scounts[0] * sizeof(SType));
  222. return nullptr;
  223. #endif
  224. }
  225. template <class Type> void Comm::Alltoallv(ConstIterator<Type> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<Type> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const {
  226. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  227. #ifdef SCTL_HAVE_MPI
  228. { // Use Alltoallv_sparse of average connectivity<64
  229. Long connectivity = 0, glb_connectivity = 0;
  230. #pragma omp parallel for schedule(static) reduction(+ : connectivity)
  231. for (Integer i = 0; i < mpi_size_; i++) {
  232. if (rcounts[i]) connectivity++;
  233. }
  234. Allreduce(Ptr2ConstItr<Long>(&connectivity, 1), Ptr2Itr<Long>(&glb_connectivity, 1), 1, CommOp::SUM);
  235. if (glb_connectivity < 64 * Size()) {
  236. void* mpi_req = Ialltoallv_sparse(sbuf, scounts, sdispls, rbuf, rcounts, rdispls, 0);
  237. Wait(mpi_req);
  238. return;
  239. }
  240. }
  241. { // Use vendor MPI_Alltoallv
  242. //#ifndef ALLTOALLV_FIX
  243. Vector<int> scnt, sdsp, rcnt, rdsp;
  244. scnt.ReInit(mpi_size_);
  245. sdsp.ReInit(mpi_size_);
  246. rcnt.ReInit(mpi_size_);
  247. rdsp.ReInit(mpi_size_);
  248. Long stotal = 0, rtotal = 0;
  249. #pragma omp parallel for schedule(static) reduction(+ : stotal, rtotal)
  250. for (Integer i = 0; i < mpi_size_; i++) {
  251. scnt[i] = scounts[i];
  252. sdsp[i] = sdispls[i];
  253. rcnt[i] = rcounts[i];
  254. rdsp[i] = rdispls[i];
  255. stotal += scounts[i];
  256. rtotal += rcounts[i];
  257. }
  258. MPI_Alltoallv((stotal ? &sbuf[0] : nullptr), &scnt[0], &sdsp[0], CommDatatype<Type>::value(), (rtotal ? &rbuf[0] : nullptr), &rcnt[0], &rdsp[0], CommDatatype<Type>::value(), mpi_comm_);
  259. return;
  260. //#endif
  261. }
  262. // TODO: implement hypercube scheme
  263. #else
  264. memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)(sbuf + sdispls[0]), scounts[0] * sizeof(Type));
  265. #endif
  266. }
  267. template <class Type> void Comm::Allreduce(ConstIterator<Type> sbuf, Iterator<Type> rbuf, Long count, CommOp op) const {
  268. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  269. #ifdef SCTL_HAVE_MPI
  270. if (!count) return;
  271. MPI_Op mpi_op;
  272. switch (op) {
  273. case CommOp::SUM:
  274. mpi_op = CommDatatype<Type>::sum();
  275. break;
  276. case CommOp::MIN:
  277. mpi_op = CommDatatype<Type>::min();
  278. break;
  279. case CommOp::MAX:
  280. mpi_op = CommDatatype<Type>::max();
  281. break;
  282. default:
  283. mpi_op = MPI_OP_NULL;
  284. break;
  285. }
  286. SCTL_UNUSED(sbuf[0] );
  287. SCTL_UNUSED(sbuf[count - 1]);
  288. SCTL_UNUSED(rbuf[0] );
  289. SCTL_UNUSED(rbuf[count - 1]);
  290. MPI_Allreduce(&sbuf[0], &rbuf[0], count, CommDatatype<Type>::value(), mpi_op, mpi_comm_);
  291. #else
  292. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, count * sizeof(Type));
  293. #endif
  294. }
  295. template <class Type> void Comm::Scan(ConstIterator<Type> sbuf, Iterator<Type> rbuf, int count, CommOp op) const {
  296. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  297. #ifdef SCTL_HAVE_MPI
  298. if (!count) return;
  299. MPI_Op mpi_op;
  300. switch (op) {
  301. case CommOp::SUM:
  302. mpi_op = CommDatatype<Type>::sum();
  303. break;
  304. case CommOp::MIN:
  305. mpi_op = CommDatatype<Type>::min();
  306. break;
  307. case CommOp::MAX:
  308. mpi_op = CommDatatype<Type>::max();
  309. break;
  310. default:
  311. mpi_op = MPI_OP_NULL;
  312. break;
  313. }
  314. SCTL_UNUSED(sbuf[0] );
  315. SCTL_UNUSED(sbuf[count - 1]);
  316. SCTL_UNUSED(rbuf[0] );
  317. SCTL_UNUSED(rbuf[count - 1]);
  318. MPI_Scan(&sbuf[0], &rbuf[0], count, CommDatatype<Type>::value(), mpi_op, mpi_comm_);
  319. #else
  320. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, count * sizeof(Type));
  321. #endif
  322. }
  323. template <class Type> void Comm::PartitionW(Vector<Type>& nodeList, const Vector<Long>* wts_) const {
  324. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  325. Integer npes = Size();
  326. if (npes == 1) return;
  327. Long nlSize = nodeList.Dim();
  328. Vector<Long> wts;
  329. Long localWt = 0;
  330. if (wts_ == nullptr) { // Construct arrays of wts.
  331. wts.ReInit(nlSize);
  332. #pragma omp parallel for schedule(static)
  333. for (Long i = 0; i < nlSize; i++) {
  334. wts[i] = 1;
  335. }
  336. localWt = nlSize;
  337. } else {
  338. wts.ReInit(nlSize, (Iterator<Long>)wts_->begin(), false);
  339. #pragma omp parallel for reduction(+ : localWt)
  340. for (Long i = 0; i < nlSize; i++) {
  341. localWt += wts[i];
  342. }
  343. }
  344. Long off1 = 0, off2 = 0, totalWt = 0;
  345. { // compute the total weight of the problem ...
  346. Allreduce<Long>(Ptr2ConstItr<Long>(&localWt, 1), Ptr2Itr<Long>(&totalWt, 1), 1, CommOp::SUM);
  347. Scan<Long>(Ptr2ConstItr<Long>(&localWt, 1), Ptr2Itr<Long>(&off2, 1), 1, CommOp::SUM);
  348. off1 = off2 - localWt;
  349. }
  350. Vector<Long> lscn;
  351. if (nlSize) { // perform a local scan on the weights first ...
  352. lscn.ReInit(nlSize);
  353. lscn[0] = off1;
  354. omp_par::scan(wts.begin(), lscn.begin(), nlSize);
  355. }
  356. Vector<Long> sendSz, recvSz, sendOff, recvOff;
  357. sendSz.ReInit(npes);
  358. recvSz.ReInit(npes);
  359. sendOff.ReInit(npes);
  360. recvOff.ReInit(npes);
  361. sendSz.SetZero();
  362. if (nlSize > 0 && totalWt > 0) { // Compute sendSz
  363. Long pid1 = (off1 * npes) / totalWt;
  364. Long pid2 = ((off2 + 1) * npes) / totalWt + 1;
  365. assert((totalWt * pid2) / npes >= off2);
  366. pid1 = (pid1 < 0 ? 0 : pid1);
  367. pid2 = (pid2 > npes ? npes : pid2);
  368. #pragma omp parallel for schedule(static)
  369. for (Integer i = pid1; i < pid2; i++) {
  370. Long wt1 = (totalWt * (i)) / npes;
  371. Long wt2 = (totalWt * (i + 1)) / npes;
  372. Long start = std::lower_bound(lscn.begin(), lscn.begin() + nlSize, wt1, std::less<Long>()) - lscn.begin();
  373. Long end = std::lower_bound(lscn.begin(), lscn.begin() + nlSize, wt2, std::less<Long>()) - lscn.begin();
  374. if (i == 0) start = 0;
  375. if (i == npes - 1) end = nlSize;
  376. sendSz[i] = end - start;
  377. }
  378. } else {
  379. sendSz[0] = nlSize;
  380. }
  381. // Exchange sendSz, recvSz
  382. Alltoall<Long>(sendSz.begin(), 1, recvSz.begin(), 1);
  383. { // Compute sendOff, recvOff
  384. sendOff[0] = 0;
  385. omp_par::scan(sendSz.begin(), sendOff.begin(), npes);
  386. recvOff[0] = 0;
  387. omp_par::scan(recvSz.begin(), recvOff.begin(), npes);
  388. assert(sendOff[npes - 1] + sendSz[npes - 1] == nlSize);
  389. }
  390. // perform All2All ...
  391. Vector<Type> newNodes;
  392. newNodes.ReInit(recvSz[npes - 1] + recvOff[npes - 1]);
  393. void* mpi_req = Ialltoallv_sparse<Type>(nodeList.begin(), sendSz.begin(), sendOff.begin(), newNodes.begin(), recvSz.begin(), recvOff.begin());
  394. Wait(mpi_req);
  395. // reset the pointer ...
  396. nodeList.Swap(newNodes);
  397. }
  398. template <class Type> void Comm::PartitionN(Vector<Type>& v, Long N) const {
  399. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  400. Integer rank = Rank();
  401. Integer np = Size();
  402. if (np == 1) return;
  403. Vector<Long> v_cnt(np), v_dsp(np + 1);
  404. Vector<Long> N_cnt(np), N_dsp(np + 1);
  405. { // Set v_cnt, v_dsp
  406. v_dsp[0] = 0;
  407. Long cnt = v.Dim();
  408. Allgather(Ptr2ConstItr<Long>(&cnt, 1), 1, v_cnt.begin(), 1);
  409. omp_par::scan(v_cnt.begin(), v_dsp.begin(), np);
  410. v_dsp[np] = v_cnt[np - 1] + v_dsp[np - 1];
  411. }
  412. { // Set N_cnt, N_dsp
  413. N_dsp[0] = 0;
  414. Long cnt = N;
  415. Allgather(Ptr2ConstItr<Long>(&cnt, 1), 1, N_cnt.begin(), 1);
  416. omp_par::scan(N_cnt.begin(), N_dsp.begin(), np);
  417. N_dsp[np] = N_cnt[np - 1] + N_dsp[np - 1];
  418. }
  419. { // Adjust for dof
  420. Long dof = (N_dsp[np] ? v_dsp[np] / N_dsp[np] : 0);
  421. assert(dof * N_dsp[np] == v_dsp[np]);
  422. if (dof == 0) return;
  423. if (dof != 1) {
  424. #pragma omp parallel for schedule(static)
  425. for (Integer i = 0; i < np; i++) N_cnt[i] *= dof;
  426. #pragma omp parallel for schedule(static)
  427. for (Integer i = 0; i <= np; i++) N_dsp[i] *= dof;
  428. }
  429. }
  430. Vector<Type> v_(N_cnt[rank]);
  431. { // Set v_
  432. Vector<Long> scnt(np), sdsp(np);
  433. Vector<Long> rcnt(np), rdsp(np);
  434. #pragma omp parallel for schedule(static)
  435. for (Integer i = 0; i < np; i++) {
  436. { // Set scnt
  437. Long n0 = N_dsp[i + 0];
  438. Long n1 = N_dsp[i + 1];
  439. if (n0 < v_dsp[rank + 0]) n0 = v_dsp[rank + 0];
  440. if (n1 < v_dsp[rank + 0]) n1 = v_dsp[rank + 0];
  441. if (n0 > v_dsp[rank + 1]) n0 = v_dsp[rank + 1];
  442. if (n1 > v_dsp[rank + 1]) n1 = v_dsp[rank + 1];
  443. scnt[i] = n1 - n0;
  444. }
  445. { // Set rcnt
  446. Long n0 = v_dsp[i + 0];
  447. Long n1 = v_dsp[i + 1];
  448. if (n0 < N_dsp[rank + 0]) n0 = N_dsp[rank + 0];
  449. if (n1 < N_dsp[rank + 0]) n1 = N_dsp[rank + 0];
  450. if (n0 > N_dsp[rank + 1]) n0 = N_dsp[rank + 1];
  451. if (n1 > N_dsp[rank + 1]) n1 = N_dsp[rank + 1];
  452. rcnt[i] = n1 - n0;
  453. }
  454. }
  455. sdsp[0] = 0;
  456. omp_par::scan(scnt.begin(), sdsp.begin(), np);
  457. rdsp[0] = 0;
  458. omp_par::scan(rcnt.begin(), rdsp.begin(), np);
  459. void* mpi_request = Ialltoallv_sparse(v.begin(), scnt.begin(), sdsp.begin(), v_.begin(), rcnt.begin(), rdsp.begin());
  460. Wait(mpi_request);
  461. }
  462. v.Swap(v_);
  463. }
  464. template <class Type> void Comm::PartitionS(Vector<Type>& nodeList, const Type& splitter) const {
  465. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  466. Integer npes = Size();
  467. if (npes == 1) return;
  468. Vector<Type> mins(npes);
  469. Allgather(Ptr2ConstItr<Type>(&splitter, 1), 1, mins.begin(), 1);
  470. Vector<Long> scnt(npes), sdsp(npes);
  471. Vector<Long> rcnt(npes), rdsp(npes);
  472. { // Compute scnt, sdsp
  473. #pragma omp parallel for schedule(static)
  474. for (Integer i = 0; i < npes; i++) {
  475. sdsp[i] = std::lower_bound(nodeList.begin(), nodeList.begin() + nodeList.Dim(), mins[i]) - nodeList.begin();
  476. }
  477. #pragma omp parallel for schedule(static)
  478. for (Integer i = 0; i < npes - 1; i++) {
  479. scnt[i] = sdsp[i + 1] - sdsp[i];
  480. }
  481. scnt[npes - 1] = nodeList.Dim() - sdsp[npes - 1];
  482. }
  483. { // Compute rcnt, rdsp
  484. rdsp[0] = 0;
  485. Alltoall(scnt.begin(), 1, rcnt.begin(), 1);
  486. omp_par::scan(rcnt.begin(), rdsp.begin(), npes);
  487. }
  488. { // Redistribute nodeList
  489. Vector<Type> nodeList_(rdsp[npes - 1] + rcnt[npes - 1]);
  490. void* mpi_request = Ialltoallv_sparse(nodeList.begin(), scnt.begin(), sdsp.begin(), nodeList_.begin(), rcnt.begin(), rdsp.begin());
  491. Wait(mpi_request);
  492. nodeList.Swap(nodeList_);
  493. }
  494. }
  495. template <class Type> void Comm::SortScatterIndex(const Vector<Type>& key, Vector<Long>& scatter_index, const Type* split_key_) const {
  496. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  497. typedef SortPair<Type, Long> Pair_t;
  498. Integer npes = Size(), rank = Rank();
  499. Vector<Pair_t> parray(key.Dim());
  500. { // Build global index.
  501. Long glb_dsp = 0;
  502. Long loc_size = key.Dim();
  503. Scan(Ptr2ConstItr<Long>(&loc_size, 1), Ptr2Itr<Long>(&glb_dsp, 1), 1, CommOp::SUM);
  504. glb_dsp -= loc_size;
  505. #pragma omp parallel for schedule(static)
  506. for (Long i = 0; i < loc_size; i++) {
  507. parray[i].key = key[i];
  508. parray[i].data = glb_dsp + i;
  509. }
  510. }
  511. Vector<Pair_t> psorted;
  512. HyperQuickSort(parray, psorted);
  513. if (npes > 1 && split_key_ != nullptr) { // Partition data
  514. Vector<Type> split_key(npes);
  515. Allgather(Ptr2ConstItr<Type>(split_key_, 1), 1, split_key.begin(), 1);
  516. Vector<Long> sendSz(npes);
  517. Vector<Long> recvSz(npes);
  518. Vector<Long> sendOff(npes);
  519. Vector<Long> recvOff(npes);
  520. Long nlSize = psorted.Dim();
  521. sendSz.SetZero();
  522. if (nlSize > 0) { // Compute sendSz
  523. // Determine processor range.
  524. Long pid1 = std::lower_bound(split_key.begin(), split_key.begin() + npes, psorted[0].key) - split_key.begin() - 1;
  525. Long pid2 = std::upper_bound(split_key.begin(), split_key.begin() + npes, psorted[nlSize - 1].key) - split_key.begin() + 1;
  526. pid1 = (pid1 < 0 ? 0 : pid1);
  527. pid2 = (pid2 > npes ? npes : pid2);
  528. #pragma omp parallel for schedule(static)
  529. for (Integer i = pid1; i < pid2; i++) {
  530. Pair_t p1;
  531. p1.key = split_key[i];
  532. Pair_t p2;
  533. p2.key = split_key[i + 1 < npes ? i + 1 : i];
  534. Long start = std::lower_bound(psorted.begin(), psorted.begin() + nlSize, p1, std::less<Pair_t>()) - psorted.begin();
  535. Long end = std::lower_bound(psorted.begin(), psorted.begin() + nlSize, p2, std::less<Pair_t>()) - psorted.begin();
  536. if (i == 0) start = 0;
  537. if (i == npes - 1) end = nlSize;
  538. sendSz[i] = end - start;
  539. }
  540. }
  541. // Exchange sendSz, recvSz
  542. Alltoall<Long>(sendSz.begin(), 1, recvSz.begin(), 1);
  543. // compute offsets ...
  544. { // Compute sendOff, recvOff
  545. sendOff[0] = 0;
  546. omp_par::scan(sendSz.begin(), sendOff.begin(), npes);
  547. recvOff[0] = 0;
  548. omp_par::scan(recvSz.begin(), recvOff.begin(), npes);
  549. assert(sendOff[npes - 1] + sendSz[npes - 1] == nlSize);
  550. }
  551. // perform All2All ...
  552. Vector<Pair_t> newNodes(recvSz[npes - 1] + recvOff[npes - 1]);
  553. void* mpi_req = Ialltoallv_sparse<Pair_t>(psorted.begin(), sendSz.begin(), sendOff.begin(), newNodes.begin(), recvSz.begin(), recvOff.begin());
  554. Wait(mpi_req);
  555. // reset the pointer ...
  556. psorted.Swap(newNodes);
  557. }
  558. scatter_index.ReInit(psorted.Dim());
  559. #pragma omp parallel for schedule(static)
  560. for (Long i = 0; i < psorted.Dim(); i++) {
  561. scatter_index[i] = psorted[i].data;
  562. }
  563. }
  564. template <class Type> void Comm::ScatterForward(Vector<Type>& data_, const Vector<Long>& scatter_index) const {
  565. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  566. typedef SortPair<Long, Long> Pair_t;
  567. Integer npes = Size(), rank = Rank();
  568. Long data_dim = 0;
  569. Long send_size = 0;
  570. Long recv_size = 0;
  571. { // Set data_dim, send_size, recv_size
  572. recv_size = scatter_index.Dim();
  573. StaticArray<Long, 2> glb_size;
  574. StaticArray<Long, 2> loc_size;
  575. loc_size[0] = data_.Dim();
  576. loc_size[1] = recv_size;
  577. Allreduce<Long>(loc_size, glb_size, 2, CommOp::SUM);
  578. if (glb_size[0] == 0 || glb_size[1] == 0) return; // Nothing to be done.
  579. data_dim = glb_size[0] / glb_size[1];
  580. SCTL_ASSERT(glb_size[0] == data_dim * glb_size[1]);
  581. send_size = data_.Dim() / data_dim;
  582. }
  583. if (npes == 1) { // Scatter directly
  584. Vector<Type> data;
  585. data.ReInit(recv_size * data_dim);
  586. #pragma omp parallel for schedule(static)
  587. for (Long i = 0; i < recv_size; i++) {
  588. Long src_indx = scatter_index[i] * data_dim;
  589. Long trg_indx = i * data_dim;
  590. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = data_[src_indx + j];
  591. }
  592. data_.Swap(data);
  593. return;
  594. }
  595. Vector<Long> glb_scan;
  596. { // Global scan of data size.
  597. glb_scan.ReInit(npes);
  598. Long glb_rank = 0;
  599. Scan(Ptr2ConstItr<Long>(&send_size, 1), Ptr2Itr<Long>(&glb_rank, 1), 1, CommOp::SUM);
  600. glb_rank -= send_size;
  601. Allgather(Ptr2ConstItr<Long>(&glb_rank, 1), 1, glb_scan.begin(), 1);
  602. }
  603. Vector<Pair_t> psorted;
  604. { // Sort scatter_index.
  605. psorted.ReInit(recv_size);
  606. #pragma omp parallel for schedule(static)
  607. for (Long i = 0; i < recv_size; i++) {
  608. psorted[i].key = scatter_index[i];
  609. psorted[i].data = i;
  610. }
  611. omp_par::merge_sort(psorted.begin(), psorted.begin() + recv_size);
  612. }
  613. Vector<Long> recv_indx(recv_size);
  614. Vector<Long> send_indx(send_size);
  615. Vector<Long> sendSz(npes);
  616. Vector<Long> sendOff(npes);
  617. Vector<Long> recvSz(npes);
  618. Vector<Long> recvOff(npes);
  619. { // Exchange send, recv indices.
  620. #pragma omp parallel for schedule(static)
  621. for (Long i = 0; i < recv_size; i++) {
  622. recv_indx[i] = psorted[i].key;
  623. }
  624. #pragma omp parallel for schedule(static)
  625. for (Integer i = 0; i < npes; i++) {
  626. Long start = std::lower_bound(recv_indx.begin(), recv_indx.begin() + recv_size, glb_scan[i]) - recv_indx.begin();
  627. Long end = (i + 1 < npes ? std::lower_bound(recv_indx.begin(), recv_indx.begin() + recv_size, glb_scan[i + 1]) - recv_indx.begin() : recv_size);
  628. recvSz[i] = end - start;
  629. recvOff[i] = start;
  630. }
  631. Alltoall(recvSz.begin(), 1, sendSz.begin(), 1);
  632. sendOff[0] = 0;
  633. omp_par::scan(sendSz.begin(), sendOff.begin(), npes);
  634. assert(sendOff[npes - 1] + sendSz[npes - 1] == send_size);
  635. Alltoallv(recv_indx.begin(), recvSz.begin(), recvOff.begin(), send_indx.begin(), sendSz.begin(), sendOff.begin());
  636. #pragma omp parallel for schedule(static)
  637. for (Long i = 0; i < send_size; i++) {
  638. assert(send_indx[i] >= glb_scan[rank]);
  639. send_indx[i] -= glb_scan[rank];
  640. assert(send_indx[i] < send_size);
  641. }
  642. }
  643. Vector<Type> send_buff;
  644. { // Prepare send buffer
  645. send_buff.ReInit(send_size * data_dim);
  646. ConstIterator<Type> data = data_.begin();
  647. #pragma omp parallel for schedule(static)
  648. for (Long i = 0; i < send_size; i++) {
  649. Long src_indx = send_indx[i] * data_dim;
  650. Long trg_indx = i * data_dim;
  651. for (Long j = 0; j < data_dim; j++) send_buff[trg_indx + j] = data[src_indx + j];
  652. }
  653. }
  654. Vector<Type> recv_buff;
  655. { // All2Allv
  656. recv_buff.ReInit(recv_size * data_dim);
  657. #pragma omp parallel for schedule(static)
  658. for (Integer i = 0; i < npes; i++) {
  659. sendSz[i] *= data_dim;
  660. sendOff[i] *= data_dim;
  661. recvSz[i] *= data_dim;
  662. recvOff[i] *= data_dim;
  663. }
  664. Alltoallv(send_buff.begin(), sendSz.begin(), sendOff.begin(), recv_buff.begin(), recvSz.begin(), recvOff.begin());
  665. }
  666. { // Build output data.
  667. data_.ReInit(recv_size * data_dim);
  668. Iterator<Type> data = data_.begin();
  669. #pragma omp parallel for schedule(static)
  670. for (Long i = 0; i < recv_size; i++) {
  671. Long src_indx = i * data_dim;
  672. Long trg_indx = psorted[i].data * data_dim;
  673. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = recv_buff[src_indx + j];
  674. }
  675. }
  676. }
  677. template <class Type> void Comm::ScatterReverse(Vector<Type>& data_, const Vector<Long>& scatter_index_, Long loc_size_) const {
  678. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  679. typedef SortPair<Long, Long> Pair_t;
  680. Integer npes = Size(), rank = Rank();
  681. Long data_dim = 0;
  682. Long send_size = 0;
  683. Long recv_size = 0;
  684. { // Set data_dim, send_size, recv_size
  685. recv_size = loc_size_;
  686. StaticArray<Long, 3> glb_size;
  687. StaticArray<Long, 3> loc_size;
  688. loc_size[0] = data_.Dim();
  689. loc_size[1] = scatter_index_.Dim();
  690. loc_size[2] = recv_size;
  691. Allreduce<Long>(loc_size, glb_size, 3, CommOp::SUM);
  692. if (glb_size[0] == 0 || glb_size[1] == 0) return; // Nothing to be done.
  693. SCTL_ASSERT(glb_size[0] % glb_size[1] == 0);
  694. data_dim = glb_size[0] / glb_size[1];
  695. SCTL_ASSERT(loc_size[0] % data_dim == 0);
  696. send_size = loc_size[0] / data_dim;
  697. if (glb_size[0] != glb_size[2] * data_dim) {
  698. recv_size = (((rank + 1) * (glb_size[0] / data_dim)) / npes) - ((rank * (glb_size[0] / data_dim)) / npes);
  699. }
  700. }
  701. if (npes == 1) { // Scatter directly
  702. Vector<Type> data;
  703. data.ReInit(recv_size * data_dim);
  704. #pragma omp parallel for schedule(static)
  705. for (Long i = 0; i < recv_size; i++) {
  706. Long src_indx = i * data_dim;
  707. Long trg_indx = scatter_index_[i] * data_dim;
  708. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = data_[src_indx + j];
  709. }
  710. data_.Swap(data);
  711. return;
  712. }
  713. Vector<Long> scatter_index;
  714. {
  715. StaticArray<Long, 2> glb_rank;
  716. StaticArray<Long, 3> glb_size;
  717. StaticArray<Long, 2> loc_size;
  718. loc_size[0] = data_.Dim() / data_dim;
  719. loc_size[1] = scatter_index_.Dim();
  720. Scan<Long>(loc_size, glb_rank, 2, CommOp::SUM);
  721. Allreduce<Long>(loc_size, glb_size, 2, CommOp::SUM);
  722. SCTL_ASSERT(glb_size[0] == glb_size[1]);
  723. glb_rank[0] -= loc_size[0];
  724. glb_rank[1] -= loc_size[1];
  725. Vector<Long> glb_scan0(npes + 1);
  726. Vector<Long> glb_scan1(npes + 1);
  727. Allgather<Long>(glb_rank + 0, 1, glb_scan0.begin(), 1);
  728. Allgather<Long>(glb_rank + 1, 1, glb_scan1.begin(), 1);
  729. glb_scan0[npes] = glb_size[0];
  730. glb_scan1[npes] = glb_size[1];
  731. if (loc_size[0] != loc_size[1] || glb_rank[0] != glb_rank[1]) { // Repartition scatter_index
  732. scatter_index.ReInit(loc_size[0]);
  733. Vector<Long> send_dsp(npes + 1);
  734. Vector<Long> recv_dsp(npes + 1);
  735. #pragma omp parallel for schedule(static)
  736. for (Integer i = 0; i <= npes; i++) {
  737. send_dsp[i] = std::min(std::max(glb_scan0[i], glb_rank[1]), glb_rank[1] + loc_size[1]) - glb_rank[1];
  738. recv_dsp[i] = std::min(std::max(glb_scan1[i], glb_rank[0]), glb_rank[0] + loc_size[0]) - glb_rank[0];
  739. }
  740. // Long commCnt=0;
  741. Vector<Long> send_cnt(npes + 0);
  742. Vector<Long> recv_cnt(npes + 0);
  743. #pragma omp parallel for schedule(static) // reduction(+:commCnt)
  744. for (Integer i = 0; i < npes; i++) {
  745. send_cnt[i] = send_dsp[i + 1] - send_dsp[i];
  746. recv_cnt[i] = recv_dsp[i + 1] - recv_dsp[i];
  747. // if(send_cnt[i] && i!=rank) commCnt++;
  748. // if(recv_cnt[i] && i!=rank) commCnt++;
  749. }
  750. void* mpi_req = Ialltoallv_sparse<Long>(scatter_index_.begin(), send_cnt.begin(), send_dsp.begin(), scatter_index.begin(), recv_cnt.begin(), recv_dsp.begin(), 0);
  751. Wait(mpi_req);
  752. } else {
  753. scatter_index.ReInit(scatter_index_.Dim(), (Iterator<Long>)scatter_index_.begin(), false);
  754. }
  755. }
  756. Vector<Long> glb_scan(npes);
  757. { // Global data size.
  758. Long glb_rank = 0;
  759. Scan(Ptr2ConstItr<Long>(&recv_size, 1), Ptr2Itr<Long>(&glb_rank, 1), 1, CommOp::SUM);
  760. glb_rank -= recv_size;
  761. Allgather(Ptr2ConstItr<Long>(&glb_rank, 1), 1, glb_scan.begin(), 1);
  762. }
  763. Vector<Pair_t> psorted(send_size);
  764. { // Sort scatter_index.
  765. #pragma omp parallel for schedule(static)
  766. for (Long i = 0; i < send_size; i++) {
  767. psorted[i].key = scatter_index[i];
  768. psorted[i].data = i;
  769. }
  770. omp_par::merge_sort(psorted.begin(), psorted.begin() + send_size);
  771. }
  772. Vector<Long> recv_indx(recv_size);
  773. Vector<Long> send_indx(send_size);
  774. Vector<Long> sendSz(npes);
  775. Vector<Long> sendOff(npes);
  776. Vector<Long> recvSz(npes);
  777. Vector<Long> recvOff(npes);
  778. { // Exchange send, recv indices.
  779. #pragma omp parallel for schedule(static)
  780. for (Long i = 0; i < send_size; i++) {
  781. send_indx[i] = psorted[i].key;
  782. }
  783. #pragma omp parallel for schedule(static)
  784. for (Integer i = 0; i < npes; i++) {
  785. Long start = std::lower_bound(send_indx.begin(), send_indx.begin() + send_size, glb_scan[i]) - send_indx.begin();
  786. Long end = (i + 1 < npes ? std::lower_bound(send_indx.begin(), send_indx.begin() + send_size, glb_scan[i + 1]) - send_indx.begin() : send_size);
  787. sendSz[i] = end - start;
  788. sendOff[i] = start;
  789. }
  790. Alltoall(sendSz.begin(), 1, recvSz.begin(), 1);
  791. recvOff[0] = 0;
  792. omp_par::scan(recvSz.begin(), recvOff.begin(), npes);
  793. assert(recvOff[npes - 1] + recvSz[npes - 1] == recv_size);
  794. Alltoallv(send_indx.begin(), sendSz.begin(), sendOff.begin(), recv_indx.begin(), recvSz.begin(), recvOff.begin());
  795. #pragma omp parallel for schedule(static)
  796. for (Long i = 0; i < recv_size; i++) {
  797. assert(recv_indx[i] >= glb_scan[rank]);
  798. recv_indx[i] -= glb_scan[rank];
  799. assert(recv_indx[i] < recv_size);
  800. }
  801. }
  802. Vector<Type> send_buff;
  803. { // Prepare send buffer
  804. send_buff.ReInit(send_size * data_dim);
  805. ConstIterator<Type> data = data_.begin();
  806. #pragma omp parallel for schedule(static)
  807. for (Long i = 0; i < send_size; i++) {
  808. Long src_indx = psorted[i].data * data_dim;
  809. Long trg_indx = i * data_dim;
  810. for (Long j = 0; j < data_dim; j++) send_buff[trg_indx + j] = data[src_indx + j];
  811. }
  812. }
  813. Vector<Type> recv_buff;
  814. { // All2Allv
  815. recv_buff.ReInit(recv_size * data_dim);
  816. #pragma omp parallel for schedule(static)
  817. for (Integer i = 0; i < npes; i++) {
  818. sendSz[i] *= data_dim;
  819. sendOff[i] *= data_dim;
  820. recvSz[i] *= data_dim;
  821. recvOff[i] *= data_dim;
  822. }
  823. Alltoallv(send_buff.begin(), sendSz.begin(), sendOff.begin(), recv_buff.begin(), recvSz.begin(), recvOff.begin());
  824. }
  825. { // Build output data.
  826. data_.ReInit(recv_size * data_dim);
  827. Iterator<Type> data = data_.begin();
  828. #pragma omp parallel for schedule(static)
  829. for (Long i = 0; i < recv_size; i++) {
  830. Long src_indx = i * data_dim;
  831. Long trg_indx = recv_indx[i] * data_dim;
  832. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = recv_buff[src_indx + j];
  833. }
  834. }
  835. }
  836. #ifdef SCTL_HAVE_MPI
  837. inline Vector<MPI_Request>* Comm::NewReq() const {
  838. if (req.empty()) req.push(new Vector<MPI_Request>);
  839. Vector<MPI_Request>& request = *(Vector<MPI_Request>*)req.top();
  840. req.pop();
  841. return &request;
  842. }
  843. inline void Comm::Init(const MPI_Comm mpi_comm) {
  844. #pragma omp critical(SCTL_COMM_DUP)
  845. MPI_Comm_dup(mpi_comm, &mpi_comm_);
  846. MPI_Comm_rank(mpi_comm_, &mpi_rank_);
  847. MPI_Comm_size(mpi_comm_, &mpi_size_);
  848. }
  849. inline void Comm::DelReq(Vector<MPI_Request>* req_ptr) const {
  850. if (req_ptr) req.push(req_ptr);
  851. }
  852. #define HS_MPIDATATYPE(CTYPE, MPITYPE) \
  853. template <> class Comm::CommDatatype<CTYPE> { \
  854. public: \
  855. static MPI_Datatype value() { return MPITYPE; } \
  856. static MPI_Op sum() { return MPI_SUM; } \
  857. static MPI_Op min() { return MPI_MIN; } \
  858. static MPI_Op max() { return MPI_MAX; } \
  859. }
  860. HS_MPIDATATYPE(short, MPI_SHORT);
  861. HS_MPIDATATYPE(int, MPI_INT);
  862. HS_MPIDATATYPE(long, MPI_LONG);
  863. HS_MPIDATATYPE(unsigned short, MPI_UNSIGNED_SHORT);
  864. HS_MPIDATATYPE(unsigned int, MPI_UNSIGNED);
  865. HS_MPIDATATYPE(unsigned long, MPI_UNSIGNED_LONG);
  866. HS_MPIDATATYPE(float, MPI_FLOAT);
  867. HS_MPIDATATYPE(double, MPI_DOUBLE);
  868. HS_MPIDATATYPE(long double, MPI_LONG_DOUBLE);
  869. HS_MPIDATATYPE(long long, MPI_LONG_LONG_INT);
  870. HS_MPIDATATYPE(char, MPI_CHAR);
  871. HS_MPIDATATYPE(unsigned char, MPI_UNSIGNED_CHAR);
  872. #undef HS_MPIDATATYPE
  873. #endif
  874. template <class Type> void Comm::HyperQuickSort(const Vector<Type>& arr_, Vector<Type>& SortedElem) const { // O( ((N/p)+log(p))*(log(N/p)+log(p)) )
  875. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  876. #ifdef SCTL_HAVE_MPI
  877. Integer npes, myrank, omp_p;
  878. { // Get comm size and rank.
  879. npes = Size();
  880. myrank = Rank();
  881. omp_p = omp_get_max_threads();
  882. }
  883. srand(myrank);
  884. Long totSize, nelem = arr_.Dim();
  885. { // Local and global sizes. O(log p)
  886. Allreduce<Long>(Ptr2ConstItr<Long>(&nelem, 1), Ptr2Itr<Long>(&totSize, 1), 1, CommOp::SUM);
  887. }
  888. if (npes == 1) { // SortedElem <--- local_sort(arr_)
  889. SortedElem = arr_;
  890. omp_par::merge_sort(SortedElem.begin(), SortedElem.begin() + nelem);
  891. return;
  892. }
  893. Vector<Type> arr;
  894. { // arr <-- local_sort(arr_)
  895. arr = arr_;
  896. omp_par::merge_sort(arr.begin(), arr.begin() + nelem);
  897. }
  898. Vector<Type> nbuff, nbuff_ext, rbuff, rbuff_ext; // Allocate memory.
  899. MPI_Comm comm = mpi_comm_; // Copy comm
  900. bool free_comm = false; // Flag to free comm.
  901. // Binary split and merge in each iteration.
  902. while (npes > 1 && totSize > 0) { // O(log p) iterations.
  903. Type split_key;
  904. Long totSize_new;
  905. { // Determine split_key. O( log(N/p) + log(p) )
  906. Integer glb_splt_count;
  907. Vector<Type> glb_splitters;
  908. { // Take random splitters. glb_splt_count = const = 100~1000
  909. Integer splt_count;
  910. { // Set splt_coun. O( 1 ) -- Let p * splt_count = t
  911. splt_count = (100 * nelem) / totSize;
  912. if (npes > 100) splt_count = (drand48() * totSize) < (100 * nelem) ? 1 : 0;
  913. if (splt_count > nelem) splt_count = nelem;
  914. MPI_Allreduce (&splt_count, &glb_splt_count, 1, CommDatatype<Integer>::value(), CommDatatype<Integer>::sum(), comm);
  915. if (!glb_splt_count) splt_count = std::min<Long>(1, nelem);
  916. MPI_Allreduce (&splt_count, &glb_splt_count, 1, CommDatatype<Integer>::value(), CommDatatype<Integer>::sum(), comm);
  917. SCTL_ASSERT(glb_splt_count);
  918. }
  919. Vector<Type> splitters(splt_count);
  920. for (Integer i = 0; i < splt_count; i++) {
  921. splitters[i] = arr[rand() % nelem];
  922. }
  923. Vector<Integer> glb_splt_cnts(npes), glb_splt_disp(npes);
  924. { // Set glb_splt_cnts, glb_splt_disp
  925. MPI_Allgather(&splt_count, 1, CommDatatype<Integer>::value(), &glb_splt_cnts[0], 1, CommDatatype<Integer>::value(), comm);
  926. glb_splt_disp[0] = 0;
  927. omp_par::scan(glb_splt_cnts.begin(), glb_splt_disp.begin(), npes);
  928. SCTL_ASSERT(glb_splt_count == glb_splt_cnts[npes - 1] + glb_splt_disp[npes - 1]);
  929. }
  930. { // Gather all splitters. O( log(p) )
  931. glb_splitters.ReInit(glb_splt_count);
  932. Vector<int> glb_splt_cnts_(npes), glb_splt_disp_(npes);
  933. for (Integer i = 0; i < npes; i++) {
  934. glb_splt_cnts_[i] = glb_splt_cnts[i];
  935. glb_splt_disp_[i] = glb_splt_disp[i];
  936. }
  937. MPI_Allgatherv((splt_count ? &splitters[0] : nullptr), splt_count, CommDatatype<Type>::value(), &glb_splitters[0], &glb_splt_cnts_[0], &glb_splt_disp_[0], CommDatatype<Type>::value(), comm);
  938. }
  939. }
  940. // Determine split key. O( log(N/p) + log(p) )
  941. Vector<Long> lrank(glb_splt_count);
  942. { // Compute local rank
  943. #pragma omp parallel for schedule(static)
  944. for (Integer i = 0; i < glb_splt_count; i++) {
  945. lrank[i] = std::lower_bound(arr.begin(), arr.begin() + nelem, glb_splitters[i]) - arr.begin();
  946. }
  947. }
  948. Vector<Long> grank(glb_splt_count);
  949. { // Compute global rank
  950. MPI_Allreduce(&lrank[0], &grank[0], glb_splt_count, CommDatatype<Long>::value(), CommDatatype<Long>::sum(), comm);
  951. }
  952. { // Determine split_key, totSize_new
  953. ConstIterator<Long> split_disp = grank.begin();
  954. for (Integer i = 0; i < glb_splt_count; i++) {
  955. if (labs(grank[i] - totSize / 2) < labs(*split_disp - totSize / 2)) {
  956. split_disp = grank.begin() + i;
  957. }
  958. }
  959. split_key = glb_splitters[split_disp - grank.begin()];
  960. if (myrank <= (npes - 1) / 2)
  961. totSize_new = split_disp[0];
  962. else
  963. totSize_new = totSize - split_disp[0];
  964. // double err=(((double)*split_disp)/(totSize/2))-1.0;
  965. // if(fabs<double>(err)<0.01 || npes<=16) break;
  966. // else if(!myrank) std::cout<<err<<'\n';
  967. }
  968. }
  969. Integer split_id = (npes - 1) / 2;
  970. { // Split problem into two. O( N/p )
  971. Integer new_p0 = (myrank <= split_id ? 0 : split_id + 1);
  972. Integer cmp_p0 = (myrank > split_id ? 0 : split_id + 1);
  973. Integer partner;
  974. { // Set partner
  975. partner = myrank + cmp_p0 - new_p0;
  976. if (partner >= npes) partner = npes - 1;
  977. assert(partner >= 0);
  978. }
  979. bool extra_partner = (npes % 2 == 1 && npes - 1 == myrank);
  980. Long ssize = 0, lsize = 0;
  981. ConstIterator<Type> sbuff, lbuff;
  982. { // Set ssize, lsize, sbuff, lbuff
  983. Long split_indx = std::lower_bound(arr.begin(), arr.begin() + nelem, split_key) - arr.begin();
  984. ssize = (myrank > split_id ? split_indx : nelem - split_indx);
  985. sbuff = (myrank > split_id ? arr.begin() : arr.begin() + split_indx);
  986. lsize = (myrank <= split_id ? split_indx : nelem - split_indx);
  987. lbuff = (myrank <= split_id ? arr.begin() : arr.begin() + split_indx);
  988. }
  989. Long rsize = 0, ext_rsize = 0;
  990. { // Get rsize, ext_rsize
  991. Long ext_ssize = 0;
  992. MPI_Status status;
  993. MPI_Sendrecv(&ssize, 1, CommDatatype<Long>::value(), partner, 0, &rsize, 1, CommDatatype<Long>::value(), partner, 0, comm, &status);
  994. if (extra_partner) MPI_Sendrecv(&ext_ssize, 1, CommDatatype<Long>::value(), split_id, 0, &ext_rsize, 1, CommDatatype<Long>::value(), split_id, 0, comm, &status);
  995. }
  996. { // Exchange data.
  997. rbuff.ReInit(rsize);
  998. rbuff_ext.ReInit(ext_rsize);
  999. MPI_Status status;
  1000. MPI_Sendrecv((ssize ? &sbuff[0] : nullptr), ssize, CommDatatype<Type>::value(), partner, 0, (rsize ? &rbuff[0] : nullptr), rsize, CommDatatype<Type>::value(), partner, 0, comm, &status);
  1001. if (extra_partner) MPI_Sendrecv(nullptr, 0, CommDatatype<Type>::value(), split_id, 0, (ext_rsize ? &rbuff_ext[0] : nullptr), ext_rsize, CommDatatype<Type>::value(), split_id, 0, comm, &status);
  1002. }
  1003. Long nbuff_size = lsize + rsize + ext_rsize;
  1004. { // nbuff <-- merge(lbuff, rbuff, rbuff_ext)
  1005. nbuff.ReInit(lsize + rsize);
  1006. omp_par::merge<ConstIterator<Type>>(lbuff, (lbuff + lsize), rbuff.begin(), rbuff.begin() + rsize, nbuff.begin(), omp_p, std::less<Type>());
  1007. if (ext_rsize > 0 && nbuff.Dim() > 0) {
  1008. nbuff_ext.ReInit(nbuff_size);
  1009. omp_par::merge(nbuff.begin(), nbuff.begin() + (lsize + rsize), rbuff_ext.begin(), rbuff_ext.begin() + ext_rsize, nbuff_ext.begin(), omp_p, std::less<Type>());
  1010. nbuff.Swap(nbuff_ext);
  1011. nbuff_ext.ReInit(0);
  1012. }
  1013. }
  1014. // Copy new data.
  1015. totSize = totSize_new;
  1016. nelem = nbuff_size;
  1017. arr.Swap(nbuff);
  1018. nbuff.ReInit(0);
  1019. }
  1020. { // Split comm. O( log(p) ) ??
  1021. MPI_Comm scomm;
  1022. MPI_Comm_split(comm, myrank <= split_id, myrank, &scomm);
  1023. if (free_comm) MPI_Comm_free(&comm);
  1024. comm = scomm;
  1025. free_comm = true;
  1026. npes = (myrank <= split_id ? split_id + 1 : npes - split_id - 1);
  1027. myrank = (myrank <= split_id ? myrank : myrank - split_id - 1);
  1028. }
  1029. }
  1030. if (free_comm) MPI_Comm_free(&comm);
  1031. SortedElem = arr;
  1032. PartitionW<Type>(SortedElem);
  1033. #else
  1034. SortedElem = arr_;
  1035. std::sort(SortedElem.begin(), SortedElem.begin() + SortedElem.Dim());
  1036. #endif
  1037. }
  1038. } // end namespace