comm.txx 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179
  1. #include <type_traits>
  2. #include SCTL_INCLUDE(ompUtils.hpp)
  3. #include SCTL_INCLUDE(vector.hpp)
  4. namespace SCTL_NAMESPACE {
  5. inline Comm::Comm() {
  6. #ifdef SCTL_HAVE_MPI
  7. Init(MPI_COMM_SELF);
  8. #endif
  9. }
  10. inline Comm::Comm(const Comm& c) {
  11. #ifdef SCTL_HAVE_MPI
  12. Init(c.mpi_comm_);
  13. #endif
  14. }
  15. inline Comm Comm::Self() {
  16. #ifdef SCTL_HAVE_MPI
  17. Comm comm_self(MPI_COMM_SELF);
  18. return comm_self;
  19. #else
  20. Comm comm_self;
  21. return comm_self;
  22. #endif
  23. }
  24. inline Comm Comm::World() {
  25. #ifdef SCTL_HAVE_MPI
  26. Comm comm_world(MPI_COMM_WORLD);
  27. return comm_world;
  28. #else
  29. Comm comm_self;
  30. return comm_self;
  31. #endif
  32. }
  33. inline Comm& Comm::operator=(const Comm& c) {
  34. #ifdef SCTL_HAVE_MPI
  35. #pragma omp critical(SCTL_COMM_DUP)
  36. MPI_Comm_free(&mpi_comm_);
  37. Init(c.mpi_comm_);
  38. #endif
  39. return *this;
  40. }
  41. inline Comm::~Comm() {
  42. #ifdef SCTL_HAVE_MPI
  43. while (!req.empty()) {
  44. delete (Vector<MPI_Request>*)req.top();
  45. req.pop();
  46. }
  47. #pragma omp critical(SCTL_COMM_DUP)
  48. MPI_Comm_free(&mpi_comm_);
  49. #endif
  50. }
  51. inline Comm Comm::Split(Integer clr) const {
  52. #ifdef SCTL_HAVE_MPI
  53. MPI_Comm new_comm;
  54. #pragma omp critical(SCTL_COMM_DUP)
  55. MPI_Comm_split(mpi_comm_, clr, mpi_rank_, &new_comm);
  56. Comm c(new_comm);
  57. #pragma omp critical(SCTL_COMM_DUP)
  58. MPI_Comm_free(&new_comm);
  59. return c;
  60. #else
  61. Comm c;
  62. return c;
  63. #endif
  64. }
  65. inline Integer Comm::Rank() const {
  66. #ifdef SCTL_HAVE_MPI
  67. return mpi_rank_;
  68. #else
  69. return 0;
  70. #endif
  71. }
  72. inline Integer Comm::Size() const {
  73. #ifdef SCTL_HAVE_MPI
  74. return mpi_size_;
  75. #else
  76. return 1;
  77. #endif
  78. }
  79. inline void Comm::Barrier() const {
  80. #ifdef SCTL_HAVE_MPI
  81. MPI_Barrier(mpi_comm_);
  82. #endif
  83. }
  84. template <class SType> void* Comm::Isend(ConstIterator<SType> sbuf, Long scount, Integer dest, Integer tag) const {
  85. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  86. #ifdef SCTL_HAVE_MPI
  87. if (!scount) return nullptr;
  88. Vector<MPI_Request>& request = *NewReq();
  89. request.ReInit(1);
  90. SCTL_UNUSED(sbuf[0] );
  91. SCTL_UNUSED(sbuf[scount - 1]);
  92. #ifndef NDEBUG
  93. MPI_Issend(&sbuf[0], scount, CommDatatype<SType>::value(), dest, tag, mpi_comm_, &request[0]);
  94. #else
  95. MPI_Isend(&sbuf[0], scount, CommDatatype<SType>::value(), dest, tag, mpi_comm_, &request[0]);
  96. #endif
  97. return &request;
  98. #else
  99. auto it = recv_req.find(tag);
  100. if (it == recv_req.end()) {
  101. send_req.insert(std::pair<Integer, ConstIterator<char>>(tag, (ConstIterator<char>)sbuf));
  102. } else {
  103. memcopy(it->second, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  104. recv_req.erase(it);
  105. }
  106. return nullptr;
  107. #endif
  108. }
  109. template <class RType> void* Comm::Irecv(Iterator<RType> rbuf, Long rcount, Integer source, Integer tag) const {
  110. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  111. #ifdef SCTL_HAVE_MPI
  112. if (!rcount) return nullptr;
  113. Vector<MPI_Request>& request = *NewReq();
  114. request.ReInit(1);
  115. SCTL_UNUSED(rbuf[0] );
  116. SCTL_UNUSED(rbuf[rcount - 1]);
  117. MPI_Irecv(&rbuf[0], rcount, CommDatatype<RType>::value(), source, tag, mpi_comm_, &request[0]);
  118. return &request;
  119. #else
  120. auto it = send_req.find(tag);
  121. if (it == send_req.end()) {
  122. recv_req.insert(std::pair<Integer, Iterator<char>>(tag, (Iterator<char>)rbuf));
  123. } else {
  124. memcopy((Iterator<char>)rbuf, it->second, rcount * sizeof(RType));
  125. send_req.erase(it);
  126. }
  127. return nullptr;
  128. #endif
  129. }
  130. inline void Comm::Wait(void* req_ptr) const {
  131. #ifdef SCTL_HAVE_MPI
  132. if (req_ptr == nullptr) return;
  133. Vector<MPI_Request>& request = *(Vector<MPI_Request>*)req_ptr;
  134. // std::vector<MPI_Status> status(request.Dim());
  135. if (request.Dim()) MPI_Waitall(request.Dim(), &request[0], MPI_STATUSES_IGNORE); //&status[0]);
  136. DelReq(&request);
  137. #endif
  138. }
  139. template <class Type> void Comm::Bcast(Iterator<Type> buf, Long count, Long root) const {
  140. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  141. #ifdef SCTL_HAVE_MPI
  142. if (!count) return;
  143. SCTL_UNUSED(buf[0] );
  144. SCTL_UNUSED(buf[count - 1]);
  145. MPI_Bcast(&buf[0], count, CommDatatype<Type>::value(), root, mpi_comm_);
  146. #endif
  147. }
  148. template <class SType, class RType> void Comm::Allgather(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const {
  149. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  150. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  151. if (scount) {
  152. SCTL_UNUSED(sbuf[0] );
  153. SCTL_UNUSED(sbuf[scount - 1]);
  154. }
  155. if (rcount) {
  156. SCTL_UNUSED(rbuf[0] );
  157. SCTL_UNUSED(rbuf[rcount * Size() - 1]);
  158. }
  159. #ifdef SCTL_HAVE_MPI
  160. MPI_Allgather((scount ? &sbuf[0] : nullptr), scount, CommDatatype<SType>::value(), (rcount ? &rbuf[0] : nullptr), rcount, CommDatatype<RType>::value(), mpi_comm_);
  161. #else
  162. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  163. #endif
  164. }
  165. template <class SType, class RType> void Comm::Allgatherv(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const {
  166. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  167. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  168. #ifdef SCTL_HAVE_MPI
  169. Vector<int> rcounts_(mpi_size_), rdispls_(mpi_size_);
  170. Long rcount_sum = 0;
  171. #pragma omp parallel for schedule(static) reduction(+ : rcount_sum)
  172. for (Integer i = 0; i < mpi_size_; i++) {
  173. rcounts_[i] = rcounts[i];
  174. rdispls_[i] = rdispls[i];
  175. rcount_sum += rcounts[i];
  176. }
  177. if (scount) {
  178. SCTL_UNUSED(sbuf[0] );
  179. SCTL_UNUSED(sbuf[scount - 1]);
  180. }
  181. if (rcount_sum) {
  182. SCTL_UNUSED(rbuf[0] );
  183. SCTL_UNUSED(rbuf[rcount_sum - 1]);
  184. }
  185. MPI_Allgatherv((scount ? &sbuf[0] : nullptr), scount, CommDatatype<SType>::value(), (rcount_sum ? &rbuf[0] : nullptr), &rcounts_.begin()[0], &rdispls_.begin()[0], CommDatatype<RType>::value(), mpi_comm_);
  186. #else
  187. memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)sbuf, scount * sizeof(SType));
  188. #endif
  189. }
  190. template <class SType, class RType> void Comm::Alltoall(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const {
  191. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  192. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  193. #ifdef SCTL_HAVE_MPI
  194. if (scount) {
  195. SCTL_UNUSED(sbuf[0] );
  196. SCTL_UNUSED(sbuf[scount * Size() - 1]);
  197. }
  198. if (rcount) {
  199. SCTL_UNUSED(rbuf[0] );
  200. SCTL_UNUSED(rbuf[rcount * Size() - 1]);
  201. }
  202. MPI_Alltoall((scount ? &sbuf[0] : nullptr), scount, CommDatatype<SType>::value(), (rcount ? &rbuf[0] : nullptr), rcount, CommDatatype<RType>::value(), mpi_comm_);
  203. #else
  204. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  205. #endif
  206. }
  207. template <class SType, class RType> void* Comm::Ialltoallv_sparse(ConstIterator<SType> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls, Integer tag) const {
  208. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  209. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  210. #ifdef SCTL_HAVE_MPI
  211. Integer request_count = 0;
  212. for (Integer i = 0; i < mpi_size_; i++) {
  213. if (rcounts[i]) request_count++;
  214. if (scounts[i]) request_count++;
  215. }
  216. if (!request_count) return nullptr;
  217. Vector<MPI_Request>& request = *NewReq();
  218. request.ReInit(request_count);
  219. Integer request_iter = 0;
  220. for (Integer i = 0; i < mpi_size_; i++) {
  221. if (rcounts[i]) {
  222. SCTL_UNUSED(rbuf[rdispls[i]]);
  223. SCTL_UNUSED(rbuf[rdispls[i] + rcounts[i] - 1]);
  224. MPI_Irecv(&rbuf[rdispls[i]], rcounts[i], CommDatatype<RType>::value(), i, tag, mpi_comm_, &request[request_iter]);
  225. request_iter++;
  226. }
  227. }
  228. for (Integer i = 0; i < mpi_size_; i++) {
  229. if (scounts[i]) {
  230. SCTL_UNUSED(sbuf[sdispls[i]]);
  231. SCTL_UNUSED(sbuf[sdispls[i] + scounts[i] - 1]);
  232. MPI_Isend(&sbuf[sdispls[i]], scounts[i], CommDatatype<SType>::value(), i, tag, mpi_comm_, &request[request_iter]);
  233. request_iter++;
  234. }
  235. }
  236. return &request;
  237. #else
  238. memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)(sbuf + sdispls[0]), scounts[0] * sizeof(SType));
  239. return nullptr;
  240. #endif
  241. }
  242. template <class Type> void Comm::Alltoallv(ConstIterator<Type> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<Type> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const {
  243. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  244. #ifdef SCTL_HAVE_MPI
  245. { // Use Alltoallv_sparse of average connectivity<64
  246. Long connectivity = 0, glb_connectivity = 0;
  247. #pragma omp parallel for schedule(static) reduction(+ : connectivity)
  248. for (Integer i = 0; i < mpi_size_; i++) {
  249. if (rcounts[i]) connectivity++;
  250. }
  251. Allreduce(Ptr2ConstItr<Long>(&connectivity, 1), Ptr2Itr<Long>(&glb_connectivity, 1), 1, CommOp::SUM);
  252. if (glb_connectivity < 64 * Size()) {
  253. void* mpi_req = Ialltoallv_sparse(sbuf, scounts, sdispls, rbuf, rcounts, rdispls, 0);
  254. Wait(mpi_req);
  255. return;
  256. }
  257. }
  258. { // Use vendor MPI_Alltoallv
  259. //#ifndef ALLTOALLV_FIX
  260. Vector<int> scnt, sdsp, rcnt, rdsp;
  261. scnt.ReInit(mpi_size_);
  262. sdsp.ReInit(mpi_size_);
  263. rcnt.ReInit(mpi_size_);
  264. rdsp.ReInit(mpi_size_);
  265. Long stotal = 0, rtotal = 0;
  266. #pragma omp parallel for schedule(static) reduction(+ : stotal, rtotal)
  267. for (Integer i = 0; i < mpi_size_; i++) {
  268. scnt[i] = scounts[i];
  269. sdsp[i] = sdispls[i];
  270. rcnt[i] = rcounts[i];
  271. rdsp[i] = rdispls[i];
  272. stotal += scounts[i];
  273. rtotal += rcounts[i];
  274. }
  275. MPI_Alltoallv((stotal ? &sbuf[0] : nullptr), &scnt[0], &sdsp[0], CommDatatype<Type>::value(), (rtotal ? &rbuf[0] : nullptr), &rcnt[0], &rdsp[0], CommDatatype<Type>::value(), mpi_comm_);
  276. return;
  277. //#endif
  278. }
  279. // TODO: implement hypercube scheme
  280. #else
  281. memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)(sbuf + sdispls[0]), scounts[0] * sizeof(Type));
  282. #endif
  283. }
  284. template <class Type> void Comm::Allreduce(ConstIterator<Type> sbuf, Iterator<Type> rbuf, Long count, CommOp op) const {
  285. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  286. #ifdef SCTL_HAVE_MPI
  287. if (!count) return;
  288. MPI_Op mpi_op;
  289. switch (op) {
  290. case CommOp::SUM:
  291. mpi_op = CommDatatype<Type>::sum();
  292. break;
  293. case CommOp::MIN:
  294. mpi_op = CommDatatype<Type>::min();
  295. break;
  296. case CommOp::MAX:
  297. mpi_op = CommDatatype<Type>::max();
  298. break;
  299. default:
  300. mpi_op = MPI_OP_NULL;
  301. break;
  302. }
  303. SCTL_UNUSED(sbuf[0] );
  304. SCTL_UNUSED(sbuf[count - 1]);
  305. SCTL_UNUSED(rbuf[0] );
  306. SCTL_UNUSED(rbuf[count - 1]);
  307. MPI_Allreduce(&sbuf[0], &rbuf[0], count, CommDatatype<Type>::value(), mpi_op, mpi_comm_);
  308. #else
  309. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, count * sizeof(Type));
  310. #endif
  311. }
  312. template <class Type> void Comm::Scan(ConstIterator<Type> sbuf, Iterator<Type> rbuf, int count, CommOp op) const {
  313. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  314. #ifdef SCTL_HAVE_MPI
  315. if (!count) return;
  316. MPI_Op mpi_op;
  317. switch (op) {
  318. case CommOp::SUM:
  319. mpi_op = CommDatatype<Type>::sum();
  320. break;
  321. case CommOp::MIN:
  322. mpi_op = CommDatatype<Type>::min();
  323. break;
  324. case CommOp::MAX:
  325. mpi_op = CommDatatype<Type>::max();
  326. break;
  327. default:
  328. mpi_op = MPI_OP_NULL;
  329. break;
  330. }
  331. SCTL_UNUSED(sbuf[0] );
  332. SCTL_UNUSED(sbuf[count - 1]);
  333. SCTL_UNUSED(rbuf[0] );
  334. SCTL_UNUSED(rbuf[count - 1]);
  335. MPI_Scan(&sbuf[0], &rbuf[0], count, CommDatatype<Type>::value(), mpi_op, mpi_comm_);
  336. #else
  337. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, count * sizeof(Type));
  338. #endif
  339. }
  340. template <class Type> void Comm::PartitionW(Vector<Type>& nodeList, const Vector<Long>* wts_) const {
  341. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  342. Integer npes = Size();
  343. if (npes == 1) return;
  344. Long nlSize = nodeList.Dim();
  345. Vector<Long> wts;
  346. Long localWt = 0;
  347. if (wts_ == nullptr) { // Construct arrays of wts.
  348. wts.ReInit(nlSize);
  349. #pragma omp parallel for schedule(static)
  350. for (Long i = 0; i < nlSize; i++) {
  351. wts[i] = 1;
  352. }
  353. localWt = nlSize;
  354. } else {
  355. wts.ReInit(nlSize, (Iterator<Long>)wts_->begin(), false);
  356. #pragma omp parallel for reduction(+ : localWt)
  357. for (Long i = 0; i < nlSize; i++) {
  358. localWt += wts[i];
  359. }
  360. }
  361. Long off1 = 0, off2 = 0, totalWt = 0;
  362. { // compute the total weight of the problem ...
  363. Allreduce<Long>(Ptr2ConstItr<Long>(&localWt, 1), Ptr2Itr<Long>(&totalWt, 1), 1, CommOp::SUM);
  364. Scan<Long>(Ptr2ConstItr<Long>(&localWt, 1), Ptr2Itr<Long>(&off2, 1), 1, CommOp::SUM);
  365. off1 = off2 - localWt;
  366. }
  367. Vector<Long> lscn;
  368. if (nlSize) { // perform a local scan on the weights first ...
  369. lscn.ReInit(nlSize);
  370. lscn[0] = off1;
  371. omp_par::scan(wts.begin(), lscn.begin(), nlSize);
  372. }
  373. Vector<Long> sendSz, recvSz, sendOff, recvOff;
  374. sendSz.ReInit(npes);
  375. recvSz.ReInit(npes);
  376. sendOff.ReInit(npes);
  377. recvOff.ReInit(npes);
  378. sendSz.SetZero();
  379. if (nlSize > 0 && totalWt > 0) { // Compute sendSz
  380. Long pid1 = (off1 * npes) / totalWt;
  381. Long pid2 = ((off2 + 1) * npes) / totalWt + 1;
  382. assert((totalWt * pid2) / npes >= off2);
  383. pid1 = (pid1 < 0 ? 0 : pid1);
  384. pid2 = (pid2 > npes ? npes : pid2);
  385. #pragma omp parallel for schedule(static)
  386. for (Integer i = pid1; i < pid2; i++) {
  387. Long wt1 = (totalWt * (i)) / npes;
  388. Long wt2 = (totalWt * (i + 1)) / npes;
  389. Long start = std::lower_bound(lscn.begin(), lscn.begin() + nlSize, wt1, std::less<Long>()) - lscn.begin();
  390. Long end = std::lower_bound(lscn.begin(), lscn.begin() + nlSize, wt2, std::less<Long>()) - lscn.begin();
  391. if (i == 0) start = 0;
  392. if (i == npes - 1) end = nlSize;
  393. sendSz[i] = end - start;
  394. }
  395. } else {
  396. sendSz[0] = nlSize;
  397. }
  398. // Exchange sendSz, recvSz
  399. Alltoall<Long>(sendSz.begin(), 1, recvSz.begin(), 1);
  400. { // Compute sendOff, recvOff
  401. sendOff[0] = 0;
  402. omp_par::scan(sendSz.begin(), sendOff.begin(), npes);
  403. recvOff[0] = 0;
  404. omp_par::scan(recvSz.begin(), recvOff.begin(), npes);
  405. assert(sendOff[npes - 1] + sendSz[npes - 1] == nlSize);
  406. }
  407. // perform All2All ...
  408. Vector<Type> newNodes;
  409. newNodes.ReInit(recvSz[npes - 1] + recvOff[npes - 1]);
  410. void* mpi_req = Ialltoallv_sparse<Type>(nodeList.begin(), sendSz.begin(), sendOff.begin(), newNodes.begin(), recvSz.begin(), recvOff.begin());
  411. Wait(mpi_req);
  412. // reset the pointer ...
  413. nodeList.Swap(newNodes);
  414. }
  415. template <class Type> void Comm::PartitionN(Vector<Type>& v, Long N) const {
  416. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  417. Integer rank = Rank();
  418. Integer np = Size();
  419. if (np == 1) return;
  420. Vector<Long> v_cnt(np), v_dsp(np + 1);
  421. Vector<Long> N_cnt(np), N_dsp(np + 1);
  422. { // Set v_cnt, v_dsp
  423. v_dsp[0] = 0;
  424. Long cnt = v.Dim();
  425. Allgather(Ptr2ConstItr<Long>(&cnt, 1), 1, v_cnt.begin(), 1);
  426. omp_par::scan(v_cnt.begin(), v_dsp.begin(), np);
  427. v_dsp[np] = v_cnt[np - 1] + v_dsp[np - 1];
  428. }
  429. { // Set N_cnt, N_dsp
  430. N_dsp[0] = 0;
  431. Long cnt = N;
  432. Allgather(Ptr2ConstItr<Long>(&cnt, 1), 1, N_cnt.begin(), 1);
  433. omp_par::scan(N_cnt.begin(), N_dsp.begin(), np);
  434. N_dsp[np] = N_cnt[np - 1] + N_dsp[np - 1];
  435. }
  436. { // Adjust for dof
  437. Long dof = (N_dsp[np] ? v_dsp[np] / N_dsp[np] : 0);
  438. assert(dof * N_dsp[np] == v_dsp[np]);
  439. if (dof == 0) return;
  440. if (dof != 1) {
  441. #pragma omp parallel for schedule(static)
  442. for (Integer i = 0; i < np; i++) N_cnt[i] *= dof;
  443. #pragma omp parallel for schedule(static)
  444. for (Integer i = 0; i <= np; i++) N_dsp[i] *= dof;
  445. }
  446. }
  447. Vector<Type> v_(N_cnt[rank]);
  448. { // Set v_
  449. Vector<Long> scnt(np), sdsp(np);
  450. Vector<Long> rcnt(np), rdsp(np);
  451. #pragma omp parallel for schedule(static)
  452. for (Integer i = 0; i < np; i++) {
  453. { // Set scnt
  454. Long n0 = N_dsp[i + 0];
  455. Long n1 = N_dsp[i + 1];
  456. if (n0 < v_dsp[rank + 0]) n0 = v_dsp[rank + 0];
  457. if (n1 < v_dsp[rank + 0]) n1 = v_dsp[rank + 0];
  458. if (n0 > v_dsp[rank + 1]) n0 = v_dsp[rank + 1];
  459. if (n1 > v_dsp[rank + 1]) n1 = v_dsp[rank + 1];
  460. scnt[i] = n1 - n0;
  461. }
  462. { // Set rcnt
  463. Long n0 = v_dsp[i + 0];
  464. Long n1 = v_dsp[i + 1];
  465. if (n0 < N_dsp[rank + 0]) n0 = N_dsp[rank + 0];
  466. if (n1 < N_dsp[rank + 0]) n1 = N_dsp[rank + 0];
  467. if (n0 > N_dsp[rank + 1]) n0 = N_dsp[rank + 1];
  468. if (n1 > N_dsp[rank + 1]) n1 = N_dsp[rank + 1];
  469. rcnt[i] = n1 - n0;
  470. }
  471. }
  472. sdsp[0] = 0;
  473. omp_par::scan(scnt.begin(), sdsp.begin(), np);
  474. rdsp[0] = 0;
  475. omp_par::scan(rcnt.begin(), rdsp.begin(), np);
  476. void* mpi_request = Ialltoallv_sparse(v.begin(), scnt.begin(), sdsp.begin(), v_.begin(), rcnt.begin(), rdsp.begin());
  477. Wait(mpi_request);
  478. }
  479. v.Swap(v_);
  480. }
  481. template <class Type, class Compare> void Comm::PartitionS(Vector<Type>& nodeList, const Type& splitter, Compare comp) const {
  482. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  483. Integer npes = Size();
  484. if (npes == 1) return;
  485. Vector<Type> mins(npes);
  486. Allgather(Ptr2ConstItr<Type>(&splitter, 1), 1, mins.begin(), 1);
  487. Vector<Long> scnt(npes), sdsp(npes);
  488. Vector<Long> rcnt(npes), rdsp(npes);
  489. { // Compute scnt, sdsp
  490. #pragma omp parallel for schedule(static)
  491. for (Integer i = 0; i < npes; i++) {
  492. sdsp[i] = std::lower_bound(nodeList.begin(), nodeList.begin() + nodeList.Dim(), mins[i], comp) - nodeList.begin();
  493. }
  494. #pragma omp parallel for schedule(static)
  495. for (Integer i = 0; i < npes - 1; i++) {
  496. scnt[i] = sdsp[i + 1] - sdsp[i];
  497. }
  498. scnt[npes - 1] = nodeList.Dim() - sdsp[npes - 1];
  499. }
  500. { // Compute rcnt, rdsp
  501. rdsp[0] = 0;
  502. Alltoall(scnt.begin(), 1, rcnt.begin(), 1);
  503. omp_par::scan(rcnt.begin(), rdsp.begin(), npes);
  504. }
  505. { // Redistribute nodeList
  506. Vector<Type> nodeList_(rdsp[npes - 1] + rcnt[npes - 1]);
  507. void* mpi_request = Ialltoallv_sparse(nodeList.begin(), scnt.begin(), sdsp.begin(), nodeList_.begin(), rcnt.begin(), rdsp.begin());
  508. Wait(mpi_request);
  509. nodeList.Swap(nodeList_);
  510. }
  511. }
  512. template <class Type> void Comm::SortScatterIndex(const Vector<Type>& key, Vector<Long>& scatter_index, const Type* split_key_) const {
  513. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  514. typedef SortPair<Type, Long> Pair_t;
  515. Integer npes = Size();
  516. Vector<Pair_t> parray(key.Dim());
  517. { // Build global index.
  518. Long glb_dsp = 0;
  519. Long loc_size = key.Dim();
  520. Scan(Ptr2ConstItr<Long>(&loc_size, 1), Ptr2Itr<Long>(&glb_dsp, 1), 1, CommOp::SUM);
  521. glb_dsp -= loc_size;
  522. #pragma omp parallel for schedule(static)
  523. for (Long i = 0; i < loc_size; i++) {
  524. parray[i].key = key[i];
  525. parray[i].data = glb_dsp + i;
  526. }
  527. }
  528. Vector<Pair_t> psorted;
  529. HyperQuickSort(parray, psorted);
  530. if (npes > 1 && split_key_ != nullptr) { // Partition data
  531. Vector<Type> split_key(npes);
  532. Allgather(Ptr2ConstItr<Type>(split_key_, 1), 1, split_key.begin(), 1);
  533. Vector<Long> sendSz(npes);
  534. Vector<Long> recvSz(npes);
  535. Vector<Long> sendOff(npes);
  536. Vector<Long> recvOff(npes);
  537. Long nlSize = psorted.Dim();
  538. sendSz.SetZero();
  539. if (nlSize > 0) { // Compute sendSz
  540. // Determine processor range.
  541. Long pid1 = std::lower_bound(split_key.begin(), split_key.begin() + npes, psorted[0].key) - split_key.begin() - 1;
  542. Long pid2 = std::upper_bound(split_key.begin(), split_key.begin() + npes, psorted[nlSize - 1].key) - split_key.begin() + 1;
  543. pid1 = (pid1 < 0 ? 0 : pid1);
  544. pid2 = (pid2 > npes ? npes : pid2);
  545. #pragma omp parallel for schedule(static)
  546. for (Integer i = pid1; i < pid2; i++) {
  547. Pair_t p1;
  548. p1.key = split_key[i];
  549. Pair_t p2;
  550. p2.key = split_key[i + 1 < npes ? i + 1 : i];
  551. Long start = std::lower_bound(psorted.begin(), psorted.begin() + nlSize, p1, std::less<Pair_t>()) - psorted.begin();
  552. Long end = std::lower_bound(psorted.begin(), psorted.begin() + nlSize, p2, std::less<Pair_t>()) - psorted.begin();
  553. if (i == 0) start = 0;
  554. if (i == npes - 1) end = nlSize;
  555. sendSz[i] = end - start;
  556. }
  557. }
  558. // Exchange sendSz, recvSz
  559. Alltoall<Long>(sendSz.begin(), 1, recvSz.begin(), 1);
  560. // compute offsets ...
  561. { // Compute sendOff, recvOff
  562. sendOff[0] = 0;
  563. omp_par::scan(sendSz.begin(), sendOff.begin(), npes);
  564. recvOff[0] = 0;
  565. omp_par::scan(recvSz.begin(), recvOff.begin(), npes);
  566. assert(sendOff[npes - 1] + sendSz[npes - 1] == nlSize);
  567. }
  568. // perform All2All ...
  569. Vector<Pair_t> newNodes(recvSz[npes - 1] + recvOff[npes - 1]);
  570. void* mpi_req = Ialltoallv_sparse<Pair_t>(psorted.begin(), sendSz.begin(), sendOff.begin(), newNodes.begin(), recvSz.begin(), recvOff.begin());
  571. Wait(mpi_req);
  572. // reset the pointer ...
  573. psorted.Swap(newNodes);
  574. }
  575. scatter_index.ReInit(psorted.Dim());
  576. #pragma omp parallel for schedule(static)
  577. for (Long i = 0; i < psorted.Dim(); i++) {
  578. scatter_index[i] = psorted[i].data;
  579. }
  580. }
  581. template <class Type> void Comm::ScatterForward(Vector<Type>& data_, const Vector<Long>& scatter_index) const {
  582. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  583. typedef SortPair<Long, Long> Pair_t;
  584. Integer npes = Size(), rank = Rank();
  585. Long data_dim = 0;
  586. Long send_size = 0;
  587. Long recv_size = 0;
  588. { // Set data_dim, send_size, recv_size
  589. recv_size = scatter_index.Dim();
  590. StaticArray<Long, 2> glb_size;
  591. StaticArray<Long, 2> loc_size;
  592. loc_size[0] = data_.Dim();
  593. loc_size[1] = recv_size;
  594. Allreduce<Long>(loc_size, glb_size, 2, CommOp::SUM);
  595. if (glb_size[0] == 0 || glb_size[1] == 0) return; // Nothing to be done.
  596. data_dim = glb_size[0] / glb_size[1];
  597. SCTL_ASSERT(glb_size[0] == data_dim * glb_size[1]);
  598. send_size = data_.Dim() / data_dim;
  599. }
  600. if (npes == 1) { // Scatter directly
  601. Vector<Type> data;
  602. data.ReInit(recv_size * data_dim);
  603. #pragma omp parallel for schedule(static)
  604. for (Long i = 0; i < recv_size; i++) {
  605. Long src_indx = scatter_index[i] * data_dim;
  606. Long trg_indx = i * data_dim;
  607. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = data_[src_indx + j];
  608. }
  609. data_.Swap(data);
  610. return;
  611. }
  612. Vector<Long> glb_scan;
  613. { // Global scan of data size.
  614. glb_scan.ReInit(npes);
  615. Long glb_rank = 0;
  616. Scan(Ptr2ConstItr<Long>(&send_size, 1), Ptr2Itr<Long>(&glb_rank, 1), 1, CommOp::SUM);
  617. glb_rank -= send_size;
  618. Allgather(Ptr2ConstItr<Long>(&glb_rank, 1), 1, glb_scan.begin(), 1);
  619. }
  620. Vector<Pair_t> psorted;
  621. { // Sort scatter_index.
  622. psorted.ReInit(recv_size);
  623. #pragma omp parallel for schedule(static)
  624. for (Long i = 0; i < recv_size; i++) {
  625. psorted[i].key = scatter_index[i];
  626. psorted[i].data = i;
  627. }
  628. omp_par::merge_sort(psorted.begin(), psorted.begin() + recv_size);
  629. }
  630. Vector<Long> recv_indx(recv_size);
  631. Vector<Long> send_indx(send_size);
  632. Vector<Long> sendSz(npes);
  633. Vector<Long> sendOff(npes);
  634. Vector<Long> recvSz(npes);
  635. Vector<Long> recvOff(npes);
  636. { // Exchange send, recv indices.
  637. #pragma omp parallel for schedule(static)
  638. for (Long i = 0; i < recv_size; i++) {
  639. recv_indx[i] = psorted[i].key;
  640. }
  641. #pragma omp parallel for schedule(static)
  642. for (Integer i = 0; i < npes; i++) {
  643. Long start = std::lower_bound(recv_indx.begin(), recv_indx.begin() + recv_size, glb_scan[i]) - recv_indx.begin();
  644. Long end = (i + 1 < npes ? std::lower_bound(recv_indx.begin(), recv_indx.begin() + recv_size, glb_scan[i + 1]) - recv_indx.begin() : recv_size);
  645. recvSz[i] = end - start;
  646. recvOff[i] = start;
  647. }
  648. Alltoall(recvSz.begin(), 1, sendSz.begin(), 1);
  649. sendOff[0] = 0;
  650. omp_par::scan(sendSz.begin(), sendOff.begin(), npes);
  651. assert(sendOff[npes - 1] + sendSz[npes - 1] == send_size);
  652. Alltoallv(recv_indx.begin(), recvSz.begin(), recvOff.begin(), send_indx.begin(), sendSz.begin(), sendOff.begin());
  653. #pragma omp parallel for schedule(static)
  654. for (Long i = 0; i < send_size; i++) {
  655. assert(send_indx[i] >= glb_scan[rank]);
  656. send_indx[i] -= glb_scan[rank];
  657. assert(send_indx[i] < send_size);
  658. }
  659. }
  660. Vector<Type> send_buff;
  661. { // Prepare send buffer
  662. send_buff.ReInit(send_size * data_dim);
  663. ConstIterator<Type> data = data_.begin();
  664. #pragma omp parallel for schedule(static)
  665. for (Long i = 0; i < send_size; i++) {
  666. Long src_indx = send_indx[i] * data_dim;
  667. Long trg_indx = i * data_dim;
  668. for (Long j = 0; j < data_dim; j++) send_buff[trg_indx + j] = data[src_indx + j];
  669. }
  670. }
  671. Vector<Type> recv_buff;
  672. { // All2Allv
  673. recv_buff.ReInit(recv_size * data_dim);
  674. #pragma omp parallel for schedule(static)
  675. for (Integer i = 0; i < npes; i++) {
  676. sendSz[i] *= data_dim;
  677. sendOff[i] *= data_dim;
  678. recvSz[i] *= data_dim;
  679. recvOff[i] *= data_dim;
  680. }
  681. Alltoallv(send_buff.begin(), sendSz.begin(), sendOff.begin(), recv_buff.begin(), recvSz.begin(), recvOff.begin());
  682. }
  683. { // Build output data.
  684. data_.ReInit(recv_size * data_dim);
  685. Iterator<Type> data = data_.begin();
  686. #pragma omp parallel for schedule(static)
  687. for (Long i = 0; i < recv_size; i++) {
  688. Long src_indx = i * data_dim;
  689. Long trg_indx = psorted[i].data * data_dim;
  690. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = recv_buff[src_indx + j];
  691. }
  692. }
  693. }
  694. template <class Type> void Comm::ScatterReverse(Vector<Type>& data_, const Vector<Long>& scatter_index_, Long loc_size_) const {
  695. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  696. typedef SortPair<Long, Long> Pair_t;
  697. Integer npes = Size(), rank = Rank();
  698. Long data_dim = 0;
  699. Long send_size = 0;
  700. Long recv_size = 0;
  701. { // Set data_dim, send_size, recv_size
  702. recv_size = loc_size_;
  703. StaticArray<Long, 3> glb_size;
  704. StaticArray<Long, 3> loc_size;
  705. loc_size[0] = data_.Dim();
  706. loc_size[1] = scatter_index_.Dim();
  707. loc_size[2] = recv_size;
  708. Allreduce<Long>(loc_size, glb_size, 3, CommOp::SUM);
  709. if (glb_size[0] == 0 || glb_size[1] == 0) return; // Nothing to be done.
  710. SCTL_ASSERT(glb_size[0] % glb_size[1] == 0);
  711. data_dim = glb_size[0] / glb_size[1];
  712. SCTL_ASSERT(loc_size[0] % data_dim == 0);
  713. send_size = loc_size[0] / data_dim;
  714. if (glb_size[0] != glb_size[2] * data_dim) {
  715. recv_size = (((rank + 1) * (glb_size[0] / data_dim)) / npes) - ((rank * (glb_size[0] / data_dim)) / npes);
  716. }
  717. }
  718. if (npes == 1) { // Scatter directly
  719. Vector<Type> data;
  720. data.ReInit(recv_size * data_dim);
  721. #pragma omp parallel for schedule(static)
  722. for (Long i = 0; i < recv_size; i++) {
  723. Long src_indx = i * data_dim;
  724. Long trg_indx = scatter_index_[i] * data_dim;
  725. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = data_[src_indx + j];
  726. }
  727. data_.Swap(data);
  728. return;
  729. }
  730. Vector<Long> scatter_index;
  731. {
  732. StaticArray<Long, 2> glb_rank;
  733. StaticArray<Long, 3> glb_size;
  734. StaticArray<Long, 2> loc_size;
  735. loc_size[0] = data_.Dim() / data_dim;
  736. loc_size[1] = scatter_index_.Dim();
  737. Scan<Long>(loc_size, glb_rank, 2, CommOp::SUM);
  738. Allreduce<Long>(loc_size, glb_size, 2, CommOp::SUM);
  739. SCTL_ASSERT(glb_size[0] == glb_size[1]);
  740. glb_rank[0] -= loc_size[0];
  741. glb_rank[1] -= loc_size[1];
  742. Vector<Long> glb_scan0(npes + 1);
  743. Vector<Long> glb_scan1(npes + 1);
  744. Allgather<Long>(glb_rank + 0, 1, glb_scan0.begin(), 1);
  745. Allgather<Long>(glb_rank + 1, 1, glb_scan1.begin(), 1);
  746. glb_scan0[npes] = glb_size[0];
  747. glb_scan1[npes] = glb_size[1];
  748. if (loc_size[0] != loc_size[1] || glb_rank[0] != glb_rank[1]) { // Repartition scatter_index
  749. scatter_index.ReInit(loc_size[0]);
  750. Vector<Long> send_dsp(npes + 1);
  751. Vector<Long> recv_dsp(npes + 1);
  752. #pragma omp parallel for schedule(static)
  753. for (Integer i = 0; i <= npes; i++) {
  754. send_dsp[i] = std::min(std::max(glb_scan0[i], glb_rank[1]), glb_rank[1] + loc_size[1]) - glb_rank[1];
  755. recv_dsp[i] = std::min(std::max(glb_scan1[i], glb_rank[0]), glb_rank[0] + loc_size[0]) - glb_rank[0];
  756. }
  757. // Long commCnt=0;
  758. Vector<Long> send_cnt(npes + 0);
  759. Vector<Long> recv_cnt(npes + 0);
  760. #pragma omp parallel for schedule(static) // reduction(+:commCnt)
  761. for (Integer i = 0; i < npes; i++) {
  762. send_cnt[i] = send_dsp[i + 1] - send_dsp[i];
  763. recv_cnt[i] = recv_dsp[i + 1] - recv_dsp[i];
  764. // if(send_cnt[i] && i!=rank) commCnt++;
  765. // if(recv_cnt[i] && i!=rank) commCnt++;
  766. }
  767. void* mpi_req = Ialltoallv_sparse<Long>(scatter_index_.begin(), send_cnt.begin(), send_dsp.begin(), scatter_index.begin(), recv_cnt.begin(), recv_dsp.begin(), 0);
  768. Wait(mpi_req);
  769. } else {
  770. scatter_index.ReInit(scatter_index_.Dim(), (Iterator<Long>)scatter_index_.begin(), false);
  771. }
  772. }
  773. Vector<Long> glb_scan(npes);
  774. { // Global data size.
  775. Long glb_rank = 0;
  776. Scan(Ptr2ConstItr<Long>(&recv_size, 1), Ptr2Itr<Long>(&glb_rank, 1), 1, CommOp::SUM);
  777. glb_rank -= recv_size;
  778. Allgather(Ptr2ConstItr<Long>(&glb_rank, 1), 1, glb_scan.begin(), 1);
  779. }
  780. Vector<Pair_t> psorted(send_size);
  781. { // Sort scatter_index.
  782. #pragma omp parallel for schedule(static)
  783. for (Long i = 0; i < send_size; i++) {
  784. psorted[i].key = scatter_index[i];
  785. psorted[i].data = i;
  786. }
  787. omp_par::merge_sort(psorted.begin(), psorted.begin() + send_size);
  788. }
  789. Vector<Long> recv_indx(recv_size);
  790. Vector<Long> send_indx(send_size);
  791. Vector<Long> sendSz(npes);
  792. Vector<Long> sendOff(npes);
  793. Vector<Long> recvSz(npes);
  794. Vector<Long> recvOff(npes);
  795. { // Exchange send, recv indices.
  796. #pragma omp parallel for schedule(static)
  797. for (Long i = 0; i < send_size; i++) {
  798. send_indx[i] = psorted[i].key;
  799. }
  800. #pragma omp parallel for schedule(static)
  801. for (Integer i = 0; i < npes; i++) {
  802. Long start = std::lower_bound(send_indx.begin(), send_indx.begin() + send_size, glb_scan[i]) - send_indx.begin();
  803. Long end = (i + 1 < npes ? std::lower_bound(send_indx.begin(), send_indx.begin() + send_size, glb_scan[i + 1]) - send_indx.begin() : send_size);
  804. sendSz[i] = end - start;
  805. sendOff[i] = start;
  806. }
  807. Alltoall(sendSz.begin(), 1, recvSz.begin(), 1);
  808. recvOff[0] = 0;
  809. omp_par::scan(recvSz.begin(), recvOff.begin(), npes);
  810. assert(recvOff[npes - 1] + recvSz[npes - 1] == recv_size);
  811. Alltoallv(send_indx.begin(), sendSz.begin(), sendOff.begin(), recv_indx.begin(), recvSz.begin(), recvOff.begin());
  812. #pragma omp parallel for schedule(static)
  813. for (Long i = 0; i < recv_size; i++) {
  814. assert(recv_indx[i] >= glb_scan[rank]);
  815. recv_indx[i] -= glb_scan[rank];
  816. assert(recv_indx[i] < recv_size);
  817. }
  818. }
  819. Vector<Type> send_buff;
  820. { // Prepare send buffer
  821. send_buff.ReInit(send_size * data_dim);
  822. ConstIterator<Type> data = data_.begin();
  823. #pragma omp parallel for schedule(static)
  824. for (Long i = 0; i < send_size; i++) {
  825. Long src_indx = psorted[i].data * data_dim;
  826. Long trg_indx = i * data_dim;
  827. for (Long j = 0; j < data_dim; j++) send_buff[trg_indx + j] = data[src_indx + j];
  828. }
  829. }
  830. Vector<Type> recv_buff;
  831. { // All2Allv
  832. recv_buff.ReInit(recv_size * data_dim);
  833. #pragma omp parallel for schedule(static)
  834. for (Integer i = 0; i < npes; i++) {
  835. sendSz[i] *= data_dim;
  836. sendOff[i] *= data_dim;
  837. recvSz[i] *= data_dim;
  838. recvOff[i] *= data_dim;
  839. }
  840. Alltoallv(send_buff.begin(), sendSz.begin(), sendOff.begin(), recv_buff.begin(), recvSz.begin(), recvOff.begin());
  841. }
  842. { // Build output data.
  843. data_.ReInit(recv_size * data_dim);
  844. Iterator<Type> data = data_.begin();
  845. #pragma omp parallel for schedule(static)
  846. for (Long i = 0; i < recv_size; i++) {
  847. Long src_indx = i * data_dim;
  848. Long trg_indx = recv_indx[i] * data_dim;
  849. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = recv_buff[src_indx + j];
  850. }
  851. }
  852. }
  853. #ifdef SCTL_HAVE_MPI
  854. inline Vector<MPI_Request>* Comm::NewReq() const {
  855. if (req.empty()) req.push(new Vector<MPI_Request>);
  856. Vector<MPI_Request>& request = *(Vector<MPI_Request>*)req.top();
  857. req.pop();
  858. return &request;
  859. }
  860. inline void Comm::Init(const MPI_Comm mpi_comm) {
  861. #pragma omp critical(SCTL_COMM_DUP)
  862. MPI_Comm_dup(mpi_comm, &mpi_comm_);
  863. MPI_Comm_rank(mpi_comm_, &mpi_rank_);
  864. MPI_Comm_size(mpi_comm_, &mpi_size_);
  865. }
  866. inline void Comm::DelReq(Vector<MPI_Request>* req_ptr) const {
  867. if (req_ptr) req.push(req_ptr);
  868. }
  869. #define SCTL_HS_MPIDATATYPE(CTYPE, MPITYPE) \
  870. template <> class Comm::CommDatatype<CTYPE> { \
  871. public: \
  872. static MPI_Datatype value() { return MPITYPE; } \
  873. static MPI_Op sum() { return MPI_SUM; } \
  874. static MPI_Op min() { return MPI_MIN; } \
  875. static MPI_Op max() { return MPI_MAX; } \
  876. }
  877. SCTL_HS_MPIDATATYPE(short, MPI_SHORT);
  878. SCTL_HS_MPIDATATYPE(int, MPI_INT);
  879. SCTL_HS_MPIDATATYPE(long, MPI_LONG);
  880. SCTL_HS_MPIDATATYPE(unsigned short, MPI_UNSIGNED_SHORT);
  881. SCTL_HS_MPIDATATYPE(unsigned int, MPI_UNSIGNED);
  882. SCTL_HS_MPIDATATYPE(unsigned long, MPI_UNSIGNED_LONG);
  883. SCTL_HS_MPIDATATYPE(float, MPI_FLOAT);
  884. SCTL_HS_MPIDATATYPE(double, MPI_DOUBLE);
  885. SCTL_HS_MPIDATATYPE(long double, MPI_LONG_DOUBLE);
  886. SCTL_HS_MPIDATATYPE(long long, MPI_LONG_LONG_INT);
  887. SCTL_HS_MPIDATATYPE(char, MPI_CHAR);
  888. SCTL_HS_MPIDATATYPE(unsigned char, MPI_UNSIGNED_CHAR);
  889. #undef SCTL_HS_MPIDATATYPE
  890. #endif
  891. template <class Type, class Compare> void Comm::HyperQuickSort(const Vector<Type>& arr_, Vector<Type>& SortedElem, Compare comp) const { // O( ((N/p)+log(p))*(log(N/p)+log(p)) )
  892. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  893. #ifdef SCTL_HAVE_MPI
  894. Integer npes, myrank, omp_p;
  895. { // Get comm size and rank.
  896. npes = Size();
  897. myrank = Rank();
  898. omp_p = omp_get_max_threads();
  899. }
  900. srand(myrank);
  901. Long totSize;
  902. { // Local and global sizes. O(log p)
  903. Long nelem = arr_.Dim();
  904. Allreduce<Long>(Ptr2ConstItr<Long>(&nelem, 1), Ptr2Itr<Long>(&totSize, 1), 1, CommOp::SUM);
  905. }
  906. if (npes == 1) { // SortedElem <--- local_sort(arr_)
  907. SortedElem = arr_;
  908. omp_par::merge_sort(SortedElem.begin(), SortedElem.end(), comp);
  909. return;
  910. }
  911. Vector<Type> arr;
  912. { // arr <-- local_sort(arr_)
  913. arr = arr_;
  914. omp_par::merge_sort(arr.begin(), arr.end(), comp);
  915. }
  916. Vector<Type> nbuff, nbuff_ext, rbuff, rbuff_ext; // Allocate memory.
  917. MPI_Comm comm = mpi_comm_; // Copy comm
  918. bool free_comm = false; // Flag to free comm.
  919. // Binary split and merge in each iteration.
  920. while (npes > 1 && totSize > 0) { // O(log p) iterations.
  921. Type split_key;
  922. Long totSize_new;
  923. { // Determine split_key. O( log(N/p) + log(p) )
  924. Integer glb_splt_count;
  925. Vector<Type> glb_splitters;
  926. { // Take random splitters. glb_splt_count = const = 100~1000
  927. Integer splt_count;
  928. Long nelem = arr.Dim();
  929. { // Set splt_coun. O( 1 ) -- Let p * splt_count = t
  930. splt_count = (100 * nelem) / totSize;
  931. if (npes > 100) splt_count = (drand48() * totSize) < (100 * nelem) ? 1 : 0;
  932. if (splt_count > nelem) splt_count = nelem;
  933. MPI_Allreduce (&splt_count, &glb_splt_count, 1, CommDatatype<Integer>::value(), CommDatatype<Integer>::sum(), comm);
  934. if (!glb_splt_count) splt_count = std::min<Long>(1, nelem);
  935. MPI_Allreduce (&splt_count, &glb_splt_count, 1, CommDatatype<Integer>::value(), CommDatatype<Integer>::sum(), comm);
  936. SCTL_ASSERT(glb_splt_count);
  937. }
  938. Vector<Type> splitters(splt_count);
  939. for (Integer i = 0; i < splt_count; i++) {
  940. splitters[i] = arr[rand() % nelem];
  941. }
  942. Vector<Integer> glb_splt_cnts(npes), glb_splt_disp(npes);
  943. { // Set glb_splt_cnts, glb_splt_disp
  944. MPI_Allgather(&splt_count, 1, CommDatatype<Integer>::value(), &glb_splt_cnts[0], 1, CommDatatype<Integer>::value(), comm);
  945. glb_splt_disp[0] = 0;
  946. omp_par::scan(glb_splt_cnts.begin(), glb_splt_disp.begin(), npes);
  947. SCTL_ASSERT(glb_splt_count == glb_splt_cnts[npes - 1] + glb_splt_disp[npes - 1]);
  948. }
  949. { // Gather all splitters. O( log(p) )
  950. glb_splitters.ReInit(glb_splt_count);
  951. Vector<int> glb_splt_cnts_(npes), glb_splt_disp_(npes);
  952. for (Integer i = 0; i < npes; i++) {
  953. glb_splt_cnts_[i] = glb_splt_cnts[i];
  954. glb_splt_disp_[i] = glb_splt_disp[i];
  955. }
  956. MPI_Allgatherv((splt_count ? &splitters[0] : nullptr), splt_count, CommDatatype<Type>::value(), &glb_splitters[0], &glb_splt_cnts_[0], &glb_splt_disp_[0], CommDatatype<Type>::value(), comm);
  957. }
  958. }
  959. // Determine split key. O( log(N/p) + log(p) )
  960. Vector<Long> lrank(glb_splt_count);
  961. { // Compute local rank
  962. #pragma omp parallel for schedule(static)
  963. for (Integer i = 0; i < glb_splt_count; i++) {
  964. lrank[i] = std::lower_bound(arr.begin(), arr.end(), glb_splitters[i], comp) - arr.begin();
  965. }
  966. }
  967. Vector<Long> grank(glb_splt_count);
  968. { // Compute global rank
  969. MPI_Allreduce(&lrank[0], &grank[0], glb_splt_count, CommDatatype<Long>::value(), CommDatatype<Long>::sum(), comm);
  970. }
  971. { // Determine split_key, totSize_new
  972. Integer splitter_idx = 0;
  973. for (Integer i = 0; i < glb_splt_count; i++) {
  974. if (labs(grank[i] - totSize / 2) < labs(grank[splitter_idx] - totSize / 2)) {
  975. splitter_idx = i;
  976. }
  977. }
  978. split_key = glb_splitters[splitter_idx];
  979. if (myrank <= (npes - 1) / 2)
  980. totSize_new = grank[splitter_idx];
  981. else
  982. totSize_new = totSize - grank[splitter_idx];
  983. // double err=(((double)grank[splitter_idx])/(totSize/2))-1.0;
  984. // if(fabs<double>(err)<0.01 || npes<=16) break;
  985. // else if(!myrank) std::cout<<err<<'\n';
  986. }
  987. }
  988. Integer split_id = (npes - 1) / 2;
  989. { // Split problem into two. O( N/p )
  990. Integer partner;
  991. { // Set partner
  992. partner = myrank + (split_id+1) * (myrank<=split_id ? 1 : -1);
  993. if (partner >= npes) partner = npes - 1;
  994. assert(partner >= 0);
  995. }
  996. bool extra_partner = (npes % 2 == 1 && myrank == npes - 1);
  997. Long ssize = 0, lsize = 0;
  998. ConstIterator<Type> sbuff, lbuff;
  999. { // Set ssize, lsize, sbuff, lbuff
  1000. Long split_indx = std::lower_bound(arr.begin(), arr.end(), split_key, comp) - arr.begin();
  1001. ssize = (myrank > split_id ? split_indx : arr.Dim() - split_indx);
  1002. sbuff = (myrank > split_id ? arr.begin() : arr.begin() + split_indx);
  1003. lsize = (myrank <= split_id ? split_indx : arr.Dim() - split_indx);
  1004. lbuff = (myrank <= split_id ? arr.begin() : arr.begin() + split_indx);
  1005. }
  1006. Long rsize = 0, ext_rsize = 0;
  1007. { // Get rsize, ext_rsize
  1008. Long ext_ssize = 0;
  1009. MPI_Status status;
  1010. MPI_Sendrecv(&ssize, 1, CommDatatype<Long>::value(), partner, 0, &rsize, 1, CommDatatype<Long>::value(), partner, 0, comm, &status);
  1011. if (extra_partner) MPI_Sendrecv(&ext_ssize, 1, CommDatatype<Long>::value(), split_id, 0, &ext_rsize, 1, CommDatatype<Long>::value(), split_id, 0, comm, &status);
  1012. }
  1013. { // Exchange data.
  1014. rbuff.ReInit(rsize);
  1015. rbuff_ext.ReInit(ext_rsize);
  1016. MPI_Status status;
  1017. MPI_Sendrecv((ssize ? &sbuff[0] : nullptr), ssize, CommDatatype<Type>::value(), partner, 0, (rsize ? &rbuff[0] : nullptr), rsize, CommDatatype<Type>::value(), partner, 0, comm, &status);
  1018. if (extra_partner) MPI_Sendrecv(nullptr, 0, CommDatatype<Type>::value(), split_id, 0, (ext_rsize ? &rbuff_ext[0] : nullptr), ext_rsize, CommDatatype<Type>::value(), split_id, 0, comm, &status);
  1019. }
  1020. { // nbuff <-- merge(lbuff, rbuff, rbuff_ext)
  1021. nbuff.ReInit(lsize + rsize);
  1022. omp_par::merge<ConstIterator<Type>>(lbuff, (lbuff + lsize), rbuff.begin(), rbuff.begin() + rsize, nbuff.begin(), omp_p, comp);
  1023. if (ext_rsize > 0) {
  1024. if (nbuff.Dim() > 0) {
  1025. nbuff_ext.ReInit(lsize + rsize + ext_rsize);
  1026. omp_par::merge(nbuff.begin(), nbuff.begin() + (lsize + rsize), rbuff_ext.begin(), rbuff_ext.begin() + ext_rsize, nbuff_ext.begin(), omp_p, comp);
  1027. nbuff.Swap(nbuff_ext);
  1028. nbuff_ext.ReInit(0);
  1029. } else {
  1030. nbuff.Swap(rbuff_ext);
  1031. }
  1032. }
  1033. }
  1034. // Copy new data.
  1035. totSize = totSize_new;
  1036. arr.Swap(nbuff);
  1037. }
  1038. { // Split comm. O( log(p) ) ??
  1039. MPI_Comm scomm;
  1040. #pragma omp critical(SCTL_COMM_DUP)
  1041. MPI_Comm_split(comm, myrank <= split_id, myrank, &scomm);
  1042. #pragma omp critical(SCTL_COMM_DUP)
  1043. if (free_comm) MPI_Comm_free(&comm);
  1044. comm = scomm;
  1045. free_comm = true;
  1046. npes = (myrank <= split_id ? split_id + 1 : npes - split_id - 1);
  1047. myrank = (myrank <= split_id ? myrank : myrank - split_id - 1);
  1048. }
  1049. }
  1050. #pragma omp critical(SCTL_COMM_DUP)
  1051. if (free_comm) MPI_Comm_free(&comm);
  1052. SortedElem = arr;
  1053. PartitionW<Type>(SortedElem);
  1054. #else
  1055. SortedElem = arr_;
  1056. std::sort(SortedElem.begin(), SortedElem.begin() + SortedElem.Dim(), comp);
  1057. #endif
  1058. }
  1059. } // end namespace