comm.txx 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122
  1. #include <pvfmm/ompUtils.hpp>
  2. #include <pvfmm/vector.hpp>
  3. namespace pvfmm {
  4. Comm::Comm() {
  5. #ifdef PVFMM_HAVE_MPI
  6. Init(MPI_COMM_SELF);
  7. #endif
  8. }
  9. Comm::Comm(const Comm& c) {
  10. #ifdef PVFMM_HAVE_MPI
  11. Init(c.mpi_comm_);
  12. #endif
  13. }
  14. Comm Comm::Self() {
  15. #ifdef PVFMM_HAVE_MPI
  16. Comm comm_self(MPI_COMM_SELF);
  17. return comm_self;
  18. #else
  19. Comm comm_self;
  20. return comm_self;
  21. #endif
  22. }
  23. Comm Comm::World() {
  24. #ifdef PVFMM_HAVE_MPI
  25. Comm comm_world(MPI_COMM_WORLD);
  26. return comm_world;
  27. #else
  28. Comm comm_self;
  29. return comm_self;
  30. #endif
  31. }
  32. Comm& Comm::operator=(const Comm& c) {
  33. #ifdef PVFMM_HAVE_MPI
  34. Init(c.mpi_comm_);
  35. #endif
  36. }
  37. Comm::~Comm() {
  38. #ifdef PVFMM_HAVE_MPI
  39. while (!req.empty()) {
  40. delete (Vector<MPI_Request>*)req.top();
  41. req.pop();
  42. }
  43. MPI_Comm_free(&mpi_comm_);
  44. #endif
  45. }
  46. Comm Comm::Split(Integer clr) const {
  47. #ifdef PVFMM_HAVE_MPI
  48. MPI_Comm new_comm;
  49. MPI_Comm_split(mpi_comm_, clr, mpi_rank_, &new_comm);
  50. Comm c(new_comm);
  51. MPI_Comm_free(&new_comm);
  52. return c;
  53. #else
  54. Comm c;
  55. return c;
  56. #endif
  57. }
  58. Integer Comm::Rank() const {
  59. #ifdef PVFMM_HAVE_MPI
  60. return mpi_rank_;
  61. #else
  62. return 0;
  63. #endif
  64. }
  65. Integer Comm::Size() const {
  66. #ifdef PVFMM_HAVE_MPI
  67. return mpi_size_;
  68. #else
  69. return 1;
  70. #endif
  71. }
  72. void Comm::Barrier() const {
  73. #ifdef PVFMM_HAVE_MPI
  74. MPI_Barrier(mpi_comm_);
  75. #endif
  76. }
  77. template <class SType> void* Comm::Isend(ConstIterator<SType> sbuf, Long scount, Integer dest, Integer tag) const {
  78. #ifdef PVFMM_HAVE_MPI
  79. if (!scount) return NULL;
  80. Vector<MPI_Request>& request = *NewReq();
  81. request.ReInit(1);
  82. sbuf[0];
  83. sbuf[scount - 1];
  84. #ifndef NDEBUG
  85. MPI_Issend(&sbuf[0], scount, CommDatatype<SType>::value(), dest, tag, mpi_comm_, &request[0]);
  86. #else
  87. MPI_Isend(&sbuf[0], scount, CommDatatype<SType>::value(), dest, tag, mpi_comm_, &request[0]);
  88. #endif
  89. return &request;
  90. #else
  91. auto it = recv_req.find(tag);
  92. if (it == recv_req.end()) send_req.insert(std::pair<Integer,ConstIterator<char>>(tag,(ConstIterator<char>)sbuf));
  93. else pvfmm::memcopy(it->second, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  94. return NULL;
  95. #endif
  96. }
  97. template <class RType> void* Comm::Irecv(Iterator<RType> rbuf, Long rcount, Integer source, Integer tag) const {
  98. #ifdef PVFMM_HAVE_MPI
  99. if (!rcount) return NULL;
  100. Vector<MPI_Request>& request = *NewReq();
  101. request.ReInit(1);
  102. rbuf[0];
  103. rbuf[rcount - 1];
  104. MPI_Irecv(&rbuf[0], rcount, CommDatatype<RType>::value(), source, tag, mpi_comm_, &request[0]);
  105. return &request;
  106. #else
  107. auto it = send_req.find(tag);
  108. if (it == send_req.end()) recv_req.insert(std::pair<Integer,Iterator<char>>(tag,(Iterator<char>)rbuf));
  109. else pvfmm::memcopy((Iterator<char>)rbuf, it->second, rcount * sizeof(RType));
  110. return NULL;
  111. #endif
  112. }
  113. void Comm::Wait(void* req_ptr) const {
  114. #ifdef PVFMM_HAVE_MPI
  115. if (req_ptr == NULL) return;
  116. Vector<MPI_Request>& request = *(Vector<MPI_Request>*)req_ptr;
  117. // std::vector<MPI_Status> status(request.Dim());
  118. if (request.Dim()) MPI_Waitall(request.Dim(), &request[0], MPI_STATUSES_IGNORE); //&status[0]);
  119. DelReq(&request);
  120. #endif
  121. }
  122. template <class SType, class RType> void Comm::Allgather(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const {
  123. #ifdef PVFMM_HAVE_MPI
  124. if (scount) {
  125. sbuf[0];
  126. sbuf[scount - 1];
  127. }
  128. if (rcount) {
  129. rbuf[0];
  130. rbuf[rcount * Size() - 1];
  131. }
  132. MPI_Allgather((scount ? &sbuf[0] : NULL), scount, CommDatatype<SType>::value(), (rcount ? &rbuf[0] : NULL), rcount, CommDatatype<RType>::value(), mpi_comm_);
  133. #else
  134. pvfmm::memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  135. #endif
  136. }
  137. template <class SType, class RType> void Comm::Allgatherv(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const {
  138. #ifdef PVFMM_HAVE_MPI
  139. Vector<int> rcounts_(mpi_size_), rdispls_(mpi_size_);
  140. Long rcount_sum = 0;
  141. #pragma omp parallel for schedule(static) reduction(+ : rcount_sum)
  142. for (Integer i = 0; i < mpi_size_; i++) {
  143. rcounts_[i] = rcounts[i];
  144. rdispls_[i] = rdispls[i];
  145. rcount_sum += rcounts[i];
  146. }
  147. if (scount) {
  148. sbuf[0];
  149. sbuf[scount - 1];
  150. }
  151. if (rcount_sum) {
  152. rbuf[0];
  153. rbuf[rcount_sum - 1];
  154. }
  155. MPI_Allgatherv((scount ? &sbuf[0] : NULL), scount, CommDatatype<SType>::value(), (rcount_sum ? &rbuf[0] : NULL), &rcounts_.Begin()[0], &rdispls_.Begin()[0], CommDatatype<RType>::value(), mpi_comm_);
  156. #else
  157. pvfmm::memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)sbuf, scount * sizeof(SType));
  158. #endif
  159. }
  160. template <class SType, class RType> void Comm::Alltoall(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const {
  161. #ifdef PVFMM_HAVE_MPI
  162. if (scount) {
  163. sbuf[0];
  164. sbuf[scount * Size() - 1];
  165. }
  166. if (rcount) {
  167. rbuf[0];
  168. rbuf[rcount * Size() - 1];
  169. }
  170. MPI_Alltoall((scount ? &sbuf[0] : NULL), scount, CommDatatype<SType>::value(), (rcount ? &rbuf[0] : NULL), rcount, CommDatatype<RType>::value(), mpi_comm_);
  171. #else
  172. pvfmm::memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  173. #endif
  174. }
  175. template <class SType, class RType> void* Comm::Ialltoallv_sparse(ConstIterator<SType> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls, Integer tag) const {
  176. #ifdef PVFMM_HAVE_MPI
  177. Integer request_count = 0;
  178. for (Integer i = 0; i < mpi_size_; i++) {
  179. if (rcounts[i]) request_count++;
  180. if (scounts[i]) request_count++;
  181. }
  182. if (!request_count) return NULL;
  183. Vector<MPI_Request>& request = *NewReq();
  184. request.ReInit(request_count);
  185. Integer request_iter = 0;
  186. for (Integer i = 0; i < mpi_size_; i++) {
  187. if (rcounts[i]) {
  188. rbuf[rdispls[i]];
  189. rbuf[rdispls[i] + rcounts[i] - 1];
  190. MPI_Irecv(&rbuf[rdispls[i]], rcounts[i], CommDatatype<RType>::value(), i, tag, mpi_comm_, &request[request_iter]);
  191. request_iter++;
  192. }
  193. }
  194. for (Integer i = 0; i < mpi_size_; i++) {
  195. if (scounts[i]) {
  196. sbuf[sdispls[i]];
  197. sbuf[sdispls[i] + scounts[i] - 1];
  198. MPI_Isend(&sbuf[sdispls[i]], scounts[i], CommDatatype<SType>::value(), i, tag, mpi_comm_, &request[request_iter]);
  199. request_iter++;
  200. }
  201. }
  202. return &request;
  203. #else
  204. pvfmm::memcopy((Iterator<char>)(rbuf+rdispls[0]), (ConstIterator<char>)(sbuf+sdispls[0]), scounts[0] * sizeof(SType));
  205. return NULL;
  206. #endif
  207. }
  208. template <class Type> void Comm::Alltoallv(ConstIterator<Type> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<Type> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const {
  209. #ifdef PVFMM_HAVE_MPI
  210. { // Use Alltoallv_sparse of average connectivity<64
  211. Long connectivity = 0, glb_connectivity = 0;
  212. #pragma omp parallel for schedule(static) reduction(+ : connectivity)
  213. for (Integer i = 0; i < mpi_size_; i++) {
  214. if (rcounts[i]) connectivity++;
  215. }
  216. Allreduce(PVFMM_PTR2CONSTITR(Long, &connectivity, 1), PVFMM_PTR2ITR(Long, &glb_connectivity, 1), 1, CommOp::SUM);
  217. if (glb_connectivity < 64 * Size()) {
  218. void* mpi_req = Ialltoallv_sparse(sbuf, scounts, sdispls, rbuf, rcounts, rdispls, 0);
  219. Wait(mpi_req);
  220. return;
  221. }
  222. }
  223. { // Use vendor MPI_Alltoallv
  224. //#ifndef ALLTOALLV_FIX
  225. Vector<int> scnt, sdsp, rcnt, rdsp;
  226. scnt.ReInit(mpi_size_);
  227. sdsp.ReInit(mpi_size_);
  228. rcnt.ReInit(mpi_size_);
  229. rdsp.ReInit(mpi_size_);
  230. Long stotal = 0, rtotal = 0;
  231. #pragma omp parallel for schedule(static) reduction(+ : stotal, rtotal)
  232. for (Integer i = 0; i < mpi_size_; i++) {
  233. scnt[i] = scounts[i];
  234. sdsp[i] = sdispls[i];
  235. rcnt[i] = rcounts[i];
  236. rdsp[i] = rdispls[i];
  237. stotal += scounts[i];
  238. rtotal += rcounts[i];
  239. }
  240. MPI_Alltoallv((stotal ? &sbuf[0] : NULL), &scnt[0], &sdsp[0], CommDatatype<Type>::value(), (rtotal ? &rbuf[0] : NULL), &rcnt[0], &rdsp[0], CommDatatype<Type>::value(), mpi_comm_);
  241. return;
  242. //#endif
  243. }
  244. // TODO: implement hypercube scheme
  245. #else
  246. pvfmm::memcopy((Iterator<char>)(rbuf+rdispls[0]), (ConstIterator<char>)(sbuf+sdispls[0]), scounts[0] * sizeof(Type));
  247. #endif
  248. }
  249. template <class Type> void Comm::Allreduce(ConstIterator<Type> sbuf, Iterator<Type> rbuf, Long count, CommOp op) const {
  250. #ifdef PVFMM_HAVE_MPI
  251. if (!count) return;
  252. MPI_Op mpi_op;
  253. switch (op) {
  254. case CommOp::SUM:
  255. mpi_op = CommDatatype<Type>::sum();
  256. break;
  257. case CommOp::MIN:
  258. mpi_op = CommDatatype<Type>::min();
  259. break;
  260. case CommOp::MAX:
  261. mpi_op = CommDatatype<Type>::max();
  262. break;
  263. default:
  264. mpi_op = MPI_OP_NULL;
  265. break;
  266. }
  267. sbuf[0];
  268. sbuf[count - 1];
  269. rbuf[0];
  270. rbuf[count - 1];
  271. MPI_Allreduce(&sbuf[0], &rbuf[0], count, CommDatatype<Type>::value(), mpi_op, mpi_comm_);
  272. #else
  273. pvfmm::memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, count * sizeof(Type));
  274. #endif
  275. }
  276. template <class Type> void Comm::Scan(ConstIterator<Type> sbuf, Iterator<Type> rbuf, int count, CommOp op) const {
  277. #ifdef PVFMM_HAVE_MPI
  278. if (!count) return;
  279. MPI_Op mpi_op;
  280. switch (op) {
  281. case CommOp::SUM:
  282. mpi_op = CommDatatype<Type>::sum();
  283. break;
  284. case CommOp::MIN:
  285. mpi_op = CommDatatype<Type>::min();
  286. break;
  287. case CommOp::MAX:
  288. mpi_op = CommDatatype<Type>::max();
  289. break;
  290. default:
  291. mpi_op = MPI_OP_NULL;
  292. break;
  293. }
  294. sbuf[0];
  295. sbuf[count - 1];
  296. rbuf[0];
  297. rbuf[count - 1];
  298. MPI_Scan(&sbuf[0], &rbuf[0], count, CommDatatype<Type>::value(), mpi_op, mpi_comm_);
  299. #else
  300. pvfmm::memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, count * sizeof(Type));
  301. #endif
  302. }
  303. template <class Type> void Comm::PartitionW(Vector<Type>& nodeList, const Vector<Long>* wts_) const {
  304. Integer npes = Size();
  305. if (npes == 1) return;
  306. Long nlSize = nodeList.Dim();
  307. Vector<Long> wts;
  308. Long localWt = 0;
  309. if (wts_ == NULL) { // Construct arrays of wts.
  310. wts.ReInit(nlSize);
  311. #pragma omp parallel for schedule(static)
  312. for (Long i = 0; i < nlSize; i++) {
  313. wts[i] = 1;
  314. }
  315. localWt = nlSize;
  316. } else {
  317. wts.ReInit(nlSize, (Iterator<Long>)wts_->Begin(), false);
  318. #pragma omp parallel for reduction(+ : localWt)
  319. for (Long i = 0; i < nlSize; i++) {
  320. localWt += wts[i];
  321. }
  322. }
  323. Long off1 = 0, off2 = 0, totalWt = 0;
  324. { // compute the total weight of the problem ...
  325. Allreduce<Long>(PVFMM_PTR2CONSTITR(Long, &localWt, 1), PVFMM_PTR2ITR(Long, &totalWt, 1), 1, CommOp::SUM);
  326. Scan<Long>(PVFMM_PTR2CONSTITR(Long, &localWt, 1), PVFMM_PTR2ITR(Long, &off2, 1), 1, CommOp::SUM);
  327. off1 = off2 - localWt;
  328. }
  329. Vector<Long> lscn;
  330. if (nlSize) { // perform a local scan on the weights first ...
  331. lscn.ReInit(nlSize);
  332. lscn[0] = off1;
  333. omp_par::scan(wts.Begin(), lscn.Begin(), nlSize);
  334. }
  335. Vector<Long> sendSz, recvSz, sendOff, recvOff;
  336. sendSz.ReInit(npes);
  337. recvSz.ReInit(npes);
  338. sendOff.ReInit(npes);
  339. recvOff.ReInit(npes);
  340. sendSz.SetZero();
  341. if (nlSize > 0 && totalWt > 0) { // Compute sendSz
  342. Long pid1 = (off1 * npes) / totalWt;
  343. Long pid2 = ((off2 + 1) * npes) / totalWt + 1;
  344. assert((totalWt * pid2) / npes >= off2);
  345. pid1 = (pid1 < 0 ? 0 : pid1);
  346. pid2 = (pid2 > npes ? npes : pid2);
  347. #pragma omp parallel for schedule(static)
  348. for (Integer i = pid1; i < pid2; i++) {
  349. Long wt1 = (totalWt * (i)) / npes;
  350. Long wt2 = (totalWt * (i + 1)) / npes;
  351. Long start = std::lower_bound(lscn.Begin(), lscn.Begin() + nlSize, wt1, std::less<Long>()) - lscn.Begin();
  352. Long end = std::lower_bound(lscn.Begin(), lscn.Begin() + nlSize, wt2, std::less<Long>()) - lscn.Begin();
  353. if (i == 0) start = 0;
  354. if (i == npes - 1) end = nlSize;
  355. sendSz[i] = end - start;
  356. }
  357. } else {
  358. sendSz[0] = nlSize;
  359. }
  360. // Exchange sendSz, recvSz
  361. Alltoall<Long>(sendSz.Begin(), 1, recvSz.Begin(), 1);
  362. { // Compute sendOff, recvOff
  363. sendOff[0] = 0;
  364. omp_par::scan(sendSz.Begin(), sendOff.Begin(), npes);
  365. recvOff[0] = 0;
  366. omp_par::scan(recvSz.Begin(), recvOff.Begin(), npes);
  367. assert(sendOff[npes - 1] + sendSz[npes - 1] == nlSize);
  368. }
  369. // perform All2All ...
  370. Vector<Type> newNodes;
  371. newNodes.ReInit(recvSz[npes - 1] + recvOff[npes - 1]);
  372. void* mpi_req = Ialltoallv_sparse<Type>(nodeList.Begin(), sendSz.Begin(), sendOff.Begin(), newNodes.Begin(), recvSz.Begin(), recvOff.Begin());
  373. Wait(mpi_req);
  374. // reset the pointer ...
  375. nodeList.Swap(newNodes);
  376. }
  377. template <class Type> void Comm::PartitionN(Vector<Type>& v, Long N) const {
  378. Integer rank = Rank();
  379. Integer np = Size();
  380. if (np == 1) return;
  381. Vector<Long> v_cnt(np), v_dsp(np + 1);
  382. Vector<Long> N_cnt(np), N_dsp(np + 1);
  383. { // Set v_cnt, v_dsp
  384. v_dsp[0] = 0;
  385. Long cnt = v.Dim();
  386. Allgather(PVFMM_PTR2CONSTITR(Long, &cnt, 1), 1, v_cnt.Begin(), 1);
  387. omp_par::scan(v_cnt.Begin(), v_dsp.Begin(), np);
  388. v_dsp[np] = v_cnt[np - 1] + v_dsp[np - 1];
  389. }
  390. { // Set N_cnt, N_dsp
  391. N_dsp[0] = 0;
  392. Long cnt = N;
  393. Allgather(PVFMM_PTR2CONSTITR(Long, &cnt, 1), 1, N_cnt.Begin(), 1);
  394. omp_par::scan(N_cnt.Begin(), N_dsp.Begin(), np);
  395. N_dsp[np] = N_cnt[np - 1] + N_dsp[np - 1];
  396. }
  397. { // Adjust for dof
  398. Long dof = v_dsp[np] / N_dsp[np];
  399. assert(dof * N_dsp[np] == v_dsp[np]);
  400. if (dof == 0) return;
  401. if (dof != 1) {
  402. #pragma omp parallel for schedule(static)
  403. for (Integer i = 0; i < np; i++) N_cnt[i] *= dof;
  404. #pragma omp parallel for schedule(static)
  405. for (Integer i = 0; i <= np; i++) N_dsp[i] *= dof;
  406. }
  407. }
  408. Vector<Type> v_(N_cnt[rank]);
  409. { // Set v_
  410. Vector<Long> scnt(np), sdsp(np);
  411. Vector<Long> rcnt(np), rdsp(np);
  412. #pragma omp parallel for schedule(static)
  413. for (Integer i = 0; i < np; i++) {
  414. { // Set scnt
  415. Long n0 = N_dsp[i + 0];
  416. Long n1 = N_dsp[i + 1];
  417. if (n0 < v_dsp[rank + 0]) n0 = v_dsp[rank + 0];
  418. if (n1 < v_dsp[rank + 0]) n1 = v_dsp[rank + 0];
  419. if (n0 > v_dsp[rank + 1]) n0 = v_dsp[rank + 1];
  420. if (n1 > v_dsp[rank + 1]) n1 = v_dsp[rank + 1];
  421. scnt[i] = n1 - n0;
  422. }
  423. { // Set rcnt
  424. Long n0 = v_dsp[i + 0];
  425. Long n1 = v_dsp[i + 1];
  426. if (n0 < N_dsp[rank + 0]) n0 = N_dsp[rank + 0];
  427. if (n1 < N_dsp[rank + 0]) n1 = N_dsp[rank + 0];
  428. if (n0 > N_dsp[rank + 1]) n0 = N_dsp[rank + 1];
  429. if (n1 > N_dsp[rank + 1]) n1 = N_dsp[rank + 1];
  430. rcnt[i] = n1 - n0;
  431. }
  432. }
  433. sdsp[0] = 0;
  434. omp_par::scan(scnt.Begin(), sdsp.Begin(), np);
  435. rdsp[0] = 0;
  436. omp_par::scan(rcnt.Begin(), rdsp.Begin(), np);
  437. void* mpi_request = Ialltoallv_sparse(v.Begin(), scnt.Begin(), sdsp.Begin(), v_.Begin(), rcnt.Begin(), rdsp.Begin());
  438. Wait(mpi_request);
  439. }
  440. v.Swap(v_);
  441. }
  442. template <class Type> void Comm::PartitionS(Vector<Type>& nodeList, const Type& splitter) const {
  443. Integer npes = Size();
  444. if (npes == 1) return;
  445. Vector<Type> mins(npes);
  446. Allgather(PVFMM_PTR2CONSTITR(Type, &splitter, 1), 1, mins.Begin(), 1);
  447. Vector<Long> scnt(npes), sdsp(npes);
  448. Vector<Long> rcnt(npes), rdsp(npes);
  449. { // Compute scnt, sdsp
  450. #pragma omp parallel for schedule(static)
  451. for (Integer i = 0; i < npes; i++) {
  452. sdsp[i] = std::lower_bound(nodeList.Begin(), nodeList.Begin() + nodeList.Dim(), mins[i]) - nodeList.Begin();
  453. }
  454. #pragma omp parallel for schedule(static)
  455. for (Integer i = 0; i < npes - 1; i++) {
  456. scnt[i] = sdsp[i + 1] - sdsp[i];
  457. }
  458. scnt[npes - 1] = nodeList.Dim() - sdsp[npes - 1];
  459. }
  460. { // Compute rcnt, rdsp
  461. rdsp[0] = 0;
  462. Alltoall(scnt.Begin(), 1, rcnt.Begin(), 1);
  463. omp_par::scan(rcnt.Begin(), rdsp.Begin(), npes);
  464. }
  465. { // Redistribute nodeList
  466. Vector<Type> nodeList_(rdsp[npes - 1] + rcnt[npes - 1]);
  467. void* mpi_request = Ialltoallv_sparse(nodeList.Begin(), scnt.Begin(), sdsp.Begin(), nodeList_.Begin(), rcnt.Begin(), rdsp.Begin());
  468. Wait(mpi_request);
  469. nodeList.Swap(nodeList_);
  470. }
  471. }
  472. template <class Type> void Comm::SortScatterIndex(const Vector<Type>& key, Vector<Long>& scatter_index, const Type* split_key_) const {
  473. typedef SortPair<Type, Long> Pair_t;
  474. Integer npes = Size(), rank = Rank();
  475. Vector<Pair_t> parray(key.Dim());
  476. { // Build global index.
  477. Long glb_dsp = 0;
  478. Long loc_size = key.Dim();
  479. Scan(PVFMM_PTR2CONSTITR(Long, &loc_size, 1), PVFMM_PTR2ITR(Long, &glb_dsp, 1), 1, CommOp::SUM);
  480. glb_dsp -= loc_size;
  481. #pragma omp parallel for schedule(static)
  482. for (Long i = 0; i < loc_size; i++) {
  483. parray[i].key = key[i];
  484. parray[i].data = glb_dsp + i;
  485. }
  486. }
  487. Vector<Pair_t> psorted;
  488. HyperQuickSort(parray, psorted);
  489. if (npes > 1 && split_key_ != NULL) { // Partition data
  490. Vector<Type> split_key(npes);
  491. Allgather(PVFMM_PTR2CONSTITR(Type, split_key_, 1), 1, split_key.Begin(), 1);
  492. Vector<Long> sendSz(npes);
  493. Vector<Long> recvSz(npes);
  494. Vector<Long> sendOff(npes);
  495. Vector<Long> recvOff(npes);
  496. Long nlSize = psorted.Dim();
  497. sendSz.SetZero();
  498. if (nlSize > 0) { // Compute sendSz
  499. // Determine processor range.
  500. Long pid1 = std::lower_bound(split_key.Begin(), split_key.Begin() + npes, psorted[0].key) - split_key.Begin() - 1;
  501. Long pid2 = std::upper_bound(split_key.Begin(), split_key.Begin() + npes, psorted[nlSize - 1].key) - split_key.Begin() + 1;
  502. pid1 = (pid1 < 0 ? 0 : pid1);
  503. pid2 = (pid2 > npes ? npes : pid2);
  504. #pragma omp parallel for schedule(static)
  505. for (Integer i = pid1; i < pid2; i++) {
  506. Pair_t p1;
  507. p1.key = split_key[i];
  508. Pair_t p2;
  509. p2.key = split_key[i + 1 < npes ? i + 1 : i];
  510. Long start = std::lower_bound(psorted.Begin(), psorted.Begin() + nlSize, p1, std::less<Pair_t>()) - psorted.Begin();
  511. Long end = std::lower_bound(psorted.Begin(), psorted.Begin() + nlSize, p2, std::less<Pair_t>()) - psorted.Begin();
  512. if (i == 0) start = 0;
  513. if (i == npes - 1) end = nlSize;
  514. sendSz[i] = end - start;
  515. }
  516. }
  517. // Exchange sendSz, recvSz
  518. Alltoall<Long>(sendSz.Begin(), 1, recvSz.Begin(), 1);
  519. // compute offsets ...
  520. { // Compute sendOff, recvOff
  521. sendOff[0] = 0;
  522. omp_par::scan(sendSz.Begin(), sendOff.Begin(), npes);
  523. recvOff[0] = 0;
  524. omp_par::scan(recvSz.Begin(), recvOff.Begin(), npes);
  525. assert(sendOff[npes - 1] + sendSz[npes - 1] == nlSize);
  526. }
  527. // perform All2All ...
  528. Vector<Pair_t> newNodes(recvSz[npes - 1] + recvOff[npes - 1]);
  529. void* mpi_req = Ialltoallv_sparse<Pair_t>(psorted.Begin(), sendSz.Begin(), sendOff.Begin(), newNodes.Begin(), recvSz.Begin(), recvOff.Begin());
  530. Wait(mpi_req);
  531. // reset the pointer ...
  532. psorted.Swap(newNodes);
  533. }
  534. scatter_index.ReInit(psorted.Dim());
  535. #pragma omp parallel for schedule(static)
  536. for (Long i = 0; i < psorted.Dim(); i++) {
  537. scatter_index[i] = psorted[i].data;
  538. }
  539. }
  540. template <class Type> void Comm::ScatterForward(Vector<Type>& data_, const Vector<Long>& scatter_index) const {
  541. typedef SortPair<Long, Long> Pair_t;
  542. Integer npes = Size(), rank = Rank();
  543. Long data_dim = 0;
  544. Long send_size = 0;
  545. Long recv_size = 0;
  546. { // Set data_dim, send_size, recv_size
  547. recv_size = scatter_index.Dim();
  548. StaticArray<Long, 2> glb_size;
  549. StaticArray<Long, 2> loc_size;
  550. loc_size[0] = data_.Dim();
  551. loc_size[1] = recv_size;
  552. Allreduce(loc_size, glb_size, 2, CommOp::SUM);
  553. if (glb_size[0] == 0 || glb_size[1] == 0) return; // Nothing to be done.
  554. data_dim = glb_size[0] / glb_size[1];
  555. PVFMM_ASSERT(glb_size[0] == data_dim * glb_size[1]);
  556. send_size = data_.Dim() / data_dim;
  557. }
  558. if (npes == 1) { // Scatter directly
  559. Vector<Type> data;
  560. data.ReInit(recv_size * data_dim);
  561. #pragma omp parallel for schedule(static)
  562. for (Long i = 0; i < recv_size; i++) {
  563. Long src_indx = scatter_index[i] * data_dim;
  564. Long trg_indx = i * data_dim;
  565. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = data_[src_indx + j];
  566. }
  567. data_.Swap(data);
  568. return;
  569. }
  570. Vector<Long> glb_scan;
  571. { // Global scan of data size.
  572. glb_scan.ReInit(npes);
  573. Long glb_rank = 0;
  574. Scan(PVFMM_PTR2CONSTITR(Long, &send_size, 1), PVFMM_PTR2ITR(Long, &glb_rank, 1), 1, CommOp::SUM);
  575. glb_rank -= send_size;
  576. Allgather(PVFMM_PTR2CONSTITR(Long, &glb_rank, 1), 1, glb_scan.Begin(), 1);
  577. }
  578. Vector<Pair_t> psorted;
  579. { // Sort scatter_index.
  580. psorted.ReInit(recv_size);
  581. #pragma omp parallel for schedule(static)
  582. for (Long i = 0; i < recv_size; i++) {
  583. psorted[i].key = scatter_index[i];
  584. psorted[i].data = i;
  585. }
  586. omp_par::merge_sort(psorted.Begin(), psorted.Begin() + recv_size);
  587. }
  588. Vector<Long> recv_indx(recv_size);
  589. Vector<Long> send_indx(send_size);
  590. Vector<Long> sendSz(npes);
  591. Vector<Long> sendOff(npes);
  592. Vector<Long> recvSz(npes);
  593. Vector<Long> recvOff(npes);
  594. { // Exchange send, recv indices.
  595. #pragma omp parallel for schedule(static)
  596. for (Long i = 0; i < recv_size; i++) {
  597. recv_indx[i] = psorted[i].key;
  598. }
  599. #pragma omp parallel for schedule(static)
  600. for (Integer i = 0; i < npes; i++) {
  601. Long start = std::lower_bound(recv_indx.Begin(), recv_indx.Begin() + recv_size, glb_scan[i]) - recv_indx.Begin();
  602. Long end = (i + 1 < npes ? std::lower_bound(recv_indx.Begin(), recv_indx.Begin() + recv_size, glb_scan[i + 1]) - recv_indx.Begin() : recv_size);
  603. recvSz[i] = end - start;
  604. recvOff[i] = start;
  605. }
  606. Alltoall(recvSz.Begin(), 1, sendSz.Begin(), 1);
  607. sendOff[0] = 0;
  608. omp_par::scan(sendSz.Begin(), sendOff.Begin(), npes);
  609. assert(sendOff[npes - 1] + sendSz[npes - 1] == send_size);
  610. Alltoallv(recv_indx.Begin(), recvSz.Begin(), recvOff.Begin(), send_indx.Begin(), sendSz.Begin(), sendOff.Begin());
  611. #pragma omp parallel for schedule(static)
  612. for (Long i = 0; i < send_size; i++) {
  613. assert(send_indx[i] >= glb_scan[rank]);
  614. send_indx[i] -= glb_scan[rank];
  615. assert(send_indx[i] < send_size);
  616. }
  617. }
  618. Vector<Type> send_buff;
  619. { // Prepare send buffer
  620. send_buff.ReInit(send_size * data_dim);
  621. ConstIterator<Type> data = data_.Begin();
  622. #pragma omp parallel for schedule(static)
  623. for (Long i = 0; i < send_size; i++) {
  624. Long src_indx = send_indx[i] * data_dim;
  625. Long trg_indx = i * data_dim;
  626. for (Long j = 0; j < data_dim; j++) send_buff[trg_indx + j] = data[src_indx + j];
  627. }
  628. }
  629. Vector<Type> recv_buff;
  630. { // All2Allv
  631. recv_buff.ReInit(recv_size * data_dim);
  632. #pragma omp parallel for schedule(static)
  633. for (Integer i = 0; i < npes; i++) {
  634. sendSz[i] *= data_dim;
  635. sendOff[i] *= data_dim;
  636. recvSz[i] *= data_dim;
  637. recvOff[i] *= data_dim;
  638. }
  639. Alltoallv(send_buff.Begin(), sendSz.Begin(), sendOff.Begin(), recv_buff.Begin(), recvSz.Begin(), recvOff.Begin());
  640. }
  641. { // Build output data.
  642. data_.ReInit(recv_size * data_dim);
  643. Iterator<Type> data = data_.Begin();
  644. #pragma omp parallel for schedule(static)
  645. for (Long i = 0; i < recv_size; i++) {
  646. Long src_indx = i * data_dim;
  647. Long trg_indx = psorted[i].data * data_dim;
  648. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = recv_buff[src_indx + j];
  649. }
  650. }
  651. }
  652. template <class Type> void Comm::ScatterReverse(Vector<Type>& data_, const Vector<Long>& scatter_index_, Long loc_size_) const {
  653. typedef SortPair<Long, Long> Pair_t;
  654. Integer npes = Size(), rank = Rank();
  655. Long data_dim = 0;
  656. Long send_size = 0;
  657. Long recv_size = 0;
  658. { // Set data_dim, send_size, recv_size
  659. recv_size = loc_size_;
  660. StaticArray<Long, 3> glb_size;
  661. StaticArray<Long, 3> loc_size;
  662. loc_size[0] = data_.Dim();
  663. loc_size[1] = scatter_index_.Dim();
  664. loc_size[2] = recv_size;
  665. Allreduce(loc_size, glb_size, 3, CommOp::SUM);
  666. if (glb_size[0] == 0 || glb_size[1] == 0) return; // Nothing to be done.
  667. PVFMM_ASSERT(glb_size[0] % glb_size[1] == 0);
  668. data_dim = glb_size[0] / glb_size[1];
  669. PVFMM_ASSERT(loc_size[0] % data_dim == 0);
  670. send_size = loc_size[0] / data_dim;
  671. if (glb_size[0] != glb_size[2] * data_dim) {
  672. recv_size = (((rank + 1) * (glb_size[0] / data_dim)) / npes) - ((rank * (glb_size[0] / data_dim)) / npes);
  673. }
  674. }
  675. if (npes == 1) { // Scatter directly
  676. Vector<Type> data;
  677. data.ReInit(recv_size * data_dim);
  678. #pragma omp parallel for schedule(static)
  679. for (Long i = 0; i < recv_size; i++) {
  680. Long src_indx = i * data_dim;
  681. Long trg_indx = scatter_index_[i] * data_dim;
  682. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = data_[src_indx + j];
  683. }
  684. data_.Swap(data);
  685. return;
  686. }
  687. Vector<Long> scatter_index;
  688. {
  689. StaticArray<Long, 2> glb_rank;
  690. StaticArray<Long, 3> glb_size;
  691. StaticArray<Long, 2> loc_size;
  692. loc_size[0] = data_.Dim() / data_dim;
  693. loc_size[1] = scatter_index_.Dim();
  694. Scan(loc_size, glb_rank, 2, CommOp::SUM);
  695. Allreduce(loc_size, glb_size, 2, CommOp::SUM);
  696. PVFMM_ASSERT(glb_size[0] == glb_size[1]);
  697. glb_rank[0] -= loc_size[0];
  698. glb_rank[1] -= loc_size[1];
  699. Vector<Long> glb_scan0(npes + 1);
  700. Vector<Long> glb_scan1(npes + 1);
  701. Allgather(glb_rank + 0, 1, glb_scan0.Begin(), 1);
  702. Allgather(glb_rank + 1, 1, glb_scan1.Begin(), 1);
  703. glb_scan0[npes] = glb_size[0];
  704. glb_scan1[npes] = glb_size[1];
  705. if (loc_size[0] != loc_size[1] || glb_rank[0] != glb_rank[1]) { // Repartition scatter_index
  706. scatter_index.ReInit(loc_size[0]);
  707. Vector<Long> send_dsp(npes + 1);
  708. Vector<Long> recv_dsp(npes + 1);
  709. #pragma omp parallel for schedule(static)
  710. for (Integer i = 0; i <= npes; i++) {
  711. send_dsp[i] = std::min(std::max(glb_scan0[i], glb_rank[1]), glb_rank[1] + loc_size[1]) - glb_rank[1];
  712. recv_dsp[i] = std::min(std::max(glb_scan1[i], glb_rank[0]), glb_rank[0] + loc_size[0]) - glb_rank[0];
  713. }
  714. // Long commCnt=0;
  715. Vector<Long> send_cnt(npes + 0);
  716. Vector<Long> recv_cnt(npes + 0);
  717. #pragma omp parallel for schedule(static) // reduction(+:commCnt)
  718. for (Integer i = 0; i < npes; i++) {
  719. send_cnt[i] = send_dsp[i + 1] - send_dsp[i];
  720. recv_cnt[i] = recv_dsp[i + 1] - recv_dsp[i];
  721. // if(send_cnt[i] && i!=rank) commCnt++;
  722. // if(recv_cnt[i] && i!=rank) commCnt++;
  723. }
  724. void* mpi_req = Ialltoallv_sparse<Long>(scatter_index_.Begin(), send_cnt.Begin(), send_dsp.Begin(), scatter_index.Begin(), recv_cnt.Begin(), recv_dsp.Begin(), 0);
  725. Wait(mpi_req);
  726. } else {
  727. scatter_index.ReInit(scatter_index_.Dim(), (Iterator<Long>)scatter_index_.Begin(), false);
  728. }
  729. }
  730. Vector<Long> glb_scan(npes);
  731. { // Global data size.
  732. Long glb_rank = 0;
  733. Scan(PVFMM_PTR2CONSTITR(Long, &recv_size, 1), PVFMM_PTR2ITR(Long, &glb_rank, 1), 1, CommOp::SUM);
  734. glb_rank -= recv_size;
  735. Allgather(PVFMM_PTR2CONSTITR(Long, &glb_rank, 1), 1, glb_scan.Begin(), 1);
  736. }
  737. Vector<Pair_t> psorted(send_size);
  738. { // Sort scatter_index.
  739. #pragma omp parallel for schedule(static)
  740. for (Long i = 0; i < send_size; i++) {
  741. psorted[i].key = scatter_index[i];
  742. psorted[i].data = i;
  743. }
  744. omp_par::merge_sort(psorted.Begin(), psorted.Begin() + send_size);
  745. }
  746. Vector<Long> recv_indx(recv_size);
  747. Vector<Long> send_indx(send_size);
  748. Vector<Long> sendSz(npes);
  749. Vector<Long> sendOff(npes);
  750. Vector<Long> recvSz(npes);
  751. Vector<Long> recvOff(npes);
  752. { // Exchange send, recv indices.
  753. #pragma omp parallel for schedule(static)
  754. for (Long i = 0; i < send_size; i++) {
  755. send_indx[i] = psorted[i].key;
  756. }
  757. #pragma omp parallel for schedule(static)
  758. for (Integer i = 0; i < npes; i++) {
  759. Long start = std::lower_bound(send_indx.Begin(), send_indx.Begin() + send_size, glb_scan[i]) - send_indx.Begin();
  760. Long end = (i + 1 < npes ? std::lower_bound(send_indx.Begin(), send_indx.Begin() + send_size, glb_scan[i + 1]) - send_indx.Begin() : send_size);
  761. sendSz[i] = end - start;
  762. sendOff[i] = start;
  763. }
  764. Alltoall(sendSz.Begin(), 1, recvSz.Begin(), 1);
  765. recvOff[0] = 0;
  766. omp_par::scan(recvSz.Begin(), recvOff.Begin(), npes);
  767. assert(recvOff[npes - 1] + recvSz[npes - 1] == recv_size);
  768. Alltoallv(send_indx.Begin(), sendSz.Begin(), sendOff.Begin(), recv_indx.Begin(), recvSz.Begin(), recvOff.Begin());
  769. #pragma omp parallel for schedule(static)
  770. for (Long i = 0; i < recv_size; i++) {
  771. assert(recv_indx[i] >= glb_scan[rank]);
  772. recv_indx[i] -= glb_scan[rank];
  773. assert(recv_indx[i] < recv_size);
  774. }
  775. }
  776. Vector<Type> send_buff;
  777. { // Prepare send buffer
  778. send_buff.ReInit(send_size * data_dim);
  779. ConstIterator<Type> data = data_.Begin();
  780. #pragma omp parallel for schedule(static)
  781. for (Long i = 0; i < send_size; i++) {
  782. Long src_indx = psorted[i].data * data_dim;
  783. Long trg_indx = i * data_dim;
  784. for (Long j = 0; j < data_dim; j++) send_buff[trg_indx + j] = data[src_indx + j];
  785. }
  786. }
  787. Vector<Type> recv_buff;
  788. { // All2Allv
  789. recv_buff.ReInit(recv_size * data_dim);
  790. #pragma omp parallel for schedule(static)
  791. for (Integer i = 0; i < npes; i++) {
  792. sendSz[i] *= data_dim;
  793. sendOff[i] *= data_dim;
  794. recvSz[i] *= data_dim;
  795. recvOff[i] *= data_dim;
  796. }
  797. Alltoallv(send_buff.Begin(), sendSz.Begin(), sendOff.Begin(), recv_buff.Begin(), recvSz.Begin(), recvOff.Begin());
  798. }
  799. { // Build output data.
  800. data_.ReInit(recv_size * data_dim);
  801. Iterator<Type> data = data_.Begin();
  802. #pragma omp parallel for schedule(static)
  803. for (Long i = 0; i < recv_size; i++) {
  804. Long src_indx = i * data_dim;
  805. Long trg_indx = recv_indx[i] * data_dim;
  806. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = recv_buff[src_indx + j];
  807. }
  808. }
  809. }
  810. #ifdef PVFMM_HAVE_MPI
  811. Vector<MPI_Request>* Comm::NewReq() const {
  812. if (req.empty()) req.push(new Vector<MPI_Request>);
  813. Vector<MPI_Request>& request = *(Vector<MPI_Request>*)req.top();
  814. req.pop();
  815. return &request;
  816. }
  817. void Comm::DelReq(Vector<MPI_Request>* req_ptr) const {
  818. if (req_ptr) req.push(req_ptr);
  819. }
  820. #define HS_MPIDATATYPE(CTYPE, MPITYPE) \
  821. template <> class Comm::CommDatatype<CTYPE> { \
  822. public: \
  823. static MPI_Datatype value() { return MPITYPE; } \
  824. static MPI_Op sum() { return MPI_SUM; } \
  825. static MPI_Op min() { return MPI_MIN; } \
  826. static MPI_Op max() { return MPI_MAX; } \
  827. }
  828. HS_MPIDATATYPE(short, MPI_SHORT);
  829. HS_MPIDATATYPE(int, MPI_INT);
  830. HS_MPIDATATYPE(long, MPI_LONG);
  831. HS_MPIDATATYPE(unsigned short, MPI_UNSIGNED_SHORT);
  832. HS_MPIDATATYPE(unsigned int, MPI_UNSIGNED);
  833. HS_MPIDATATYPE(unsigned long, MPI_UNSIGNED_LONG);
  834. HS_MPIDATATYPE(float, MPI_FLOAT);
  835. HS_MPIDATATYPE(double, MPI_DOUBLE);
  836. HS_MPIDATATYPE(long double, MPI_LONG_DOUBLE);
  837. HS_MPIDATATYPE(long long, MPI_LONG_LONG_INT);
  838. HS_MPIDATATYPE(char, MPI_CHAR);
  839. HS_MPIDATATYPE(unsigned char, MPI_UNSIGNED_CHAR);
  840. #undef HS_MPIDATATYPE
  841. #endif
  842. template <class Type> void Comm::HyperQuickSort(const Vector<Type>& arr_, Vector<Type>& SortedElem) const { // O( ((N/p)+log(p))*(log(N/p)+log(p)) )
  843. #ifdef PVFMM_HAVE_MPI
  844. Integer npes, myrank, omp_p;
  845. { // Get comm size and rank.
  846. npes = Size();
  847. myrank = Rank();
  848. omp_p = omp_get_max_threads();
  849. }
  850. srand(myrank);
  851. Long totSize, nelem = arr_.Dim();
  852. { // Local and global sizes. O(log p)
  853. assert(nelem); // TODO: Check if this is needed.
  854. Allreduce<Long>(PVFMM_PTR2CONSTITR(Long, &nelem, 1), PVFMM_PTR2ITR(Long, &totSize, 1), 1, CommOp::SUM);
  855. }
  856. if (npes == 1) { // SortedElem <--- local_sort(arr_)
  857. SortedElem = arr_;
  858. omp_par::merge_sort(SortedElem.Begin(), SortedElem.Begin() + nelem);
  859. return;
  860. }
  861. Vector<Type> arr;
  862. { // arr <-- local_sort(arr_)
  863. arr = arr_;
  864. omp_par::merge_sort(arr.Begin(), arr.Begin() + nelem);
  865. }
  866. Vector<Type> nbuff, nbuff_ext, rbuff, rbuff_ext; // Allocate memory.
  867. MPI_Comm comm = mpi_comm_; // Copy comm
  868. bool free_comm = false; // Flag to free comm.
  869. // Binary split and merge in each iteration.
  870. while (npes > 1 && totSize > 0) { // O(log p) iterations.
  871. Type split_key;
  872. Long totSize_new;
  873. { // Determine split_key. O( log(N/p) + log(p) )
  874. Integer glb_splt_count;
  875. Vector<Type> glb_splitters;
  876. { // Take random splitters. glb_splt_count = const = 100~1000
  877. Integer splt_count;
  878. { // Set splt_coun. O( 1 ) -- Let p * splt_count = t
  879. splt_count = (100 * nelem) / totSize;
  880. if (npes > 100) splt_count = (drand48() * totSize) < (100 * nelem) ? 1 : 0;
  881. if (splt_count > nelem) splt_count = nelem;
  882. }
  883. Vector<Type> splitters(splt_count);
  884. for (Integer i = 0; i < splt_count; i++) {
  885. splitters[i] = arr[rand() % nelem];
  886. }
  887. Vector<Integer> glb_splt_cnts(npes), glb_splt_disp(npes);
  888. { // Set glb_splt_count, glb_splt_cnts, glb_splt_disp
  889. MPI_Allgather(&splt_count, 1, CommDatatype<Integer>::value(), &glb_splt_cnts[0], 1, CommDatatype<Integer>::value(), comm);
  890. glb_splt_disp[0] = 0;
  891. omp_par::scan(glb_splt_cnts.Begin(), glb_splt_disp.Begin(), npes);
  892. glb_splt_count = glb_splt_cnts[npes - 1] + glb_splt_disp[npes - 1];
  893. PVFMM_ASSERT(glb_splt_count);
  894. }
  895. { // Gather all splitters. O( log(p) )
  896. glb_splitters.ReInit(glb_splt_count);
  897. Vector<int> glb_splt_cnts_(npes), glb_splt_disp_(npes);
  898. for (Integer i = 0; i < npes; i++) {
  899. glb_splt_cnts_[i] = glb_splt_cnts[i];
  900. glb_splt_disp_[i] = glb_splt_disp[i];
  901. }
  902. MPI_Allgatherv((splt_count ? &splitters[0] : NULL), splt_count, CommDatatype<Type>::value(), &glb_splitters[0], &glb_splt_cnts_[0], &glb_splt_disp_[0], CommDatatype<Type>::value(), comm);
  903. }
  904. }
  905. // Determine split key. O( log(N/p) + log(p) )
  906. Vector<Long> lrank(glb_splt_count);
  907. { // Compute local rank
  908. #pragma omp parallel for schedule(static)
  909. for (Integer i = 0; i < glb_splt_count; i++) {
  910. lrank[i] = std::lower_bound(arr.Begin(), arr.Begin() + nelem, glb_splitters[i]) - arr.Begin();
  911. }
  912. }
  913. Vector<Long> grank(glb_splt_count);
  914. { // Compute global rank
  915. MPI_Allreduce(&lrank[0], &grank[0], glb_splt_count, CommDatatype<Long>::value(), CommDatatype<Long>::sum(), comm);
  916. }
  917. { // Determine split_key, totSize_new
  918. ConstIterator<Long> split_disp = grank.Begin();
  919. for (Integer i = 0; i < glb_splt_count; i++) {
  920. if (labs(grank[i] - totSize / 2) < labs(*split_disp - totSize / 2)) {
  921. split_disp = grank.Begin() + i;
  922. }
  923. }
  924. split_key = glb_splitters[split_disp - grank.Begin()];
  925. if (myrank <= (npes - 1) / 2)
  926. totSize_new = split_disp[0];
  927. else
  928. totSize_new = totSize - split_disp[0];
  929. // double err=(((double)*split_disp)/(totSize/2))-1.0;
  930. // if(pvfmm::fabs<double>(err)<0.01 || npes<=16) break;
  931. // else if(!myrank) std::cout<<err<<'\n';
  932. }
  933. }
  934. Integer split_id = (npes - 1) / 2;
  935. { // Split problem into two. O( N/p )
  936. Integer new_p0 = (myrank <= split_id ? 0 : split_id + 1);
  937. Integer cmp_p0 = (myrank > split_id ? 0 : split_id + 1);
  938. Integer partner;
  939. { // Set partner
  940. partner = myrank + cmp_p0 - new_p0;
  941. if (partner >= npes) partner = npes - 1;
  942. assert(partner >= 0);
  943. }
  944. bool extra_partner = (npes % 2 == 1 && npes - 1 == myrank);
  945. Long ssize = 0, lsize = 0;
  946. ConstIterator<Type> sbuff, lbuff;
  947. { // Set ssize, lsize, sbuff, lbuff
  948. Long split_indx = std::lower_bound(arr.Begin(), arr.Begin() + nelem, split_key) - arr.Begin();
  949. ssize = (myrank > split_id ? split_indx : nelem - split_indx);
  950. sbuff = (myrank > split_id ? arr.Begin() : arr.Begin() + split_indx);
  951. lsize = (myrank <= split_id ? split_indx : nelem - split_indx);
  952. lbuff = (myrank <= split_id ? arr.Begin() : arr.Begin() + split_indx);
  953. }
  954. Long rsize = 0, ext_rsize = 0;
  955. { // Get rsize, ext_rsize
  956. Long ext_ssize = 0;
  957. MPI_Status status;
  958. MPI_Sendrecv(&ssize, 1, CommDatatype<Long>::value(), partner, 0, &rsize, 1, CommDatatype<Long>::value(), partner, 0, comm, &status);
  959. if (extra_partner) MPI_Sendrecv(&ext_ssize, 1, CommDatatype<Long>::value(), split_id, 0, &ext_rsize, 1, CommDatatype<Long>::value(), split_id, 0, comm, &status);
  960. }
  961. { // Exchange data.
  962. rbuff.ReInit(rsize);
  963. rbuff_ext.ReInit(ext_rsize);
  964. MPI_Status status;
  965. MPI_Sendrecv((ssize ? &sbuff[0] : NULL), ssize, CommDatatype<Type>::value(), partner, 0, (rsize ? &rbuff[0] : NULL), rsize, CommDatatype<Type>::value(), partner, 0, comm, &status);
  966. if (extra_partner) MPI_Sendrecv(NULL, 0, CommDatatype<Type>::value(), split_id, 0, (ext_rsize ? &rbuff_ext[0] : NULL), ext_rsize, CommDatatype<Type>::value(), split_id, 0, comm, &status);
  967. }
  968. Long nbuff_size = lsize + rsize + ext_rsize;
  969. { // nbuff <-- merge(lbuff, rbuff, rbuff_ext)
  970. nbuff.ReInit(lsize + rsize);
  971. omp_par::merge<ConstIterator<Type>>(lbuff, (lbuff + lsize), rbuff.Begin(), rbuff.Begin() + rsize, nbuff.Begin(), omp_p, std::less<Type>());
  972. if (ext_rsize > 0 && nbuff.Dim() > 0) {
  973. nbuff_ext.ReInit(nbuff_size);
  974. omp_par::merge(nbuff.Begin(), nbuff.Begin() + (lsize + rsize), rbuff_ext.Begin(), rbuff_ext.Begin() + ext_rsize, nbuff_ext.Begin(), omp_p, std::less<Type>());
  975. nbuff.Swap(nbuff_ext);
  976. nbuff_ext.ReInit(0);
  977. }
  978. }
  979. // Copy new data.
  980. totSize = totSize_new;
  981. nelem = nbuff_size;
  982. arr.Swap(nbuff);
  983. nbuff.ReInit(0);
  984. }
  985. { // Split comm. O( log(p) ) ??
  986. MPI_Comm scomm;
  987. MPI_Comm_split(comm, myrank <= split_id, myrank, &scomm);
  988. if (free_comm) MPI_Comm_free(&comm);
  989. comm = scomm;
  990. free_comm = true;
  991. npes = (myrank <= split_id ? split_id + 1 : npes - split_id - 1);
  992. myrank = (myrank <= split_id ? myrank : myrank - split_id - 1);
  993. }
  994. }
  995. if (free_comm) MPI_Comm_free(&comm);
  996. SortedElem = arr;
  997. PartitionW<Type>(SortedElem);
  998. #else
  999. SortedElem = arr_;
  1000. std::sort(SortedElem.Begin(), SortedElem.Begin() + SortedElem.Dim());
  1001. #endif
  1002. }
  1003. } // end namespace