comm.txx 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165
  1. #include <type_traits>
  2. #include SCTL_INCLUDE(ompUtils.hpp)
  3. #include SCTL_INCLUDE(vector.hpp)
  4. namespace SCTL_NAMESPACE {
  5. inline Comm::Comm() {
  6. #ifdef SCTL_HAVE_MPI
  7. Init(MPI_COMM_SELF);
  8. #endif
  9. }
  10. inline Comm::Comm(const Comm& c) {
  11. #ifdef SCTL_HAVE_MPI
  12. Init(c.mpi_comm_);
  13. #endif
  14. }
  15. inline Comm Comm::Self() {
  16. #ifdef SCTL_HAVE_MPI
  17. Comm comm_self(MPI_COMM_SELF);
  18. return comm_self;
  19. #else
  20. Comm comm_self;
  21. return comm_self;
  22. #endif
  23. }
  24. inline Comm Comm::World() {
  25. #ifdef SCTL_HAVE_MPI
  26. Comm comm_world(MPI_COMM_WORLD);
  27. return comm_world;
  28. #else
  29. Comm comm_self;
  30. return comm_self;
  31. #endif
  32. }
  33. inline Comm& Comm::operator=(const Comm& c) {
  34. #ifdef SCTL_HAVE_MPI
  35. #pragma omp critical(SCTL_COMM_DUP)
  36. MPI_Comm_free(&mpi_comm_);
  37. Init(c.mpi_comm_);
  38. #endif
  39. return *this;
  40. }
  41. inline Comm::~Comm() {
  42. #ifdef SCTL_HAVE_MPI
  43. while (!req.empty()) {
  44. delete (Vector<MPI_Request>*)req.top();
  45. req.pop();
  46. }
  47. #pragma omp critical(SCTL_COMM_DUP)
  48. MPI_Comm_free(&mpi_comm_);
  49. #endif
  50. }
  51. inline Comm Comm::Split(Integer clr) const {
  52. #ifdef SCTL_HAVE_MPI
  53. MPI_Comm new_comm;
  54. #pragma omp critical(SCTL_COMM_DUP)
  55. MPI_Comm_split(mpi_comm_, clr, mpi_rank_, &new_comm);
  56. Comm c(new_comm);
  57. #pragma omp critical(SCTL_COMM_DUP)
  58. MPI_Comm_free(&new_comm);
  59. return c;
  60. #else
  61. Comm c;
  62. return c;
  63. #endif
  64. }
  65. inline Integer Comm::Rank() const {
  66. #ifdef SCTL_HAVE_MPI
  67. return mpi_rank_;
  68. #else
  69. return 0;
  70. #endif
  71. }
  72. inline Integer Comm::Size() const {
  73. #ifdef SCTL_HAVE_MPI
  74. return mpi_size_;
  75. #else
  76. return 1;
  77. #endif
  78. }
  79. inline void Comm::Barrier() const {
  80. #ifdef SCTL_HAVE_MPI
  81. MPI_Barrier(mpi_comm_);
  82. #endif
  83. }
  84. template <class SType> void* Comm::Isend(ConstIterator<SType> sbuf, Long scount, Integer dest, Integer tag) const {
  85. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  86. #ifdef SCTL_HAVE_MPI
  87. if (!scount) return nullptr;
  88. Vector<MPI_Request>& request = *NewReq();
  89. request.ReInit(1);
  90. SCTL_UNUSED(sbuf[0] );
  91. SCTL_UNUSED(sbuf[scount - 1]);
  92. #ifndef NDEBUG
  93. MPI_Issend(&sbuf[0], scount, CommDatatype<SType>::value(), dest, tag, mpi_comm_, &request[0]);
  94. #else
  95. MPI_Isend(&sbuf[0], scount, CommDatatype<SType>::value(), dest, tag, mpi_comm_, &request[0]);
  96. #endif
  97. return &request;
  98. #else
  99. auto it = recv_req.find(tag);
  100. if (it == recv_req.end())
  101. send_req.insert(std::pair<Integer, ConstIterator<char>>(tag, (ConstIterator<char>)sbuf));
  102. else
  103. memcopy(it->second, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  104. return nullptr;
  105. #endif
  106. }
  107. template <class RType> void* Comm::Irecv(Iterator<RType> rbuf, Long rcount, Integer source, Integer tag) const {
  108. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  109. #ifdef SCTL_HAVE_MPI
  110. if (!rcount) return nullptr;
  111. Vector<MPI_Request>& request = *NewReq();
  112. request.ReInit(1);
  113. SCTL_UNUSED(rbuf[0] );
  114. SCTL_UNUSED(rbuf[rcount - 1]);
  115. MPI_Irecv(&rbuf[0], rcount, CommDatatype<RType>::value(), source, tag, mpi_comm_, &request[0]);
  116. return &request;
  117. #else
  118. auto it = send_req.find(tag);
  119. if (it == send_req.end())
  120. recv_req.insert(std::pair<Integer, Iterator<char>>(tag, (Iterator<char>)rbuf));
  121. else
  122. memcopy((Iterator<char>)rbuf, it->second, rcount * sizeof(RType));
  123. return nullptr;
  124. #endif
  125. }
  126. inline void Comm::Wait(void* req_ptr) const {
  127. #ifdef SCTL_HAVE_MPI
  128. if (req_ptr == nullptr) return;
  129. Vector<MPI_Request>& request = *(Vector<MPI_Request>*)req_ptr;
  130. // std::vector<MPI_Status> status(request.Dim());
  131. if (request.Dim()) MPI_Waitall(request.Dim(), &request[0], MPI_STATUSES_IGNORE); //&status[0]);
  132. DelReq(&request);
  133. #endif
  134. }
  135. template <class SType, class RType> void Comm::Allgather(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const {
  136. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  137. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  138. #ifdef SCTL_HAVE_MPI
  139. if (scount) {
  140. SCTL_UNUSED(sbuf[0] );
  141. SCTL_UNUSED(sbuf[scount - 1]);
  142. }
  143. if (rcount) {
  144. SCTL_UNUSED(rbuf[0] );
  145. SCTL_UNUSED(rbuf[rcount * Size() - 1]);
  146. }
  147. MPI_Allgather((scount ? &sbuf[0] : nullptr), scount, CommDatatype<SType>::value(), (rcount ? &rbuf[0] : nullptr), rcount, CommDatatype<RType>::value(), mpi_comm_);
  148. #else
  149. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  150. #endif
  151. }
  152. template <class SType, class RType> void Comm::Allgatherv(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const {
  153. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  154. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  155. #ifdef SCTL_HAVE_MPI
  156. Vector<int> rcounts_(mpi_size_), rdispls_(mpi_size_);
  157. Long rcount_sum = 0;
  158. #pragma omp parallel for schedule(static) reduction(+ : rcount_sum)
  159. for (Integer i = 0; i < mpi_size_; i++) {
  160. rcounts_[i] = rcounts[i];
  161. rdispls_[i] = rdispls[i];
  162. rcount_sum += rcounts[i];
  163. }
  164. if (scount) {
  165. SCTL_UNUSED(sbuf[0] );
  166. SCTL_UNUSED(sbuf[scount - 1]);
  167. }
  168. if (rcount_sum) {
  169. SCTL_UNUSED(rbuf[0] );
  170. SCTL_UNUSED(rbuf[rcount_sum - 1]);
  171. }
  172. MPI_Allgatherv((scount ? &sbuf[0] : nullptr), scount, CommDatatype<SType>::value(), (rcount_sum ? &rbuf[0] : nullptr), &rcounts_.begin()[0], &rdispls_.begin()[0], CommDatatype<RType>::value(), mpi_comm_);
  173. #else
  174. memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)sbuf, scount * sizeof(SType));
  175. #endif
  176. }
  177. template <class SType, class RType> void Comm::Alltoall(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const {
  178. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  179. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  180. #ifdef SCTL_HAVE_MPI
  181. if (scount) {
  182. SCTL_UNUSED(sbuf[0] );
  183. SCTL_UNUSED(sbuf[scount * Size() - 1]);
  184. }
  185. if (rcount) {
  186. SCTL_UNUSED(rbuf[0] );
  187. SCTL_UNUSED(rbuf[rcount * Size() - 1]);
  188. }
  189. MPI_Alltoall((scount ? &sbuf[0] : nullptr), scount, CommDatatype<SType>::value(), (rcount ? &rbuf[0] : nullptr), rcount, CommDatatype<RType>::value(), mpi_comm_);
  190. #else
  191. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  192. #endif
  193. }
  194. template <class SType, class RType> void* Comm::Ialltoallv_sparse(ConstIterator<SType> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls, Integer tag) const {
  195. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  196. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  197. #ifdef SCTL_HAVE_MPI
  198. Integer request_count = 0;
  199. for (Integer i = 0; i < mpi_size_; i++) {
  200. if (rcounts[i]) request_count++;
  201. if (scounts[i]) request_count++;
  202. }
  203. if (!request_count) return nullptr;
  204. Vector<MPI_Request>& request = *NewReq();
  205. request.ReInit(request_count);
  206. Integer request_iter = 0;
  207. for (Integer i = 0; i < mpi_size_; i++) {
  208. if (rcounts[i]) {
  209. SCTL_UNUSED(rbuf[rdispls[i]]);
  210. SCTL_UNUSED(rbuf[rdispls[i] + rcounts[i] - 1]);
  211. MPI_Irecv(&rbuf[rdispls[i]], rcounts[i], CommDatatype<RType>::value(), i, tag, mpi_comm_, &request[request_iter]);
  212. request_iter++;
  213. }
  214. }
  215. for (Integer i = 0; i < mpi_size_; i++) {
  216. if (scounts[i]) {
  217. SCTL_UNUSED(sbuf[sdispls[i]]);
  218. SCTL_UNUSED(sbuf[sdispls[i] + scounts[i] - 1]);
  219. MPI_Isend(&sbuf[sdispls[i]], scounts[i], CommDatatype<SType>::value(), i, tag, mpi_comm_, &request[request_iter]);
  220. request_iter++;
  221. }
  222. }
  223. return &request;
  224. #else
  225. memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)(sbuf + sdispls[0]), scounts[0] * sizeof(SType));
  226. return nullptr;
  227. #endif
  228. }
  229. template <class Type> void Comm::Alltoallv(ConstIterator<Type> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<Type> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const {
  230. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  231. #ifdef SCTL_HAVE_MPI
  232. { // Use Alltoallv_sparse of average connectivity<64
  233. Long connectivity = 0, glb_connectivity = 0;
  234. #pragma omp parallel for schedule(static) reduction(+ : connectivity)
  235. for (Integer i = 0; i < mpi_size_; i++) {
  236. if (rcounts[i]) connectivity++;
  237. }
  238. Allreduce(Ptr2ConstItr<Long>(&connectivity, 1), Ptr2Itr<Long>(&glb_connectivity, 1), 1, CommOp::SUM);
  239. if (glb_connectivity < 64 * Size()) {
  240. void* mpi_req = Ialltoallv_sparse(sbuf, scounts, sdispls, rbuf, rcounts, rdispls, 0);
  241. Wait(mpi_req);
  242. return;
  243. }
  244. }
  245. { // Use vendor MPI_Alltoallv
  246. //#ifndef ALLTOALLV_FIX
  247. Vector<int> scnt, sdsp, rcnt, rdsp;
  248. scnt.ReInit(mpi_size_);
  249. sdsp.ReInit(mpi_size_);
  250. rcnt.ReInit(mpi_size_);
  251. rdsp.ReInit(mpi_size_);
  252. Long stotal = 0, rtotal = 0;
  253. #pragma omp parallel for schedule(static) reduction(+ : stotal, rtotal)
  254. for (Integer i = 0; i < mpi_size_; i++) {
  255. scnt[i] = scounts[i];
  256. sdsp[i] = sdispls[i];
  257. rcnt[i] = rcounts[i];
  258. rdsp[i] = rdispls[i];
  259. stotal += scounts[i];
  260. rtotal += rcounts[i];
  261. }
  262. MPI_Alltoallv((stotal ? &sbuf[0] : nullptr), &scnt[0], &sdsp[0], CommDatatype<Type>::value(), (rtotal ? &rbuf[0] : nullptr), &rcnt[0], &rdsp[0], CommDatatype<Type>::value(), mpi_comm_);
  263. return;
  264. //#endif
  265. }
  266. // TODO: implement hypercube scheme
  267. #else
  268. memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)(sbuf + sdispls[0]), scounts[0] * sizeof(Type));
  269. #endif
  270. }
  271. template <class Type> void Comm::Allreduce(ConstIterator<Type> sbuf, Iterator<Type> rbuf, Long count, CommOp op) const {
  272. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  273. #ifdef SCTL_HAVE_MPI
  274. if (!count) return;
  275. MPI_Op mpi_op;
  276. switch (op) {
  277. case CommOp::SUM:
  278. mpi_op = CommDatatype<Type>::sum();
  279. break;
  280. case CommOp::MIN:
  281. mpi_op = CommDatatype<Type>::min();
  282. break;
  283. case CommOp::MAX:
  284. mpi_op = CommDatatype<Type>::max();
  285. break;
  286. default:
  287. mpi_op = MPI_OP_NULL;
  288. break;
  289. }
  290. SCTL_UNUSED(sbuf[0] );
  291. SCTL_UNUSED(sbuf[count - 1]);
  292. SCTL_UNUSED(rbuf[0] );
  293. SCTL_UNUSED(rbuf[count - 1]);
  294. MPI_Allreduce(&sbuf[0], &rbuf[0], count, CommDatatype<Type>::value(), mpi_op, mpi_comm_);
  295. #else
  296. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, count * sizeof(Type));
  297. #endif
  298. }
  299. template <class Type> void Comm::Scan(ConstIterator<Type> sbuf, Iterator<Type> rbuf, int count, CommOp op) const {
  300. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  301. #ifdef SCTL_HAVE_MPI
  302. if (!count) return;
  303. MPI_Op mpi_op;
  304. switch (op) {
  305. case CommOp::SUM:
  306. mpi_op = CommDatatype<Type>::sum();
  307. break;
  308. case CommOp::MIN:
  309. mpi_op = CommDatatype<Type>::min();
  310. break;
  311. case CommOp::MAX:
  312. mpi_op = CommDatatype<Type>::max();
  313. break;
  314. default:
  315. mpi_op = MPI_OP_NULL;
  316. break;
  317. }
  318. SCTL_UNUSED(sbuf[0] );
  319. SCTL_UNUSED(sbuf[count - 1]);
  320. SCTL_UNUSED(rbuf[0] );
  321. SCTL_UNUSED(rbuf[count - 1]);
  322. MPI_Scan(&sbuf[0], &rbuf[0], count, CommDatatype<Type>::value(), mpi_op, mpi_comm_);
  323. #else
  324. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, count * sizeof(Type));
  325. #endif
  326. }
  327. template <class Type> void Comm::PartitionW(Vector<Type>& nodeList, const Vector<Long>* wts_) const {
  328. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  329. Integer npes = Size();
  330. if (npes == 1) return;
  331. Long nlSize = nodeList.Dim();
  332. Vector<Long> wts;
  333. Long localWt = 0;
  334. if (wts_ == nullptr) { // Construct arrays of wts.
  335. wts.ReInit(nlSize);
  336. #pragma omp parallel for schedule(static)
  337. for (Long i = 0; i < nlSize; i++) {
  338. wts[i] = 1;
  339. }
  340. localWt = nlSize;
  341. } else {
  342. wts.ReInit(nlSize, (Iterator<Long>)wts_->begin(), false);
  343. #pragma omp parallel for reduction(+ : localWt)
  344. for (Long i = 0; i < nlSize; i++) {
  345. localWt += wts[i];
  346. }
  347. }
  348. Long off1 = 0, off2 = 0, totalWt = 0;
  349. { // compute the total weight of the problem ...
  350. Allreduce<Long>(Ptr2ConstItr<Long>(&localWt, 1), Ptr2Itr<Long>(&totalWt, 1), 1, CommOp::SUM);
  351. Scan<Long>(Ptr2ConstItr<Long>(&localWt, 1), Ptr2Itr<Long>(&off2, 1), 1, CommOp::SUM);
  352. off1 = off2 - localWt;
  353. }
  354. Vector<Long> lscn;
  355. if (nlSize) { // perform a local scan on the weights first ...
  356. lscn.ReInit(nlSize);
  357. lscn[0] = off1;
  358. omp_par::scan(wts.begin(), lscn.begin(), nlSize);
  359. }
  360. Vector<Long> sendSz, recvSz, sendOff, recvOff;
  361. sendSz.ReInit(npes);
  362. recvSz.ReInit(npes);
  363. sendOff.ReInit(npes);
  364. recvOff.ReInit(npes);
  365. sendSz.SetZero();
  366. if (nlSize > 0 && totalWt > 0) { // Compute sendSz
  367. Long pid1 = (off1 * npes) / totalWt;
  368. Long pid2 = ((off2 + 1) * npes) / totalWt + 1;
  369. assert((totalWt * pid2) / npes >= off2);
  370. pid1 = (pid1 < 0 ? 0 : pid1);
  371. pid2 = (pid2 > npes ? npes : pid2);
  372. #pragma omp parallel for schedule(static)
  373. for (Integer i = pid1; i < pid2; i++) {
  374. Long wt1 = (totalWt * (i)) / npes;
  375. Long wt2 = (totalWt * (i + 1)) / npes;
  376. Long start = std::lower_bound(lscn.begin(), lscn.begin() + nlSize, wt1, std::less<Long>()) - lscn.begin();
  377. Long end = std::lower_bound(lscn.begin(), lscn.begin() + nlSize, wt2, std::less<Long>()) - lscn.begin();
  378. if (i == 0) start = 0;
  379. if (i == npes - 1) end = nlSize;
  380. sendSz[i] = end - start;
  381. }
  382. } else {
  383. sendSz[0] = nlSize;
  384. }
  385. // Exchange sendSz, recvSz
  386. Alltoall<Long>(sendSz.begin(), 1, recvSz.begin(), 1);
  387. { // Compute sendOff, recvOff
  388. sendOff[0] = 0;
  389. omp_par::scan(sendSz.begin(), sendOff.begin(), npes);
  390. recvOff[0] = 0;
  391. omp_par::scan(recvSz.begin(), recvOff.begin(), npes);
  392. assert(sendOff[npes - 1] + sendSz[npes - 1] == nlSize);
  393. }
  394. // perform All2All ...
  395. Vector<Type> newNodes;
  396. newNodes.ReInit(recvSz[npes - 1] + recvOff[npes - 1]);
  397. void* mpi_req = Ialltoallv_sparse<Type>(nodeList.begin(), sendSz.begin(), sendOff.begin(), newNodes.begin(), recvSz.begin(), recvOff.begin());
  398. Wait(mpi_req);
  399. // reset the pointer ...
  400. nodeList.Swap(newNodes);
  401. }
  402. template <class Type> void Comm::PartitionN(Vector<Type>& v, Long N) const {
  403. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  404. Integer rank = Rank();
  405. Integer np = Size();
  406. if (np == 1) return;
  407. Vector<Long> v_cnt(np), v_dsp(np + 1);
  408. Vector<Long> N_cnt(np), N_dsp(np + 1);
  409. { // Set v_cnt, v_dsp
  410. v_dsp[0] = 0;
  411. Long cnt = v.Dim();
  412. Allgather(Ptr2ConstItr<Long>(&cnt, 1), 1, v_cnt.begin(), 1);
  413. omp_par::scan(v_cnt.begin(), v_dsp.begin(), np);
  414. v_dsp[np] = v_cnt[np - 1] + v_dsp[np - 1];
  415. }
  416. { // Set N_cnt, N_dsp
  417. N_dsp[0] = 0;
  418. Long cnt = N;
  419. Allgather(Ptr2ConstItr<Long>(&cnt, 1), 1, N_cnt.begin(), 1);
  420. omp_par::scan(N_cnt.begin(), N_dsp.begin(), np);
  421. N_dsp[np] = N_cnt[np - 1] + N_dsp[np - 1];
  422. }
  423. { // Adjust for dof
  424. Long dof = (N_dsp[np] ? v_dsp[np] / N_dsp[np] : 0);
  425. assert(dof * N_dsp[np] == v_dsp[np]);
  426. if (dof == 0) return;
  427. if (dof != 1) {
  428. #pragma omp parallel for schedule(static)
  429. for (Integer i = 0; i < np; i++) N_cnt[i] *= dof;
  430. #pragma omp parallel for schedule(static)
  431. for (Integer i = 0; i <= np; i++) N_dsp[i] *= dof;
  432. }
  433. }
  434. Vector<Type> v_(N_cnt[rank]);
  435. { // Set v_
  436. Vector<Long> scnt(np), sdsp(np);
  437. Vector<Long> rcnt(np), rdsp(np);
  438. #pragma omp parallel for schedule(static)
  439. for (Integer i = 0; i < np; i++) {
  440. { // Set scnt
  441. Long n0 = N_dsp[i + 0];
  442. Long n1 = N_dsp[i + 1];
  443. if (n0 < v_dsp[rank + 0]) n0 = v_dsp[rank + 0];
  444. if (n1 < v_dsp[rank + 0]) n1 = v_dsp[rank + 0];
  445. if (n0 > v_dsp[rank + 1]) n0 = v_dsp[rank + 1];
  446. if (n1 > v_dsp[rank + 1]) n1 = v_dsp[rank + 1];
  447. scnt[i] = n1 - n0;
  448. }
  449. { // Set rcnt
  450. Long n0 = v_dsp[i + 0];
  451. Long n1 = v_dsp[i + 1];
  452. if (n0 < N_dsp[rank + 0]) n0 = N_dsp[rank + 0];
  453. if (n1 < N_dsp[rank + 0]) n1 = N_dsp[rank + 0];
  454. if (n0 > N_dsp[rank + 1]) n0 = N_dsp[rank + 1];
  455. if (n1 > N_dsp[rank + 1]) n1 = N_dsp[rank + 1];
  456. rcnt[i] = n1 - n0;
  457. }
  458. }
  459. sdsp[0] = 0;
  460. omp_par::scan(scnt.begin(), sdsp.begin(), np);
  461. rdsp[0] = 0;
  462. omp_par::scan(rcnt.begin(), rdsp.begin(), np);
  463. void* mpi_request = Ialltoallv_sparse(v.begin(), scnt.begin(), sdsp.begin(), v_.begin(), rcnt.begin(), rdsp.begin());
  464. Wait(mpi_request);
  465. }
  466. v.Swap(v_);
  467. }
  468. template <class Type> void Comm::PartitionS(Vector<Type>& nodeList, const Type& splitter) const {
  469. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  470. Integer npes = Size();
  471. if (npes == 1) return;
  472. Vector<Type> mins(npes);
  473. Allgather(Ptr2ConstItr<Type>(&splitter, 1), 1, mins.begin(), 1);
  474. Vector<Long> scnt(npes), sdsp(npes);
  475. Vector<Long> rcnt(npes), rdsp(npes);
  476. { // Compute scnt, sdsp
  477. #pragma omp parallel for schedule(static)
  478. for (Integer i = 0; i < npes; i++) {
  479. sdsp[i] = std::lower_bound(nodeList.begin(), nodeList.begin() + nodeList.Dim(), mins[i]) - nodeList.begin();
  480. }
  481. #pragma omp parallel for schedule(static)
  482. for (Integer i = 0; i < npes - 1; i++) {
  483. scnt[i] = sdsp[i + 1] - sdsp[i];
  484. }
  485. scnt[npes - 1] = nodeList.Dim() - sdsp[npes - 1];
  486. }
  487. { // Compute rcnt, rdsp
  488. rdsp[0] = 0;
  489. Alltoall(scnt.begin(), 1, rcnt.begin(), 1);
  490. omp_par::scan(rcnt.begin(), rdsp.begin(), npes);
  491. }
  492. { // Redistribute nodeList
  493. Vector<Type> nodeList_(rdsp[npes - 1] + rcnt[npes - 1]);
  494. void* mpi_request = Ialltoallv_sparse(nodeList.begin(), scnt.begin(), sdsp.begin(), nodeList_.begin(), rcnt.begin(), rdsp.begin());
  495. Wait(mpi_request);
  496. nodeList.Swap(nodeList_);
  497. }
  498. }
  499. template <class Type> void Comm::SortScatterIndex(const Vector<Type>& key, Vector<Long>& scatter_index, const Type* split_key_) const {
  500. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  501. typedef SortPair<Type, Long> Pair_t;
  502. Integer npes = Size(), rank = Rank();
  503. Vector<Pair_t> parray(key.Dim());
  504. { // Build global index.
  505. Long glb_dsp = 0;
  506. Long loc_size = key.Dim();
  507. Scan(Ptr2ConstItr<Long>(&loc_size, 1), Ptr2Itr<Long>(&glb_dsp, 1), 1, CommOp::SUM);
  508. glb_dsp -= loc_size;
  509. #pragma omp parallel for schedule(static)
  510. for (Long i = 0; i < loc_size; i++) {
  511. parray[i].key = key[i];
  512. parray[i].data = glb_dsp + i;
  513. }
  514. }
  515. Vector<Pair_t> psorted;
  516. HyperQuickSort(parray, psorted);
  517. if (npes > 1 && split_key_ != nullptr) { // Partition data
  518. Vector<Type> split_key(npes);
  519. Allgather(Ptr2ConstItr<Type>(split_key_, 1), 1, split_key.begin(), 1);
  520. Vector<Long> sendSz(npes);
  521. Vector<Long> recvSz(npes);
  522. Vector<Long> sendOff(npes);
  523. Vector<Long> recvOff(npes);
  524. Long nlSize = psorted.Dim();
  525. sendSz.SetZero();
  526. if (nlSize > 0) { // Compute sendSz
  527. // Determine processor range.
  528. Long pid1 = std::lower_bound(split_key.begin(), split_key.begin() + npes, psorted[0].key) - split_key.begin() - 1;
  529. Long pid2 = std::upper_bound(split_key.begin(), split_key.begin() + npes, psorted[nlSize - 1].key) - split_key.begin() + 1;
  530. pid1 = (pid1 < 0 ? 0 : pid1);
  531. pid2 = (pid2 > npes ? npes : pid2);
  532. #pragma omp parallel for schedule(static)
  533. for (Integer i = pid1; i < pid2; i++) {
  534. Pair_t p1;
  535. p1.key = split_key[i];
  536. Pair_t p2;
  537. p2.key = split_key[i + 1 < npes ? i + 1 : i];
  538. Long start = std::lower_bound(psorted.begin(), psorted.begin() + nlSize, p1, std::less<Pair_t>()) - psorted.begin();
  539. Long end = std::lower_bound(psorted.begin(), psorted.begin() + nlSize, p2, std::less<Pair_t>()) - psorted.begin();
  540. if (i == 0) start = 0;
  541. if (i == npes - 1) end = nlSize;
  542. sendSz[i] = end - start;
  543. }
  544. }
  545. // Exchange sendSz, recvSz
  546. Alltoall<Long>(sendSz.begin(), 1, recvSz.begin(), 1);
  547. // compute offsets ...
  548. { // Compute sendOff, recvOff
  549. sendOff[0] = 0;
  550. omp_par::scan(sendSz.begin(), sendOff.begin(), npes);
  551. recvOff[0] = 0;
  552. omp_par::scan(recvSz.begin(), recvOff.begin(), npes);
  553. assert(sendOff[npes - 1] + sendSz[npes - 1] == nlSize);
  554. }
  555. // perform All2All ...
  556. Vector<Pair_t> newNodes(recvSz[npes - 1] + recvOff[npes - 1]);
  557. void* mpi_req = Ialltoallv_sparse<Pair_t>(psorted.begin(), sendSz.begin(), sendOff.begin(), newNodes.begin(), recvSz.begin(), recvOff.begin());
  558. Wait(mpi_req);
  559. // reset the pointer ...
  560. psorted.Swap(newNodes);
  561. }
  562. scatter_index.ReInit(psorted.Dim());
  563. #pragma omp parallel for schedule(static)
  564. for (Long i = 0; i < psorted.Dim(); i++) {
  565. scatter_index[i] = psorted[i].data;
  566. }
  567. }
  568. template <class Type> void Comm::ScatterForward(Vector<Type>& data_, const Vector<Long>& scatter_index) const {
  569. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  570. typedef SortPair<Long, Long> Pair_t;
  571. Integer npes = Size(), rank = Rank();
  572. Long data_dim = 0;
  573. Long send_size = 0;
  574. Long recv_size = 0;
  575. { // Set data_dim, send_size, recv_size
  576. recv_size = scatter_index.Dim();
  577. StaticArray<Long, 2> glb_size;
  578. StaticArray<Long, 2> loc_size;
  579. loc_size[0] = data_.Dim();
  580. loc_size[1] = recv_size;
  581. Allreduce<Long>(loc_size, glb_size, 2, CommOp::SUM);
  582. if (glb_size[0] == 0 || glb_size[1] == 0) return; // Nothing to be done.
  583. data_dim = glb_size[0] / glb_size[1];
  584. SCTL_ASSERT(glb_size[0] == data_dim * glb_size[1]);
  585. send_size = data_.Dim() / data_dim;
  586. }
  587. if (npes == 1) { // Scatter directly
  588. Vector<Type> data;
  589. data.ReInit(recv_size * data_dim);
  590. #pragma omp parallel for schedule(static)
  591. for (Long i = 0; i < recv_size; i++) {
  592. Long src_indx = scatter_index[i] * data_dim;
  593. Long trg_indx = i * data_dim;
  594. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = data_[src_indx + j];
  595. }
  596. data_.Swap(data);
  597. return;
  598. }
  599. Vector<Long> glb_scan;
  600. { // Global scan of data size.
  601. glb_scan.ReInit(npes);
  602. Long glb_rank = 0;
  603. Scan(Ptr2ConstItr<Long>(&send_size, 1), Ptr2Itr<Long>(&glb_rank, 1), 1, CommOp::SUM);
  604. glb_rank -= send_size;
  605. Allgather(Ptr2ConstItr<Long>(&glb_rank, 1), 1, glb_scan.begin(), 1);
  606. }
  607. Vector<Pair_t> psorted;
  608. { // Sort scatter_index.
  609. psorted.ReInit(recv_size);
  610. #pragma omp parallel for schedule(static)
  611. for (Long i = 0; i < recv_size; i++) {
  612. psorted[i].key = scatter_index[i];
  613. psorted[i].data = i;
  614. }
  615. omp_par::merge_sort(psorted.begin(), psorted.begin() + recv_size);
  616. }
  617. Vector<Long> recv_indx(recv_size);
  618. Vector<Long> send_indx(send_size);
  619. Vector<Long> sendSz(npes);
  620. Vector<Long> sendOff(npes);
  621. Vector<Long> recvSz(npes);
  622. Vector<Long> recvOff(npes);
  623. { // Exchange send, recv indices.
  624. #pragma omp parallel for schedule(static)
  625. for (Long i = 0; i < recv_size; i++) {
  626. recv_indx[i] = psorted[i].key;
  627. }
  628. #pragma omp parallel for schedule(static)
  629. for (Integer i = 0; i < npes; i++) {
  630. Long start = std::lower_bound(recv_indx.begin(), recv_indx.begin() + recv_size, glb_scan[i]) - recv_indx.begin();
  631. Long end = (i + 1 < npes ? std::lower_bound(recv_indx.begin(), recv_indx.begin() + recv_size, glb_scan[i + 1]) - recv_indx.begin() : recv_size);
  632. recvSz[i] = end - start;
  633. recvOff[i] = start;
  634. }
  635. Alltoall(recvSz.begin(), 1, sendSz.begin(), 1);
  636. sendOff[0] = 0;
  637. omp_par::scan(sendSz.begin(), sendOff.begin(), npes);
  638. assert(sendOff[npes - 1] + sendSz[npes - 1] == send_size);
  639. Alltoallv(recv_indx.begin(), recvSz.begin(), recvOff.begin(), send_indx.begin(), sendSz.begin(), sendOff.begin());
  640. #pragma omp parallel for schedule(static)
  641. for (Long i = 0; i < send_size; i++) {
  642. assert(send_indx[i] >= glb_scan[rank]);
  643. send_indx[i] -= glb_scan[rank];
  644. assert(send_indx[i] < send_size);
  645. }
  646. }
  647. Vector<Type> send_buff;
  648. { // Prepare send buffer
  649. send_buff.ReInit(send_size * data_dim);
  650. ConstIterator<Type> data = data_.begin();
  651. #pragma omp parallel for schedule(static)
  652. for (Long i = 0; i < send_size; i++) {
  653. Long src_indx = send_indx[i] * data_dim;
  654. Long trg_indx = i * data_dim;
  655. for (Long j = 0; j < data_dim; j++) send_buff[trg_indx + j] = data[src_indx + j];
  656. }
  657. }
  658. Vector<Type> recv_buff;
  659. { // All2Allv
  660. recv_buff.ReInit(recv_size * data_dim);
  661. #pragma omp parallel for schedule(static)
  662. for (Integer i = 0; i < npes; i++) {
  663. sendSz[i] *= data_dim;
  664. sendOff[i] *= data_dim;
  665. recvSz[i] *= data_dim;
  666. recvOff[i] *= data_dim;
  667. }
  668. Alltoallv(send_buff.begin(), sendSz.begin(), sendOff.begin(), recv_buff.begin(), recvSz.begin(), recvOff.begin());
  669. }
  670. { // Build output data.
  671. data_.ReInit(recv_size * data_dim);
  672. Iterator<Type> data = data_.begin();
  673. #pragma omp parallel for schedule(static)
  674. for (Long i = 0; i < recv_size; i++) {
  675. Long src_indx = i * data_dim;
  676. Long trg_indx = psorted[i].data * data_dim;
  677. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = recv_buff[src_indx + j];
  678. }
  679. }
  680. }
  681. template <class Type> void Comm::ScatterReverse(Vector<Type>& data_, const Vector<Long>& scatter_index_, Long loc_size_) const {
  682. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  683. typedef SortPair<Long, Long> Pair_t;
  684. Integer npes = Size(), rank = Rank();
  685. Long data_dim = 0;
  686. Long send_size = 0;
  687. Long recv_size = 0;
  688. { // Set data_dim, send_size, recv_size
  689. recv_size = loc_size_;
  690. StaticArray<Long, 3> glb_size;
  691. StaticArray<Long, 3> loc_size;
  692. loc_size[0] = data_.Dim();
  693. loc_size[1] = scatter_index_.Dim();
  694. loc_size[2] = recv_size;
  695. Allreduce<Long>(loc_size, glb_size, 3, CommOp::SUM);
  696. if (glb_size[0] == 0 || glb_size[1] == 0) return; // Nothing to be done.
  697. SCTL_ASSERT(glb_size[0] % glb_size[1] == 0);
  698. data_dim = glb_size[0] / glb_size[1];
  699. SCTL_ASSERT(loc_size[0] % data_dim == 0);
  700. send_size = loc_size[0] / data_dim;
  701. if (glb_size[0] != glb_size[2] * data_dim) {
  702. recv_size = (((rank + 1) * (glb_size[0] / data_dim)) / npes) - ((rank * (glb_size[0] / data_dim)) / npes);
  703. }
  704. }
  705. if (npes == 1) { // Scatter directly
  706. Vector<Type> data;
  707. data.ReInit(recv_size * data_dim);
  708. #pragma omp parallel for schedule(static)
  709. for (Long i = 0; i < recv_size; i++) {
  710. Long src_indx = i * data_dim;
  711. Long trg_indx = scatter_index_[i] * data_dim;
  712. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = data_[src_indx + j];
  713. }
  714. data_.Swap(data);
  715. return;
  716. }
  717. Vector<Long> scatter_index;
  718. {
  719. StaticArray<Long, 2> glb_rank;
  720. StaticArray<Long, 3> glb_size;
  721. StaticArray<Long, 2> loc_size;
  722. loc_size[0] = data_.Dim() / data_dim;
  723. loc_size[1] = scatter_index_.Dim();
  724. Scan<Long>(loc_size, glb_rank, 2, CommOp::SUM);
  725. Allreduce<Long>(loc_size, glb_size, 2, CommOp::SUM);
  726. SCTL_ASSERT(glb_size[0] == glb_size[1]);
  727. glb_rank[0] -= loc_size[0];
  728. glb_rank[1] -= loc_size[1];
  729. Vector<Long> glb_scan0(npes + 1);
  730. Vector<Long> glb_scan1(npes + 1);
  731. Allgather<Long>(glb_rank + 0, 1, glb_scan0.begin(), 1);
  732. Allgather<Long>(glb_rank + 1, 1, glb_scan1.begin(), 1);
  733. glb_scan0[npes] = glb_size[0];
  734. glb_scan1[npes] = glb_size[1];
  735. if (loc_size[0] != loc_size[1] || glb_rank[0] != glb_rank[1]) { // Repartition scatter_index
  736. scatter_index.ReInit(loc_size[0]);
  737. Vector<Long> send_dsp(npes + 1);
  738. Vector<Long> recv_dsp(npes + 1);
  739. #pragma omp parallel for schedule(static)
  740. for (Integer i = 0; i <= npes; i++) {
  741. send_dsp[i] = std::min(std::max(glb_scan0[i], glb_rank[1]), glb_rank[1] + loc_size[1]) - glb_rank[1];
  742. recv_dsp[i] = std::min(std::max(glb_scan1[i], glb_rank[0]), glb_rank[0] + loc_size[0]) - glb_rank[0];
  743. }
  744. // Long commCnt=0;
  745. Vector<Long> send_cnt(npes + 0);
  746. Vector<Long> recv_cnt(npes + 0);
  747. #pragma omp parallel for schedule(static) // reduction(+:commCnt)
  748. for (Integer i = 0; i < npes; i++) {
  749. send_cnt[i] = send_dsp[i + 1] - send_dsp[i];
  750. recv_cnt[i] = recv_dsp[i + 1] - recv_dsp[i];
  751. // if(send_cnt[i] && i!=rank) commCnt++;
  752. // if(recv_cnt[i] && i!=rank) commCnt++;
  753. }
  754. void* mpi_req = Ialltoallv_sparse<Long>(scatter_index_.begin(), send_cnt.begin(), send_dsp.begin(), scatter_index.begin(), recv_cnt.begin(), recv_dsp.begin(), 0);
  755. Wait(mpi_req);
  756. } else {
  757. scatter_index.ReInit(scatter_index_.Dim(), (Iterator<Long>)scatter_index_.begin(), false);
  758. }
  759. }
  760. Vector<Long> glb_scan(npes);
  761. { // Global data size.
  762. Long glb_rank = 0;
  763. Scan(Ptr2ConstItr<Long>(&recv_size, 1), Ptr2Itr<Long>(&glb_rank, 1), 1, CommOp::SUM);
  764. glb_rank -= recv_size;
  765. Allgather(Ptr2ConstItr<Long>(&glb_rank, 1), 1, glb_scan.begin(), 1);
  766. }
  767. Vector<Pair_t> psorted(send_size);
  768. { // Sort scatter_index.
  769. #pragma omp parallel for schedule(static)
  770. for (Long i = 0; i < send_size; i++) {
  771. psorted[i].key = scatter_index[i];
  772. psorted[i].data = i;
  773. }
  774. omp_par::merge_sort(psorted.begin(), psorted.begin() + send_size);
  775. }
  776. Vector<Long> recv_indx(recv_size);
  777. Vector<Long> send_indx(send_size);
  778. Vector<Long> sendSz(npes);
  779. Vector<Long> sendOff(npes);
  780. Vector<Long> recvSz(npes);
  781. Vector<Long> recvOff(npes);
  782. { // Exchange send, recv indices.
  783. #pragma omp parallel for schedule(static)
  784. for (Long i = 0; i < send_size; i++) {
  785. send_indx[i] = psorted[i].key;
  786. }
  787. #pragma omp parallel for schedule(static)
  788. for (Integer i = 0; i < npes; i++) {
  789. Long start = std::lower_bound(send_indx.begin(), send_indx.begin() + send_size, glb_scan[i]) - send_indx.begin();
  790. Long end = (i + 1 < npes ? std::lower_bound(send_indx.begin(), send_indx.begin() + send_size, glb_scan[i + 1]) - send_indx.begin() : send_size);
  791. sendSz[i] = end - start;
  792. sendOff[i] = start;
  793. }
  794. Alltoall(sendSz.begin(), 1, recvSz.begin(), 1);
  795. recvOff[0] = 0;
  796. omp_par::scan(recvSz.begin(), recvOff.begin(), npes);
  797. assert(recvOff[npes - 1] + recvSz[npes - 1] == recv_size);
  798. Alltoallv(send_indx.begin(), sendSz.begin(), sendOff.begin(), recv_indx.begin(), recvSz.begin(), recvOff.begin());
  799. #pragma omp parallel for schedule(static)
  800. for (Long i = 0; i < recv_size; i++) {
  801. assert(recv_indx[i] >= glb_scan[rank]);
  802. recv_indx[i] -= glb_scan[rank];
  803. assert(recv_indx[i] < recv_size);
  804. }
  805. }
  806. Vector<Type> send_buff;
  807. { // Prepare send buffer
  808. send_buff.ReInit(send_size * data_dim);
  809. ConstIterator<Type> data = data_.begin();
  810. #pragma omp parallel for schedule(static)
  811. for (Long i = 0; i < send_size; i++) {
  812. Long src_indx = psorted[i].data * data_dim;
  813. Long trg_indx = i * data_dim;
  814. for (Long j = 0; j < data_dim; j++) send_buff[trg_indx + j] = data[src_indx + j];
  815. }
  816. }
  817. Vector<Type> recv_buff;
  818. { // All2Allv
  819. recv_buff.ReInit(recv_size * data_dim);
  820. #pragma omp parallel for schedule(static)
  821. for (Integer i = 0; i < npes; i++) {
  822. sendSz[i] *= data_dim;
  823. sendOff[i] *= data_dim;
  824. recvSz[i] *= data_dim;
  825. recvOff[i] *= data_dim;
  826. }
  827. Alltoallv(send_buff.begin(), sendSz.begin(), sendOff.begin(), recv_buff.begin(), recvSz.begin(), recvOff.begin());
  828. }
  829. { // Build output data.
  830. data_.ReInit(recv_size * data_dim);
  831. Iterator<Type> data = data_.begin();
  832. #pragma omp parallel for schedule(static)
  833. for (Long i = 0; i < recv_size; i++) {
  834. Long src_indx = i * data_dim;
  835. Long trg_indx = recv_indx[i] * data_dim;
  836. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = recv_buff[src_indx + j];
  837. }
  838. }
  839. }
  840. #ifdef SCTL_HAVE_MPI
  841. inline Vector<MPI_Request>* Comm::NewReq() const {
  842. if (req.empty()) req.push(new Vector<MPI_Request>);
  843. Vector<MPI_Request>& request = *(Vector<MPI_Request>*)req.top();
  844. req.pop();
  845. return &request;
  846. }
  847. inline void Comm::Init(const MPI_Comm mpi_comm) {
  848. #pragma omp critical(SCTL_COMM_DUP)
  849. MPI_Comm_dup(mpi_comm, &mpi_comm_);
  850. MPI_Comm_rank(mpi_comm_, &mpi_rank_);
  851. MPI_Comm_size(mpi_comm_, &mpi_size_);
  852. }
  853. inline void Comm::DelReq(Vector<MPI_Request>* req_ptr) const {
  854. if (req_ptr) req.push(req_ptr);
  855. }
  856. #define HS_MPIDATATYPE(CTYPE, MPITYPE) \
  857. template <> class Comm::CommDatatype<CTYPE> { \
  858. public: \
  859. static MPI_Datatype value() { return MPITYPE; } \
  860. static MPI_Op sum() { return MPI_SUM; } \
  861. static MPI_Op min() { return MPI_MIN; } \
  862. static MPI_Op max() { return MPI_MAX; } \
  863. }
  864. HS_MPIDATATYPE(short, MPI_SHORT);
  865. HS_MPIDATATYPE(int, MPI_INT);
  866. HS_MPIDATATYPE(long, MPI_LONG);
  867. HS_MPIDATATYPE(unsigned short, MPI_UNSIGNED_SHORT);
  868. HS_MPIDATATYPE(unsigned int, MPI_UNSIGNED);
  869. HS_MPIDATATYPE(unsigned long, MPI_UNSIGNED_LONG);
  870. HS_MPIDATATYPE(float, MPI_FLOAT);
  871. HS_MPIDATATYPE(double, MPI_DOUBLE);
  872. HS_MPIDATATYPE(long double, MPI_LONG_DOUBLE);
  873. HS_MPIDATATYPE(long long, MPI_LONG_LONG_INT);
  874. HS_MPIDATATYPE(char, MPI_CHAR);
  875. HS_MPIDATATYPE(unsigned char, MPI_UNSIGNED_CHAR);
  876. #undef HS_MPIDATATYPE
  877. #endif
  878. template <class Type> void Comm::HyperQuickSort(const Vector<Type>& arr_, Vector<Type>& SortedElem) const { // O( ((N/p)+log(p))*(log(N/p)+log(p)) )
  879. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  880. #ifdef SCTL_HAVE_MPI
  881. Integer npes, myrank, omp_p;
  882. { // Get comm size and rank.
  883. npes = Size();
  884. myrank = Rank();
  885. omp_p = omp_get_max_threads();
  886. }
  887. srand(myrank);
  888. Long totSize, nelem = arr_.Dim();
  889. { // Local and global sizes. O(log p)
  890. Allreduce<Long>(Ptr2ConstItr<Long>(&nelem, 1), Ptr2Itr<Long>(&totSize, 1), 1, CommOp::SUM);
  891. }
  892. if (npes == 1) { // SortedElem <--- local_sort(arr_)
  893. SortedElem = arr_;
  894. omp_par::merge_sort(SortedElem.begin(), SortedElem.begin() + nelem);
  895. return;
  896. }
  897. Vector<Type> arr;
  898. { // arr <-- local_sort(arr_)
  899. arr = arr_;
  900. omp_par::merge_sort(arr.begin(), arr.begin() + nelem);
  901. }
  902. Vector<Type> nbuff, nbuff_ext, rbuff, rbuff_ext; // Allocate memory.
  903. MPI_Comm comm = mpi_comm_; // Copy comm
  904. bool free_comm = false; // Flag to free comm.
  905. // Binary split and merge in each iteration.
  906. while (npes > 1 && totSize > 0) { // O(log p) iterations.
  907. Type split_key;
  908. Long totSize_new;
  909. { // Determine split_key. O( log(N/p) + log(p) )
  910. Integer glb_splt_count;
  911. Vector<Type> glb_splitters;
  912. { // Take random splitters. glb_splt_count = const = 100~1000
  913. Integer splt_count;
  914. { // Set splt_coun. O( 1 ) -- Let p * splt_count = t
  915. splt_count = (100 * nelem) / totSize;
  916. if (npes > 100) splt_count = (drand48() * totSize) < (100 * nelem) ? 1 : 0;
  917. if (splt_count > nelem) splt_count = nelem;
  918. MPI_Allreduce (&splt_count, &glb_splt_count, 1, CommDatatype<Integer>::value(), CommDatatype<Integer>::sum(), comm);
  919. if (!glb_splt_count) splt_count = std::min<Long>(1, nelem);
  920. MPI_Allreduce (&splt_count, &glb_splt_count, 1, CommDatatype<Integer>::value(), CommDatatype<Integer>::sum(), comm);
  921. SCTL_ASSERT(glb_splt_count);
  922. }
  923. Vector<Type> splitters(splt_count);
  924. for (Integer i = 0; i < splt_count; i++) {
  925. splitters[i] = arr[rand() % nelem];
  926. }
  927. Vector<Integer> glb_splt_cnts(npes), glb_splt_disp(npes);
  928. { // Set glb_splt_cnts, glb_splt_disp
  929. MPI_Allgather(&splt_count, 1, CommDatatype<Integer>::value(), &glb_splt_cnts[0], 1, CommDatatype<Integer>::value(), comm);
  930. glb_splt_disp[0] = 0;
  931. omp_par::scan(glb_splt_cnts.begin(), glb_splt_disp.begin(), npes);
  932. SCTL_ASSERT(glb_splt_count == glb_splt_cnts[npes - 1] + glb_splt_disp[npes - 1]);
  933. }
  934. { // Gather all splitters. O( log(p) )
  935. glb_splitters.ReInit(glb_splt_count);
  936. Vector<int> glb_splt_cnts_(npes), glb_splt_disp_(npes);
  937. for (Integer i = 0; i < npes; i++) {
  938. glb_splt_cnts_[i] = glb_splt_cnts[i];
  939. glb_splt_disp_[i] = glb_splt_disp[i];
  940. }
  941. MPI_Allgatherv((splt_count ? &splitters[0] : nullptr), splt_count, CommDatatype<Type>::value(), &glb_splitters[0], &glb_splt_cnts_[0], &glb_splt_disp_[0], CommDatatype<Type>::value(), comm);
  942. }
  943. }
  944. // Determine split key. O( log(N/p) + log(p) )
  945. Vector<Long> lrank(glb_splt_count);
  946. { // Compute local rank
  947. #pragma omp parallel for schedule(static)
  948. for (Integer i = 0; i < glb_splt_count; i++) {
  949. lrank[i] = std::lower_bound(arr.begin(), arr.begin() + nelem, glb_splitters[i]) - arr.begin();
  950. }
  951. }
  952. Vector<Long> grank(glb_splt_count);
  953. { // Compute global rank
  954. MPI_Allreduce(&lrank[0], &grank[0], glb_splt_count, CommDatatype<Long>::value(), CommDatatype<Long>::sum(), comm);
  955. }
  956. { // Determine split_key, totSize_new
  957. ConstIterator<Long> split_disp = grank.begin();
  958. for (Integer i = 0; i < glb_splt_count; i++) {
  959. if (labs(grank[i] - totSize / 2) < labs(*split_disp - totSize / 2)) {
  960. split_disp = grank.begin() + i;
  961. }
  962. }
  963. split_key = glb_splitters[split_disp - grank.begin()];
  964. if (myrank <= (npes - 1) / 2)
  965. totSize_new = split_disp[0];
  966. else
  967. totSize_new = totSize - split_disp[0];
  968. // double err=(((double)*split_disp)/(totSize/2))-1.0;
  969. // if(fabs<double>(err)<0.01 || npes<=16) break;
  970. // else if(!myrank) std::cout<<err<<'\n';
  971. }
  972. }
  973. Integer split_id = (npes - 1) / 2;
  974. { // Split problem into two. O( N/p )
  975. Integer new_p0 = (myrank <= split_id ? 0 : split_id + 1);
  976. Integer cmp_p0 = (myrank > split_id ? 0 : split_id + 1);
  977. Integer partner;
  978. { // Set partner
  979. partner = myrank + cmp_p0 - new_p0;
  980. if (partner >= npes) partner = npes - 1;
  981. assert(partner >= 0);
  982. }
  983. bool extra_partner = (npes % 2 == 1 && npes - 1 == myrank);
  984. Long ssize = 0, lsize = 0;
  985. ConstIterator<Type> sbuff, lbuff;
  986. { // Set ssize, lsize, sbuff, lbuff
  987. Long split_indx = std::lower_bound(arr.begin(), arr.begin() + nelem, split_key) - arr.begin();
  988. ssize = (myrank > split_id ? split_indx : nelem - split_indx);
  989. sbuff = (myrank > split_id ? arr.begin() : arr.begin() + split_indx);
  990. lsize = (myrank <= split_id ? split_indx : nelem - split_indx);
  991. lbuff = (myrank <= split_id ? arr.begin() : arr.begin() + split_indx);
  992. }
  993. Long rsize = 0, ext_rsize = 0;
  994. { // Get rsize, ext_rsize
  995. Long ext_ssize = 0;
  996. MPI_Status status;
  997. MPI_Sendrecv(&ssize, 1, CommDatatype<Long>::value(), partner, 0, &rsize, 1, CommDatatype<Long>::value(), partner, 0, comm, &status);
  998. if (extra_partner) MPI_Sendrecv(&ext_ssize, 1, CommDatatype<Long>::value(), split_id, 0, &ext_rsize, 1, CommDatatype<Long>::value(), split_id, 0, comm, &status);
  999. }
  1000. { // Exchange data.
  1001. rbuff.ReInit(rsize);
  1002. rbuff_ext.ReInit(ext_rsize);
  1003. MPI_Status status;
  1004. MPI_Sendrecv((ssize ? &sbuff[0] : nullptr), ssize, CommDatatype<Type>::value(), partner, 0, (rsize ? &rbuff[0] : nullptr), rsize, CommDatatype<Type>::value(), partner, 0, comm, &status);
  1005. if (extra_partner) MPI_Sendrecv(nullptr, 0, CommDatatype<Type>::value(), split_id, 0, (ext_rsize ? &rbuff_ext[0] : nullptr), ext_rsize, CommDatatype<Type>::value(), split_id, 0, comm, &status);
  1006. }
  1007. Long nbuff_size = lsize + rsize + ext_rsize;
  1008. { // nbuff <-- merge(lbuff, rbuff, rbuff_ext)
  1009. nbuff.ReInit(lsize + rsize);
  1010. omp_par::merge<ConstIterator<Type>>(lbuff, (lbuff + lsize), rbuff.begin(), rbuff.begin() + rsize, nbuff.begin(), omp_p, std::less<Type>());
  1011. if (ext_rsize > 0 && nbuff.Dim() > 0) {
  1012. nbuff_ext.ReInit(nbuff_size);
  1013. omp_par::merge(nbuff.begin(), nbuff.begin() + (lsize + rsize), rbuff_ext.begin(), rbuff_ext.begin() + ext_rsize, nbuff_ext.begin(), omp_p, std::less<Type>());
  1014. nbuff.Swap(nbuff_ext);
  1015. nbuff_ext.ReInit(0);
  1016. }
  1017. }
  1018. // Copy new data.
  1019. totSize = totSize_new;
  1020. nelem = nbuff_size;
  1021. arr.Swap(nbuff);
  1022. nbuff.ReInit(0);
  1023. }
  1024. { // Split comm. O( log(p) ) ??
  1025. MPI_Comm scomm;
  1026. #pragma omp critical(SCTL_COMM_DUP)
  1027. MPI_Comm_split(comm, myrank <= split_id, myrank, &scomm);
  1028. #pragma omp critical(SCTL_COMM_DUP)
  1029. if (free_comm) MPI_Comm_free(&comm);
  1030. comm = scomm;
  1031. free_comm = true;
  1032. npes = (myrank <= split_id ? split_id + 1 : npes - split_id - 1);
  1033. myrank = (myrank <= split_id ? myrank : myrank - split_id - 1);
  1034. }
  1035. }
  1036. #pragma omp critical(SCTL_COMM_DUP)
  1037. if (free_comm) MPI_Comm_free(&comm);
  1038. SortedElem = arr;
  1039. PartitionW<Type>(SortedElem);
  1040. #else
  1041. SortedElem = arr_;
  1042. std::sort(SortedElem.begin(), SortedElem.begin() + SortedElem.Dim());
  1043. #endif
  1044. }
  1045. } // end namespace