comm.txx 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126
  1. #include <pvfmm/ompUtils.hpp>
  2. #include <pvfmm/vector.hpp>
  3. namespace pvfmm {
  4. Comm::Comm() {
  5. #ifdef PVFMM_HAVE_MPI
  6. Init(MPI_COMM_SELF);
  7. #endif
  8. }
  9. Comm::Comm(const Comm& c) {
  10. #ifdef PVFMM_HAVE_MPI
  11. Init(c.mpi_comm_);
  12. #endif
  13. }
  14. Comm Comm::Self() {
  15. #ifdef PVFMM_HAVE_MPI
  16. Comm comm_self(MPI_COMM_SELF);
  17. return comm_self;
  18. #else
  19. Comm comm_self;
  20. return comm_self;
  21. #endif
  22. }
  23. Comm Comm::World() {
  24. #ifdef PVFMM_HAVE_MPI
  25. Comm comm_world(MPI_COMM_WORLD);
  26. return comm_world;
  27. #else
  28. Comm comm_self;
  29. return comm_self;
  30. #endif
  31. }
  32. Comm& Comm::operator=(const Comm& c) {
  33. #ifdef PVFMM_HAVE_MPI
  34. Init(c.mpi_comm_);
  35. #endif
  36. }
  37. Comm::~Comm() {
  38. #ifdef PVFMM_HAVE_MPI
  39. while (!req.empty()) {
  40. delete (Vector<MPI_Request>*)req.top();
  41. req.pop();
  42. }
  43. MPI_Comm_free(&mpi_comm_);
  44. #endif
  45. }
  46. Comm Comm::Split(Integer clr) const {
  47. #ifdef PVFMM_HAVE_MPI
  48. MPI_Comm new_comm;
  49. MPI_Comm_split(mpi_comm_, clr, mpi_rank_, &new_comm);
  50. Comm c(new_comm);
  51. MPI_Comm_free(&new_comm);
  52. return c;
  53. #else
  54. Comm c;
  55. return c;
  56. #endif
  57. }
  58. Integer Comm::Rank() const {
  59. #ifdef PVFMM_HAVE_MPI
  60. return mpi_rank_;
  61. #else
  62. return 0;
  63. #endif
  64. }
  65. Integer Comm::Size() const {
  66. #ifdef PVFMM_HAVE_MPI
  67. return mpi_size_;
  68. #else
  69. return 1;
  70. #endif
  71. }
  72. void Comm::Barrier() const {
  73. #ifdef PVFMM_HAVE_MPI
  74. MPI_Barrier(mpi_comm_);
  75. #endif
  76. }
  77. template <class SType> void* Comm::Isend(ConstIterator<SType> sbuf, Long scount, Integer dest, Integer tag) const {
  78. #ifdef PVFMM_HAVE_MPI
  79. if (!scount) return NULL;
  80. Vector<MPI_Request>& request = *NewReq();
  81. request.ReInit(1);
  82. sbuf[0];
  83. sbuf[scount - 1];
  84. #ifndef NDEBUG
  85. MPI_Issend(&sbuf[0], scount, CommDatatype<SType>::value(), dest, tag, mpi_comm_, &request[0]);
  86. #else
  87. MPI_Isend(&sbuf[0], scount, CommDatatype<SType>::value(), dest, tag, mpi_comm_, &request[0]);
  88. #endif
  89. return &request;
  90. #else
  91. auto it = recv_req.find(tag);
  92. if (it == recv_req.end())
  93. send_req.insert(std::pair<Integer, ConstIterator<char>>(tag, (ConstIterator<char>)sbuf));
  94. else
  95. pvfmm::memcopy(it->second, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  96. return NULL;
  97. #endif
  98. }
  99. template <class RType> void* Comm::Irecv(Iterator<RType> rbuf, Long rcount, Integer source, Integer tag) const {
  100. #ifdef PVFMM_HAVE_MPI
  101. if (!rcount) return NULL;
  102. Vector<MPI_Request>& request = *NewReq();
  103. request.ReInit(1);
  104. rbuf[0];
  105. rbuf[rcount - 1];
  106. MPI_Irecv(&rbuf[0], rcount, CommDatatype<RType>::value(), source, tag, mpi_comm_, &request[0]);
  107. return &request;
  108. #else
  109. auto it = send_req.find(tag);
  110. if (it == send_req.end())
  111. recv_req.insert(std::pair<Integer, Iterator<char>>(tag, (Iterator<char>)rbuf));
  112. else
  113. pvfmm::memcopy((Iterator<char>)rbuf, it->second, rcount * sizeof(RType));
  114. return NULL;
  115. #endif
  116. }
  117. void Comm::Wait(void* req_ptr) const {
  118. #ifdef PVFMM_HAVE_MPI
  119. if (req_ptr == NULL) return;
  120. Vector<MPI_Request>& request = *(Vector<MPI_Request>*)req_ptr;
  121. // std::vector<MPI_Status> status(request.Dim());
  122. if (request.Dim()) MPI_Waitall(request.Dim(), &request[0], MPI_STATUSES_IGNORE); //&status[0]);
  123. DelReq(&request);
  124. #endif
  125. }
  126. template <class SType, class RType> void Comm::Allgather(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const {
  127. #ifdef PVFMM_HAVE_MPI
  128. if (scount) {
  129. sbuf[0];
  130. sbuf[scount - 1];
  131. }
  132. if (rcount) {
  133. rbuf[0];
  134. rbuf[rcount * Size() - 1];
  135. }
  136. MPI_Allgather((scount ? &sbuf[0] : NULL), scount, CommDatatype<SType>::value(), (rcount ? &rbuf[0] : NULL), rcount, CommDatatype<RType>::value(), mpi_comm_);
  137. #else
  138. pvfmm::memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  139. #endif
  140. }
  141. template <class SType, class RType> void Comm::Allgatherv(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const {
  142. #ifdef PVFMM_HAVE_MPI
  143. Vector<int> rcounts_(mpi_size_), rdispls_(mpi_size_);
  144. Long rcount_sum = 0;
  145. #pragma omp parallel for schedule(static) reduction(+ : rcount_sum)
  146. for (Integer i = 0; i < mpi_size_; i++) {
  147. rcounts_[i] = rcounts[i];
  148. rdispls_[i] = rdispls[i];
  149. rcount_sum += rcounts[i];
  150. }
  151. if (scount) {
  152. sbuf[0];
  153. sbuf[scount - 1];
  154. }
  155. if (rcount_sum) {
  156. rbuf[0];
  157. rbuf[rcount_sum - 1];
  158. }
  159. MPI_Allgatherv((scount ? &sbuf[0] : NULL), scount, CommDatatype<SType>::value(), (rcount_sum ? &rbuf[0] : NULL), &rcounts_.Begin()[0], &rdispls_.Begin()[0], CommDatatype<RType>::value(), mpi_comm_);
  160. #else
  161. pvfmm::memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)sbuf, scount * sizeof(SType));
  162. #endif
  163. }
  164. template <class SType, class RType> void Comm::Alltoall(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const {
  165. #ifdef PVFMM_HAVE_MPI
  166. if (scount) {
  167. sbuf[0];
  168. sbuf[scount * Size() - 1];
  169. }
  170. if (rcount) {
  171. rbuf[0];
  172. rbuf[rcount * Size() - 1];
  173. }
  174. MPI_Alltoall((scount ? &sbuf[0] : NULL), scount, CommDatatype<SType>::value(), (rcount ? &rbuf[0] : NULL), rcount, CommDatatype<RType>::value(), mpi_comm_);
  175. #else
  176. pvfmm::memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  177. #endif
  178. }
  179. template <class SType, class RType> void* Comm::Ialltoallv_sparse(ConstIterator<SType> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls, Integer tag) const {
  180. #ifdef PVFMM_HAVE_MPI
  181. Integer request_count = 0;
  182. for (Integer i = 0; i < mpi_size_; i++) {
  183. if (rcounts[i]) request_count++;
  184. if (scounts[i]) request_count++;
  185. }
  186. if (!request_count) return NULL;
  187. Vector<MPI_Request>& request = *NewReq();
  188. request.ReInit(request_count);
  189. Integer request_iter = 0;
  190. for (Integer i = 0; i < mpi_size_; i++) {
  191. if (rcounts[i]) {
  192. rbuf[rdispls[i]];
  193. rbuf[rdispls[i] + rcounts[i] - 1];
  194. MPI_Irecv(&rbuf[rdispls[i]], rcounts[i], CommDatatype<RType>::value(), i, tag, mpi_comm_, &request[request_iter]);
  195. request_iter++;
  196. }
  197. }
  198. for (Integer i = 0; i < mpi_size_; i++) {
  199. if (scounts[i]) {
  200. sbuf[sdispls[i]];
  201. sbuf[sdispls[i] + scounts[i] - 1];
  202. MPI_Isend(&sbuf[sdispls[i]], scounts[i], CommDatatype<SType>::value(), i, tag, mpi_comm_, &request[request_iter]);
  203. request_iter++;
  204. }
  205. }
  206. return &request;
  207. #else
  208. pvfmm::memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)(sbuf + sdispls[0]), scounts[0] * sizeof(SType));
  209. return NULL;
  210. #endif
  211. }
  212. template <class Type> void Comm::Alltoallv(ConstIterator<Type> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<Type> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const {
  213. #ifdef PVFMM_HAVE_MPI
  214. { // Use Alltoallv_sparse of average connectivity<64
  215. Long connectivity = 0, glb_connectivity = 0;
  216. #pragma omp parallel for schedule(static) reduction(+ : connectivity)
  217. for (Integer i = 0; i < mpi_size_; i++) {
  218. if (rcounts[i]) connectivity++;
  219. }
  220. Allreduce(PVFMM_PTR2CONSTITR(Long, &connectivity, 1), PVFMM_PTR2ITR(Long, &glb_connectivity, 1), 1, CommOp::SUM);
  221. if (glb_connectivity < 64 * Size()) {
  222. void* mpi_req = Ialltoallv_sparse(sbuf, scounts, sdispls, rbuf, rcounts, rdispls, 0);
  223. Wait(mpi_req);
  224. return;
  225. }
  226. }
  227. { // Use vendor MPI_Alltoallv
  228. //#ifndef ALLTOALLV_FIX
  229. Vector<int> scnt, sdsp, rcnt, rdsp;
  230. scnt.ReInit(mpi_size_);
  231. sdsp.ReInit(mpi_size_);
  232. rcnt.ReInit(mpi_size_);
  233. rdsp.ReInit(mpi_size_);
  234. Long stotal = 0, rtotal = 0;
  235. #pragma omp parallel for schedule(static) reduction(+ : stotal, rtotal)
  236. for (Integer i = 0; i < mpi_size_; i++) {
  237. scnt[i] = scounts[i];
  238. sdsp[i] = sdispls[i];
  239. rcnt[i] = rcounts[i];
  240. rdsp[i] = rdispls[i];
  241. stotal += scounts[i];
  242. rtotal += rcounts[i];
  243. }
  244. MPI_Alltoallv((stotal ? &sbuf[0] : NULL), &scnt[0], &sdsp[0], CommDatatype<Type>::value(), (rtotal ? &rbuf[0] : NULL), &rcnt[0], &rdsp[0], CommDatatype<Type>::value(), mpi_comm_);
  245. return;
  246. //#endif
  247. }
  248. // TODO: implement hypercube scheme
  249. #else
  250. pvfmm::memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)(sbuf + sdispls[0]), scounts[0] * sizeof(Type));
  251. #endif
  252. }
  253. template <class Type> void Comm::Allreduce(ConstIterator<Type> sbuf, Iterator<Type> rbuf, Long count, CommOp op) const {
  254. #ifdef PVFMM_HAVE_MPI
  255. if (!count) return;
  256. MPI_Op mpi_op;
  257. switch (op) {
  258. case CommOp::SUM:
  259. mpi_op = CommDatatype<Type>::sum();
  260. break;
  261. case CommOp::MIN:
  262. mpi_op = CommDatatype<Type>::min();
  263. break;
  264. case CommOp::MAX:
  265. mpi_op = CommDatatype<Type>::max();
  266. break;
  267. default:
  268. mpi_op = MPI_OP_NULL;
  269. break;
  270. }
  271. sbuf[0];
  272. sbuf[count - 1];
  273. rbuf[0];
  274. rbuf[count - 1];
  275. MPI_Allreduce(&sbuf[0], &rbuf[0], count, CommDatatype<Type>::value(), mpi_op, mpi_comm_);
  276. #else
  277. pvfmm::memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, count * sizeof(Type));
  278. #endif
  279. }
  280. template <class Type> void Comm::Scan(ConstIterator<Type> sbuf, Iterator<Type> rbuf, int count, CommOp op) const {
  281. #ifdef PVFMM_HAVE_MPI
  282. if (!count) return;
  283. MPI_Op mpi_op;
  284. switch (op) {
  285. case CommOp::SUM:
  286. mpi_op = CommDatatype<Type>::sum();
  287. break;
  288. case CommOp::MIN:
  289. mpi_op = CommDatatype<Type>::min();
  290. break;
  291. case CommOp::MAX:
  292. mpi_op = CommDatatype<Type>::max();
  293. break;
  294. default:
  295. mpi_op = MPI_OP_NULL;
  296. break;
  297. }
  298. sbuf[0];
  299. sbuf[count - 1];
  300. rbuf[0];
  301. rbuf[count - 1];
  302. MPI_Scan(&sbuf[0], &rbuf[0], count, CommDatatype<Type>::value(), mpi_op, mpi_comm_);
  303. #else
  304. pvfmm::memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, count * sizeof(Type));
  305. #endif
  306. }
  307. template <class Type> void Comm::PartitionW(Vector<Type>& nodeList, const Vector<Long>* wts_) const {
  308. Integer npes = Size();
  309. if (npes == 1) return;
  310. Long nlSize = nodeList.Dim();
  311. Vector<Long> wts;
  312. Long localWt = 0;
  313. if (wts_ == NULL) { // Construct arrays of wts.
  314. wts.ReInit(nlSize);
  315. #pragma omp parallel for schedule(static)
  316. for (Long i = 0; i < nlSize; i++) {
  317. wts[i] = 1;
  318. }
  319. localWt = nlSize;
  320. } else {
  321. wts.ReInit(nlSize, (Iterator<Long>)wts_->Begin(), false);
  322. #pragma omp parallel for reduction(+ : localWt)
  323. for (Long i = 0; i < nlSize; i++) {
  324. localWt += wts[i];
  325. }
  326. }
  327. Long off1 = 0, off2 = 0, totalWt = 0;
  328. { // compute the total weight of the problem ...
  329. Allreduce<Long>(PVFMM_PTR2CONSTITR(Long, &localWt, 1), PVFMM_PTR2ITR(Long, &totalWt, 1), 1, CommOp::SUM);
  330. Scan<Long>(PVFMM_PTR2CONSTITR(Long, &localWt, 1), PVFMM_PTR2ITR(Long, &off2, 1), 1, CommOp::SUM);
  331. off1 = off2 - localWt;
  332. }
  333. Vector<Long> lscn;
  334. if (nlSize) { // perform a local scan on the weights first ...
  335. lscn.ReInit(nlSize);
  336. lscn[0] = off1;
  337. omp_par::scan(wts.Begin(), lscn.Begin(), nlSize);
  338. }
  339. Vector<Long> sendSz, recvSz, sendOff, recvOff;
  340. sendSz.ReInit(npes);
  341. recvSz.ReInit(npes);
  342. sendOff.ReInit(npes);
  343. recvOff.ReInit(npes);
  344. sendSz.SetZero();
  345. if (nlSize > 0 && totalWt > 0) { // Compute sendSz
  346. Long pid1 = (off1 * npes) / totalWt;
  347. Long pid2 = ((off2 + 1) * npes) / totalWt + 1;
  348. assert((totalWt * pid2) / npes >= off2);
  349. pid1 = (pid1 < 0 ? 0 : pid1);
  350. pid2 = (pid2 > npes ? npes : pid2);
  351. #pragma omp parallel for schedule(static)
  352. for (Integer i = pid1; i < pid2; i++) {
  353. Long wt1 = (totalWt * (i)) / npes;
  354. Long wt2 = (totalWt * (i + 1)) / npes;
  355. Long start = std::lower_bound(lscn.Begin(), lscn.Begin() + nlSize, wt1, std::less<Long>()) - lscn.Begin();
  356. Long end = std::lower_bound(lscn.Begin(), lscn.Begin() + nlSize, wt2, std::less<Long>()) - lscn.Begin();
  357. if (i == 0) start = 0;
  358. if (i == npes - 1) end = nlSize;
  359. sendSz[i] = end - start;
  360. }
  361. } else {
  362. sendSz[0] = nlSize;
  363. }
  364. // Exchange sendSz, recvSz
  365. Alltoall<Long>(sendSz.Begin(), 1, recvSz.Begin(), 1);
  366. { // Compute sendOff, recvOff
  367. sendOff[0] = 0;
  368. omp_par::scan(sendSz.Begin(), sendOff.Begin(), npes);
  369. recvOff[0] = 0;
  370. omp_par::scan(recvSz.Begin(), recvOff.Begin(), npes);
  371. assert(sendOff[npes - 1] + sendSz[npes - 1] == nlSize);
  372. }
  373. // perform All2All ...
  374. Vector<Type> newNodes;
  375. newNodes.ReInit(recvSz[npes - 1] + recvOff[npes - 1]);
  376. void* mpi_req = Ialltoallv_sparse<Type>(nodeList.Begin(), sendSz.Begin(), sendOff.Begin(), newNodes.Begin(), recvSz.Begin(), recvOff.Begin());
  377. Wait(mpi_req);
  378. // reset the pointer ...
  379. nodeList.Swap(newNodes);
  380. }
  381. template <class Type> void Comm::PartitionN(Vector<Type>& v, Long N) const {
  382. Integer rank = Rank();
  383. Integer np = Size();
  384. if (np == 1) return;
  385. Vector<Long> v_cnt(np), v_dsp(np + 1);
  386. Vector<Long> N_cnt(np), N_dsp(np + 1);
  387. { // Set v_cnt, v_dsp
  388. v_dsp[0] = 0;
  389. Long cnt = v.Dim();
  390. Allgather(PVFMM_PTR2CONSTITR(Long, &cnt, 1), 1, v_cnt.Begin(), 1);
  391. omp_par::scan(v_cnt.Begin(), v_dsp.Begin(), np);
  392. v_dsp[np] = v_cnt[np - 1] + v_dsp[np - 1];
  393. }
  394. { // Set N_cnt, N_dsp
  395. N_dsp[0] = 0;
  396. Long cnt = N;
  397. Allgather(PVFMM_PTR2CONSTITR(Long, &cnt, 1), 1, N_cnt.Begin(), 1);
  398. omp_par::scan(N_cnt.Begin(), N_dsp.Begin(), np);
  399. N_dsp[np] = N_cnt[np - 1] + N_dsp[np - 1];
  400. }
  401. { // Adjust for dof
  402. Long dof = v_dsp[np] / N_dsp[np];
  403. assert(dof * N_dsp[np] == v_dsp[np]);
  404. if (dof == 0) return;
  405. if (dof != 1) {
  406. #pragma omp parallel for schedule(static)
  407. for (Integer i = 0; i < np; i++) N_cnt[i] *= dof;
  408. #pragma omp parallel for schedule(static)
  409. for (Integer i = 0; i <= np; i++) N_dsp[i] *= dof;
  410. }
  411. }
  412. Vector<Type> v_(N_cnt[rank]);
  413. { // Set v_
  414. Vector<Long> scnt(np), sdsp(np);
  415. Vector<Long> rcnt(np), rdsp(np);
  416. #pragma omp parallel for schedule(static)
  417. for (Integer i = 0; i < np; i++) {
  418. { // Set scnt
  419. Long n0 = N_dsp[i + 0];
  420. Long n1 = N_dsp[i + 1];
  421. if (n0 < v_dsp[rank + 0]) n0 = v_dsp[rank + 0];
  422. if (n1 < v_dsp[rank + 0]) n1 = v_dsp[rank + 0];
  423. if (n0 > v_dsp[rank + 1]) n0 = v_dsp[rank + 1];
  424. if (n1 > v_dsp[rank + 1]) n1 = v_dsp[rank + 1];
  425. scnt[i] = n1 - n0;
  426. }
  427. { // Set rcnt
  428. Long n0 = v_dsp[i + 0];
  429. Long n1 = v_dsp[i + 1];
  430. if (n0 < N_dsp[rank + 0]) n0 = N_dsp[rank + 0];
  431. if (n1 < N_dsp[rank + 0]) n1 = N_dsp[rank + 0];
  432. if (n0 > N_dsp[rank + 1]) n0 = N_dsp[rank + 1];
  433. if (n1 > N_dsp[rank + 1]) n1 = N_dsp[rank + 1];
  434. rcnt[i] = n1 - n0;
  435. }
  436. }
  437. sdsp[0] = 0;
  438. omp_par::scan(scnt.Begin(), sdsp.Begin(), np);
  439. rdsp[0] = 0;
  440. omp_par::scan(rcnt.Begin(), rdsp.Begin(), np);
  441. void* mpi_request = Ialltoallv_sparse(v.Begin(), scnt.Begin(), sdsp.Begin(), v_.Begin(), rcnt.Begin(), rdsp.Begin());
  442. Wait(mpi_request);
  443. }
  444. v.Swap(v_);
  445. }
  446. template <class Type> void Comm::PartitionS(Vector<Type>& nodeList, const Type& splitter) const {
  447. Integer npes = Size();
  448. if (npes == 1) return;
  449. Vector<Type> mins(npes);
  450. Allgather(PVFMM_PTR2CONSTITR(Type, &splitter, 1), 1, mins.Begin(), 1);
  451. Vector<Long> scnt(npes), sdsp(npes);
  452. Vector<Long> rcnt(npes), rdsp(npes);
  453. { // Compute scnt, sdsp
  454. #pragma omp parallel for schedule(static)
  455. for (Integer i = 0; i < npes; i++) {
  456. sdsp[i] = std::lower_bound(nodeList.Begin(), nodeList.Begin() + nodeList.Dim(), mins[i]) - nodeList.Begin();
  457. }
  458. #pragma omp parallel for schedule(static)
  459. for (Integer i = 0; i < npes - 1; i++) {
  460. scnt[i] = sdsp[i + 1] - sdsp[i];
  461. }
  462. scnt[npes - 1] = nodeList.Dim() - sdsp[npes - 1];
  463. }
  464. { // Compute rcnt, rdsp
  465. rdsp[0] = 0;
  466. Alltoall(scnt.Begin(), 1, rcnt.Begin(), 1);
  467. omp_par::scan(rcnt.Begin(), rdsp.Begin(), npes);
  468. }
  469. { // Redistribute nodeList
  470. Vector<Type> nodeList_(rdsp[npes - 1] + rcnt[npes - 1]);
  471. void* mpi_request = Ialltoallv_sparse(nodeList.Begin(), scnt.Begin(), sdsp.Begin(), nodeList_.Begin(), rcnt.Begin(), rdsp.Begin());
  472. Wait(mpi_request);
  473. nodeList.Swap(nodeList_);
  474. }
  475. }
  476. template <class Type> void Comm::SortScatterIndex(const Vector<Type>& key, Vector<Long>& scatter_index, const Type* split_key_) const {
  477. typedef SortPair<Type, Long> Pair_t;
  478. Integer npes = Size(), rank = Rank();
  479. Vector<Pair_t> parray(key.Dim());
  480. { // Build global index.
  481. Long glb_dsp = 0;
  482. Long loc_size = key.Dim();
  483. Scan(PVFMM_PTR2CONSTITR(Long, &loc_size, 1), PVFMM_PTR2ITR(Long, &glb_dsp, 1), 1, CommOp::SUM);
  484. glb_dsp -= loc_size;
  485. #pragma omp parallel for schedule(static)
  486. for (Long i = 0; i < loc_size; i++) {
  487. parray[i].key = key[i];
  488. parray[i].data = glb_dsp + i;
  489. }
  490. }
  491. Vector<Pair_t> psorted;
  492. HyperQuickSort(parray, psorted);
  493. if (npes > 1 && split_key_ != NULL) { // Partition data
  494. Vector<Type> split_key(npes);
  495. Allgather(PVFMM_PTR2CONSTITR(Type, split_key_, 1), 1, split_key.Begin(), 1);
  496. Vector<Long> sendSz(npes);
  497. Vector<Long> recvSz(npes);
  498. Vector<Long> sendOff(npes);
  499. Vector<Long> recvOff(npes);
  500. Long nlSize = psorted.Dim();
  501. sendSz.SetZero();
  502. if (nlSize > 0) { // Compute sendSz
  503. // Determine processor range.
  504. Long pid1 = std::lower_bound(split_key.Begin(), split_key.Begin() + npes, psorted[0].key) - split_key.Begin() - 1;
  505. Long pid2 = std::upper_bound(split_key.Begin(), split_key.Begin() + npes, psorted[nlSize - 1].key) - split_key.Begin() + 1;
  506. pid1 = (pid1 < 0 ? 0 : pid1);
  507. pid2 = (pid2 > npes ? npes : pid2);
  508. #pragma omp parallel for schedule(static)
  509. for (Integer i = pid1; i < pid2; i++) {
  510. Pair_t p1;
  511. p1.key = split_key[i];
  512. Pair_t p2;
  513. p2.key = split_key[i + 1 < npes ? i + 1 : i];
  514. Long start = std::lower_bound(psorted.Begin(), psorted.Begin() + nlSize, p1, std::less<Pair_t>()) - psorted.Begin();
  515. Long end = std::lower_bound(psorted.Begin(), psorted.Begin() + nlSize, p2, std::less<Pair_t>()) - psorted.Begin();
  516. if (i == 0) start = 0;
  517. if (i == npes - 1) end = nlSize;
  518. sendSz[i] = end - start;
  519. }
  520. }
  521. // Exchange sendSz, recvSz
  522. Alltoall<Long>(sendSz.Begin(), 1, recvSz.Begin(), 1);
  523. // compute offsets ...
  524. { // Compute sendOff, recvOff
  525. sendOff[0] = 0;
  526. omp_par::scan(sendSz.Begin(), sendOff.Begin(), npes);
  527. recvOff[0] = 0;
  528. omp_par::scan(recvSz.Begin(), recvOff.Begin(), npes);
  529. assert(sendOff[npes - 1] + sendSz[npes - 1] == nlSize);
  530. }
  531. // perform All2All ...
  532. Vector<Pair_t> newNodes(recvSz[npes - 1] + recvOff[npes - 1]);
  533. void* mpi_req = Ialltoallv_sparse<Pair_t>(psorted.Begin(), sendSz.Begin(), sendOff.Begin(), newNodes.Begin(), recvSz.Begin(), recvOff.Begin());
  534. Wait(mpi_req);
  535. // reset the pointer ...
  536. psorted.Swap(newNodes);
  537. }
  538. scatter_index.ReInit(psorted.Dim());
  539. #pragma omp parallel for schedule(static)
  540. for (Long i = 0; i < psorted.Dim(); i++) {
  541. scatter_index[i] = psorted[i].data;
  542. }
  543. }
  544. template <class Type> void Comm::ScatterForward(Vector<Type>& data_, const Vector<Long>& scatter_index) const {
  545. typedef SortPair<Long, Long> Pair_t;
  546. Integer npes = Size(), rank = Rank();
  547. Long data_dim = 0;
  548. Long send_size = 0;
  549. Long recv_size = 0;
  550. { // Set data_dim, send_size, recv_size
  551. recv_size = scatter_index.Dim();
  552. StaticArray<Long, 2> glb_size;
  553. StaticArray<Long, 2> loc_size;
  554. loc_size[0] = data_.Dim();
  555. loc_size[1] = recv_size;
  556. Allreduce(loc_size, glb_size, 2, CommOp::SUM);
  557. if (glb_size[0] == 0 || glb_size[1] == 0) return; // Nothing to be done.
  558. data_dim = glb_size[0] / glb_size[1];
  559. PVFMM_ASSERT(glb_size[0] == data_dim * glb_size[1]);
  560. send_size = data_.Dim() / data_dim;
  561. }
  562. if (npes == 1) { // Scatter directly
  563. Vector<Type> data;
  564. data.ReInit(recv_size * data_dim);
  565. #pragma omp parallel for schedule(static)
  566. for (Long i = 0; i < recv_size; i++) {
  567. Long src_indx = scatter_index[i] * data_dim;
  568. Long trg_indx = i * data_dim;
  569. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = data_[src_indx + j];
  570. }
  571. data_.Swap(data);
  572. return;
  573. }
  574. Vector<Long> glb_scan;
  575. { // Global scan of data size.
  576. glb_scan.ReInit(npes);
  577. Long glb_rank = 0;
  578. Scan(PVFMM_PTR2CONSTITR(Long, &send_size, 1), PVFMM_PTR2ITR(Long, &glb_rank, 1), 1, CommOp::SUM);
  579. glb_rank -= send_size;
  580. Allgather(PVFMM_PTR2CONSTITR(Long, &glb_rank, 1), 1, glb_scan.Begin(), 1);
  581. }
  582. Vector<Pair_t> psorted;
  583. { // Sort scatter_index.
  584. psorted.ReInit(recv_size);
  585. #pragma omp parallel for schedule(static)
  586. for (Long i = 0; i < recv_size; i++) {
  587. psorted[i].key = scatter_index[i];
  588. psorted[i].data = i;
  589. }
  590. omp_par::merge_sort(psorted.Begin(), psorted.Begin() + recv_size);
  591. }
  592. Vector<Long> recv_indx(recv_size);
  593. Vector<Long> send_indx(send_size);
  594. Vector<Long> sendSz(npes);
  595. Vector<Long> sendOff(npes);
  596. Vector<Long> recvSz(npes);
  597. Vector<Long> recvOff(npes);
  598. { // Exchange send, recv indices.
  599. #pragma omp parallel for schedule(static)
  600. for (Long i = 0; i < recv_size; i++) {
  601. recv_indx[i] = psorted[i].key;
  602. }
  603. #pragma omp parallel for schedule(static)
  604. for (Integer i = 0; i < npes; i++) {
  605. Long start = std::lower_bound(recv_indx.Begin(), recv_indx.Begin() + recv_size, glb_scan[i]) - recv_indx.Begin();
  606. Long end = (i + 1 < npes ? std::lower_bound(recv_indx.Begin(), recv_indx.Begin() + recv_size, glb_scan[i + 1]) - recv_indx.Begin() : recv_size);
  607. recvSz[i] = end - start;
  608. recvOff[i] = start;
  609. }
  610. Alltoall(recvSz.Begin(), 1, sendSz.Begin(), 1);
  611. sendOff[0] = 0;
  612. omp_par::scan(sendSz.Begin(), sendOff.Begin(), npes);
  613. assert(sendOff[npes - 1] + sendSz[npes - 1] == send_size);
  614. Alltoallv(recv_indx.Begin(), recvSz.Begin(), recvOff.Begin(), send_indx.Begin(), sendSz.Begin(), sendOff.Begin());
  615. #pragma omp parallel for schedule(static)
  616. for (Long i = 0; i < send_size; i++) {
  617. assert(send_indx[i] >= glb_scan[rank]);
  618. send_indx[i] -= glb_scan[rank];
  619. assert(send_indx[i] < send_size);
  620. }
  621. }
  622. Vector<Type> send_buff;
  623. { // Prepare send buffer
  624. send_buff.ReInit(send_size * data_dim);
  625. ConstIterator<Type> data = data_.Begin();
  626. #pragma omp parallel for schedule(static)
  627. for (Long i = 0; i < send_size; i++) {
  628. Long src_indx = send_indx[i] * data_dim;
  629. Long trg_indx = i * data_dim;
  630. for (Long j = 0; j < data_dim; j++) send_buff[trg_indx + j] = data[src_indx + j];
  631. }
  632. }
  633. Vector<Type> recv_buff;
  634. { // All2Allv
  635. recv_buff.ReInit(recv_size * data_dim);
  636. #pragma omp parallel for schedule(static)
  637. for (Integer i = 0; i < npes; i++) {
  638. sendSz[i] *= data_dim;
  639. sendOff[i] *= data_dim;
  640. recvSz[i] *= data_dim;
  641. recvOff[i] *= data_dim;
  642. }
  643. Alltoallv(send_buff.Begin(), sendSz.Begin(), sendOff.Begin(), recv_buff.Begin(), recvSz.Begin(), recvOff.Begin());
  644. }
  645. { // Build output data.
  646. data_.ReInit(recv_size * data_dim);
  647. Iterator<Type> data = data_.Begin();
  648. #pragma omp parallel for schedule(static)
  649. for (Long i = 0; i < recv_size; i++) {
  650. Long src_indx = i * data_dim;
  651. Long trg_indx = psorted[i].data * data_dim;
  652. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = recv_buff[src_indx + j];
  653. }
  654. }
  655. }
  656. template <class Type> void Comm::ScatterReverse(Vector<Type>& data_, const Vector<Long>& scatter_index_, Long loc_size_) const {
  657. typedef SortPair<Long, Long> Pair_t;
  658. Integer npes = Size(), rank = Rank();
  659. Long data_dim = 0;
  660. Long send_size = 0;
  661. Long recv_size = 0;
  662. { // Set data_dim, send_size, recv_size
  663. recv_size = loc_size_;
  664. StaticArray<Long, 3> glb_size;
  665. StaticArray<Long, 3> loc_size;
  666. loc_size[0] = data_.Dim();
  667. loc_size[1] = scatter_index_.Dim();
  668. loc_size[2] = recv_size;
  669. Allreduce(loc_size, glb_size, 3, CommOp::SUM);
  670. if (glb_size[0] == 0 || glb_size[1] == 0) return; // Nothing to be done.
  671. PVFMM_ASSERT(glb_size[0] % glb_size[1] == 0);
  672. data_dim = glb_size[0] / glb_size[1];
  673. PVFMM_ASSERT(loc_size[0] % data_dim == 0);
  674. send_size = loc_size[0] / data_dim;
  675. if (glb_size[0] != glb_size[2] * data_dim) {
  676. recv_size = (((rank + 1) * (glb_size[0] / data_dim)) / npes) - ((rank * (glb_size[0] / data_dim)) / npes);
  677. }
  678. }
  679. if (npes == 1) { // Scatter directly
  680. Vector<Type> data;
  681. data.ReInit(recv_size * data_dim);
  682. #pragma omp parallel for schedule(static)
  683. for (Long i = 0; i < recv_size; i++) {
  684. Long src_indx = i * data_dim;
  685. Long trg_indx = scatter_index_[i] * data_dim;
  686. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = data_[src_indx + j];
  687. }
  688. data_.Swap(data);
  689. return;
  690. }
  691. Vector<Long> scatter_index;
  692. {
  693. StaticArray<Long, 2> glb_rank;
  694. StaticArray<Long, 3> glb_size;
  695. StaticArray<Long, 2> loc_size;
  696. loc_size[0] = data_.Dim() / data_dim;
  697. loc_size[1] = scatter_index_.Dim();
  698. Scan(loc_size, glb_rank, 2, CommOp::SUM);
  699. Allreduce(loc_size, glb_size, 2, CommOp::SUM);
  700. PVFMM_ASSERT(glb_size[0] == glb_size[1]);
  701. glb_rank[0] -= loc_size[0];
  702. glb_rank[1] -= loc_size[1];
  703. Vector<Long> glb_scan0(npes + 1);
  704. Vector<Long> glb_scan1(npes + 1);
  705. Allgather(glb_rank + 0, 1, glb_scan0.Begin(), 1);
  706. Allgather(glb_rank + 1, 1, glb_scan1.Begin(), 1);
  707. glb_scan0[npes] = glb_size[0];
  708. glb_scan1[npes] = glb_size[1];
  709. if (loc_size[0] != loc_size[1] || glb_rank[0] != glb_rank[1]) { // Repartition scatter_index
  710. scatter_index.ReInit(loc_size[0]);
  711. Vector<Long> send_dsp(npes + 1);
  712. Vector<Long> recv_dsp(npes + 1);
  713. #pragma omp parallel for schedule(static)
  714. for (Integer i = 0; i <= npes; i++) {
  715. send_dsp[i] = std::min(std::max(glb_scan0[i], glb_rank[1]), glb_rank[1] + loc_size[1]) - glb_rank[1];
  716. recv_dsp[i] = std::min(std::max(glb_scan1[i], glb_rank[0]), glb_rank[0] + loc_size[0]) - glb_rank[0];
  717. }
  718. // Long commCnt=0;
  719. Vector<Long> send_cnt(npes + 0);
  720. Vector<Long> recv_cnt(npes + 0);
  721. #pragma omp parallel for schedule(static) // reduction(+:commCnt)
  722. for (Integer i = 0; i < npes; i++) {
  723. send_cnt[i] = send_dsp[i + 1] - send_dsp[i];
  724. recv_cnt[i] = recv_dsp[i + 1] - recv_dsp[i];
  725. // if(send_cnt[i] && i!=rank) commCnt++;
  726. // if(recv_cnt[i] && i!=rank) commCnt++;
  727. }
  728. void* mpi_req = Ialltoallv_sparse<Long>(scatter_index_.Begin(), send_cnt.Begin(), send_dsp.Begin(), scatter_index.Begin(), recv_cnt.Begin(), recv_dsp.Begin(), 0);
  729. Wait(mpi_req);
  730. } else {
  731. scatter_index.ReInit(scatter_index_.Dim(), (Iterator<Long>)scatter_index_.Begin(), false);
  732. }
  733. }
  734. Vector<Long> glb_scan(npes);
  735. { // Global data size.
  736. Long glb_rank = 0;
  737. Scan(PVFMM_PTR2CONSTITR(Long, &recv_size, 1), PVFMM_PTR2ITR(Long, &glb_rank, 1), 1, CommOp::SUM);
  738. glb_rank -= recv_size;
  739. Allgather(PVFMM_PTR2CONSTITR(Long, &glb_rank, 1), 1, glb_scan.Begin(), 1);
  740. }
  741. Vector<Pair_t> psorted(send_size);
  742. { // Sort scatter_index.
  743. #pragma omp parallel for schedule(static)
  744. for (Long i = 0; i < send_size; i++) {
  745. psorted[i].key = scatter_index[i];
  746. psorted[i].data = i;
  747. }
  748. omp_par::merge_sort(psorted.Begin(), psorted.Begin() + send_size);
  749. }
  750. Vector<Long> recv_indx(recv_size);
  751. Vector<Long> send_indx(send_size);
  752. Vector<Long> sendSz(npes);
  753. Vector<Long> sendOff(npes);
  754. Vector<Long> recvSz(npes);
  755. Vector<Long> recvOff(npes);
  756. { // Exchange send, recv indices.
  757. #pragma omp parallel for schedule(static)
  758. for (Long i = 0; i < send_size; i++) {
  759. send_indx[i] = psorted[i].key;
  760. }
  761. #pragma omp parallel for schedule(static)
  762. for (Integer i = 0; i < npes; i++) {
  763. Long start = std::lower_bound(send_indx.Begin(), send_indx.Begin() + send_size, glb_scan[i]) - send_indx.Begin();
  764. Long end = (i + 1 < npes ? std::lower_bound(send_indx.Begin(), send_indx.Begin() + send_size, glb_scan[i + 1]) - send_indx.Begin() : send_size);
  765. sendSz[i] = end - start;
  766. sendOff[i] = start;
  767. }
  768. Alltoall(sendSz.Begin(), 1, recvSz.Begin(), 1);
  769. recvOff[0] = 0;
  770. omp_par::scan(recvSz.Begin(), recvOff.Begin(), npes);
  771. assert(recvOff[npes - 1] + recvSz[npes - 1] == recv_size);
  772. Alltoallv(send_indx.Begin(), sendSz.Begin(), sendOff.Begin(), recv_indx.Begin(), recvSz.Begin(), recvOff.Begin());
  773. #pragma omp parallel for schedule(static)
  774. for (Long i = 0; i < recv_size; i++) {
  775. assert(recv_indx[i] >= glb_scan[rank]);
  776. recv_indx[i] -= glb_scan[rank];
  777. assert(recv_indx[i] < recv_size);
  778. }
  779. }
  780. Vector<Type> send_buff;
  781. { // Prepare send buffer
  782. send_buff.ReInit(send_size * data_dim);
  783. ConstIterator<Type> data = data_.Begin();
  784. #pragma omp parallel for schedule(static)
  785. for (Long i = 0; i < send_size; i++) {
  786. Long src_indx = psorted[i].data * data_dim;
  787. Long trg_indx = i * data_dim;
  788. for (Long j = 0; j < data_dim; j++) send_buff[trg_indx + j] = data[src_indx + j];
  789. }
  790. }
  791. Vector<Type> recv_buff;
  792. { // All2Allv
  793. recv_buff.ReInit(recv_size * data_dim);
  794. #pragma omp parallel for schedule(static)
  795. for (Integer i = 0; i < npes; i++) {
  796. sendSz[i] *= data_dim;
  797. sendOff[i] *= data_dim;
  798. recvSz[i] *= data_dim;
  799. recvOff[i] *= data_dim;
  800. }
  801. Alltoallv(send_buff.Begin(), sendSz.Begin(), sendOff.Begin(), recv_buff.Begin(), recvSz.Begin(), recvOff.Begin());
  802. }
  803. { // Build output data.
  804. data_.ReInit(recv_size * data_dim);
  805. Iterator<Type> data = data_.Begin();
  806. #pragma omp parallel for schedule(static)
  807. for (Long i = 0; i < recv_size; i++) {
  808. Long src_indx = i * data_dim;
  809. Long trg_indx = recv_indx[i] * data_dim;
  810. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = recv_buff[src_indx + j];
  811. }
  812. }
  813. }
  814. #ifdef PVFMM_HAVE_MPI
  815. Vector<MPI_Request>* Comm::NewReq() const {
  816. if (req.empty()) req.push(new Vector<MPI_Request>);
  817. Vector<MPI_Request>& request = *(Vector<MPI_Request>*)req.top();
  818. req.pop();
  819. return &request;
  820. }
  821. void Comm::DelReq(Vector<MPI_Request>* req_ptr) const {
  822. if (req_ptr) req.push(req_ptr);
  823. }
  824. #define HS_MPIDATATYPE(CTYPE, MPITYPE) \
  825. template <> class Comm::CommDatatype<CTYPE> { \
  826. public: \
  827. static MPI_Datatype value() { return MPITYPE; } \
  828. static MPI_Op sum() { return MPI_SUM; } \
  829. static MPI_Op min() { return MPI_MIN; } \
  830. static MPI_Op max() { return MPI_MAX; } \
  831. }
  832. HS_MPIDATATYPE(short, MPI_SHORT);
  833. HS_MPIDATATYPE(int, MPI_INT);
  834. HS_MPIDATATYPE(long, MPI_LONG);
  835. HS_MPIDATATYPE(unsigned short, MPI_UNSIGNED_SHORT);
  836. HS_MPIDATATYPE(unsigned int, MPI_UNSIGNED);
  837. HS_MPIDATATYPE(unsigned long, MPI_UNSIGNED_LONG);
  838. HS_MPIDATATYPE(float, MPI_FLOAT);
  839. HS_MPIDATATYPE(double, MPI_DOUBLE);
  840. HS_MPIDATATYPE(long double, MPI_LONG_DOUBLE);
  841. HS_MPIDATATYPE(long long, MPI_LONG_LONG_INT);
  842. HS_MPIDATATYPE(char, MPI_CHAR);
  843. HS_MPIDATATYPE(unsigned char, MPI_UNSIGNED_CHAR);
  844. #undef HS_MPIDATATYPE
  845. #endif
  846. template <class Type> void Comm::HyperQuickSort(const Vector<Type>& arr_, Vector<Type>& SortedElem) const { // O( ((N/p)+log(p))*(log(N/p)+log(p)) )
  847. #ifdef PVFMM_HAVE_MPI
  848. Integer npes, myrank, omp_p;
  849. { // Get comm size and rank.
  850. npes = Size();
  851. myrank = Rank();
  852. omp_p = omp_get_max_threads();
  853. }
  854. srand(myrank);
  855. Long totSize, nelem = arr_.Dim();
  856. { // Local and global sizes. O(log p)
  857. assert(nelem); // TODO: Check if this is needed.
  858. Allreduce<Long>(PVFMM_PTR2CONSTITR(Long, &nelem, 1), PVFMM_PTR2ITR(Long, &totSize, 1), 1, CommOp::SUM);
  859. }
  860. if (npes == 1) { // SortedElem <--- local_sort(arr_)
  861. SortedElem = arr_;
  862. omp_par::merge_sort(SortedElem.Begin(), SortedElem.Begin() + nelem);
  863. return;
  864. }
  865. Vector<Type> arr;
  866. { // arr <-- local_sort(arr_)
  867. arr = arr_;
  868. omp_par::merge_sort(arr.Begin(), arr.Begin() + nelem);
  869. }
  870. Vector<Type> nbuff, nbuff_ext, rbuff, rbuff_ext; // Allocate memory.
  871. MPI_Comm comm = mpi_comm_; // Copy comm
  872. bool free_comm = false; // Flag to free comm.
  873. // Binary split and merge in each iteration.
  874. while (npes > 1 && totSize > 0) { // O(log p) iterations.
  875. Type split_key;
  876. Long totSize_new;
  877. { // Determine split_key. O( log(N/p) + log(p) )
  878. Integer glb_splt_count;
  879. Vector<Type> glb_splitters;
  880. { // Take random splitters. glb_splt_count = const = 100~1000
  881. Integer splt_count;
  882. { // Set splt_coun. O( 1 ) -- Let p * splt_count = t
  883. splt_count = (100 * nelem) / totSize;
  884. if (npes > 100) splt_count = (drand48() * totSize) < (100 * nelem) ? 1 : 0;
  885. if (splt_count > nelem) splt_count = nelem;
  886. }
  887. Vector<Type> splitters(splt_count);
  888. for (Integer i = 0; i < splt_count; i++) {
  889. splitters[i] = arr[rand() % nelem];
  890. }
  891. Vector<Integer> glb_splt_cnts(npes), glb_splt_disp(npes);
  892. { // Set glb_splt_count, glb_splt_cnts, glb_splt_disp
  893. MPI_Allgather(&splt_count, 1, CommDatatype<Integer>::value(), &glb_splt_cnts[0], 1, CommDatatype<Integer>::value(), comm);
  894. glb_splt_disp[0] = 0;
  895. omp_par::scan(glb_splt_cnts.Begin(), glb_splt_disp.Begin(), npes);
  896. glb_splt_count = glb_splt_cnts[npes - 1] + glb_splt_disp[npes - 1];
  897. PVFMM_ASSERT(glb_splt_count);
  898. }
  899. { // Gather all splitters. O( log(p) )
  900. glb_splitters.ReInit(glb_splt_count);
  901. Vector<int> glb_splt_cnts_(npes), glb_splt_disp_(npes);
  902. for (Integer i = 0; i < npes; i++) {
  903. glb_splt_cnts_[i] = glb_splt_cnts[i];
  904. glb_splt_disp_[i] = glb_splt_disp[i];
  905. }
  906. MPI_Allgatherv((splt_count ? &splitters[0] : NULL), splt_count, CommDatatype<Type>::value(), &glb_splitters[0], &glb_splt_cnts_[0], &glb_splt_disp_[0], CommDatatype<Type>::value(), comm);
  907. }
  908. }
  909. // Determine split key. O( log(N/p) + log(p) )
  910. Vector<Long> lrank(glb_splt_count);
  911. { // Compute local rank
  912. #pragma omp parallel for schedule(static)
  913. for (Integer i = 0; i < glb_splt_count; i++) {
  914. lrank[i] = std::lower_bound(arr.Begin(), arr.Begin() + nelem, glb_splitters[i]) - arr.Begin();
  915. }
  916. }
  917. Vector<Long> grank(glb_splt_count);
  918. { // Compute global rank
  919. MPI_Allreduce(&lrank[0], &grank[0], glb_splt_count, CommDatatype<Long>::value(), CommDatatype<Long>::sum(), comm);
  920. }
  921. { // Determine split_key, totSize_new
  922. ConstIterator<Long> split_disp = grank.Begin();
  923. for (Integer i = 0; i < glb_splt_count; i++) {
  924. if (labs(grank[i] - totSize / 2) < labs(*split_disp - totSize / 2)) {
  925. split_disp = grank.Begin() + i;
  926. }
  927. }
  928. split_key = glb_splitters[split_disp - grank.Begin()];
  929. if (myrank <= (npes - 1) / 2)
  930. totSize_new = split_disp[0];
  931. else
  932. totSize_new = totSize - split_disp[0];
  933. // double err=(((double)*split_disp)/(totSize/2))-1.0;
  934. // if(pvfmm::fabs<double>(err)<0.01 || npes<=16) break;
  935. // else if(!myrank) std::cout<<err<<'\n';
  936. }
  937. }
  938. Integer split_id = (npes - 1) / 2;
  939. { // Split problem into two. O( N/p )
  940. Integer new_p0 = (myrank <= split_id ? 0 : split_id + 1);
  941. Integer cmp_p0 = (myrank > split_id ? 0 : split_id + 1);
  942. Integer partner;
  943. { // Set partner
  944. partner = myrank + cmp_p0 - new_p0;
  945. if (partner >= npes) partner = npes - 1;
  946. assert(partner >= 0);
  947. }
  948. bool extra_partner = (npes % 2 == 1 && npes - 1 == myrank);
  949. Long ssize = 0, lsize = 0;
  950. ConstIterator<Type> sbuff, lbuff;
  951. { // Set ssize, lsize, sbuff, lbuff
  952. Long split_indx = std::lower_bound(arr.Begin(), arr.Begin() + nelem, split_key) - arr.Begin();
  953. ssize = (myrank > split_id ? split_indx : nelem - split_indx);
  954. sbuff = (myrank > split_id ? arr.Begin() : arr.Begin() + split_indx);
  955. lsize = (myrank <= split_id ? split_indx : nelem - split_indx);
  956. lbuff = (myrank <= split_id ? arr.Begin() : arr.Begin() + split_indx);
  957. }
  958. Long rsize = 0, ext_rsize = 0;
  959. { // Get rsize, ext_rsize
  960. Long ext_ssize = 0;
  961. MPI_Status status;
  962. MPI_Sendrecv(&ssize, 1, CommDatatype<Long>::value(), partner, 0, &rsize, 1, CommDatatype<Long>::value(), partner, 0, comm, &status);
  963. if (extra_partner) MPI_Sendrecv(&ext_ssize, 1, CommDatatype<Long>::value(), split_id, 0, &ext_rsize, 1, CommDatatype<Long>::value(), split_id, 0, comm, &status);
  964. }
  965. { // Exchange data.
  966. rbuff.ReInit(rsize);
  967. rbuff_ext.ReInit(ext_rsize);
  968. MPI_Status status;
  969. MPI_Sendrecv((ssize ? &sbuff[0] : NULL), ssize, CommDatatype<Type>::value(), partner, 0, (rsize ? &rbuff[0] : NULL), rsize, CommDatatype<Type>::value(), partner, 0, comm, &status);
  970. if (extra_partner) MPI_Sendrecv(NULL, 0, CommDatatype<Type>::value(), split_id, 0, (ext_rsize ? &rbuff_ext[0] : NULL), ext_rsize, CommDatatype<Type>::value(), split_id, 0, comm, &status);
  971. }
  972. Long nbuff_size = lsize + rsize + ext_rsize;
  973. { // nbuff <-- merge(lbuff, rbuff, rbuff_ext)
  974. nbuff.ReInit(lsize + rsize);
  975. omp_par::merge<ConstIterator<Type>>(lbuff, (lbuff + lsize), rbuff.Begin(), rbuff.Begin() + rsize, nbuff.Begin(), omp_p, std::less<Type>());
  976. if (ext_rsize > 0 && nbuff.Dim() > 0) {
  977. nbuff_ext.ReInit(nbuff_size);
  978. omp_par::merge(nbuff.Begin(), nbuff.Begin() + (lsize + rsize), rbuff_ext.Begin(), rbuff_ext.Begin() + ext_rsize, nbuff_ext.Begin(), omp_p, std::less<Type>());
  979. nbuff.Swap(nbuff_ext);
  980. nbuff_ext.ReInit(0);
  981. }
  982. }
  983. // Copy new data.
  984. totSize = totSize_new;
  985. nelem = nbuff_size;
  986. arr.Swap(nbuff);
  987. nbuff.ReInit(0);
  988. }
  989. { // Split comm. O( log(p) ) ??
  990. MPI_Comm scomm;
  991. MPI_Comm_split(comm, myrank <= split_id, myrank, &scomm);
  992. if (free_comm) MPI_Comm_free(&comm);
  993. comm = scomm;
  994. free_comm = true;
  995. npes = (myrank <= split_id ? split_id + 1 : npes - split_id - 1);
  996. myrank = (myrank <= split_id ? myrank : myrank - split_id - 1);
  997. }
  998. }
  999. if (free_comm) MPI_Comm_free(&comm);
  1000. SortedElem = arr;
  1001. PartitionW<Type>(SortedElem);
  1002. #else
  1003. SortedElem = arr_;
  1004. std::sort(SortedElem.Begin(), SortedElem.Begin() + SortedElem.Dim());
  1005. #endif
  1006. }
  1007. } // end namespace