comm.txx 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148
  1. #include <type_traits>
  2. #include SCTL_INCLUDE(ompUtils.hpp)
  3. #include SCTL_INCLUDE(vector.hpp)
  4. namespace SCTL_NAMESPACE {
  5. inline Comm::Comm() {
  6. #ifdef SCTL_HAVE_MPI
  7. Init(MPI_COMM_SELF);
  8. #endif
  9. }
  10. inline Comm::Comm(const Comm& c) {
  11. #ifdef SCTL_HAVE_MPI
  12. Init(c.mpi_comm_);
  13. #endif
  14. }
  15. inline Comm Comm::Self() {
  16. #ifdef SCTL_HAVE_MPI
  17. Comm comm_self(MPI_COMM_SELF);
  18. return comm_self;
  19. #else
  20. Comm comm_self;
  21. return comm_self;
  22. #endif
  23. }
  24. inline Comm Comm::World() {
  25. #ifdef SCTL_HAVE_MPI
  26. Comm comm_world(MPI_COMM_WORLD);
  27. return comm_world;
  28. #else
  29. Comm comm_self;
  30. return comm_self;
  31. #endif
  32. }
  33. inline Comm& Comm::operator=(const Comm& c) {
  34. #ifdef SCTL_HAVE_MPI
  35. Init(c.mpi_comm_);
  36. #endif
  37. return *this;
  38. }
  39. inline Comm::~Comm() {
  40. #ifdef SCTL_HAVE_MPI
  41. while (!req.empty()) {
  42. delete (Vector<MPI_Request>*)req.top();
  43. req.pop();
  44. }
  45. MPI_Comm_free(&mpi_comm_);
  46. #endif
  47. }
  48. inline Comm Comm::Split(Integer clr) const {
  49. #ifdef SCTL_HAVE_MPI
  50. MPI_Comm new_comm;
  51. MPI_Comm_split(mpi_comm_, clr, mpi_rank_, &new_comm);
  52. Comm c(new_comm);
  53. MPI_Comm_free(&new_comm);
  54. return c;
  55. #else
  56. Comm c;
  57. return c;
  58. #endif
  59. }
  60. inline Integer Comm::Rank() const {
  61. #ifdef SCTL_HAVE_MPI
  62. return mpi_rank_;
  63. #else
  64. return 0;
  65. #endif
  66. }
  67. inline Integer Comm::Size() const {
  68. #ifdef SCTL_HAVE_MPI
  69. return mpi_size_;
  70. #else
  71. return 1;
  72. #endif
  73. }
  74. inline void Comm::Barrier() const {
  75. #ifdef SCTL_HAVE_MPI
  76. MPI_Barrier(mpi_comm_);
  77. #endif
  78. }
  79. template <class SType> void* Comm::Isend(ConstIterator<SType> sbuf, Long scount, Integer dest, Integer tag) const {
  80. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  81. #ifdef SCTL_HAVE_MPI
  82. if (!scount) return nullptr;
  83. Vector<MPI_Request>& request = *NewReq();
  84. request.ReInit(1);
  85. sbuf[0];
  86. sbuf[scount - 1];
  87. #ifndef NDEBUG
  88. MPI_Issend(&sbuf[0], scount, CommDatatype<SType>::value(), dest, tag, mpi_comm_, &request[0]);
  89. #else
  90. MPI_Isend(&sbuf[0], scount, CommDatatype<SType>::value(), dest, tag, mpi_comm_, &request[0]);
  91. #endif
  92. return &request;
  93. #else
  94. auto it = recv_req.find(tag);
  95. if (it == recv_req.end())
  96. send_req.insert(std::pair<Integer, ConstIterator<char>>(tag, (ConstIterator<char>)sbuf));
  97. else
  98. memcopy(it->second, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  99. return nullptr;
  100. #endif
  101. }
  102. template <class RType> void* Comm::Irecv(Iterator<RType> rbuf, Long rcount, Integer source, Integer tag) const {
  103. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  104. #ifdef SCTL_HAVE_MPI
  105. if (!rcount) return nullptr;
  106. Vector<MPI_Request>& request = *NewReq();
  107. request.ReInit(1);
  108. rbuf[0];
  109. rbuf[rcount - 1];
  110. MPI_Irecv(&rbuf[0], rcount, CommDatatype<RType>::value(), source, tag, mpi_comm_, &request[0]);
  111. return &request;
  112. #else
  113. auto it = send_req.find(tag);
  114. if (it == send_req.end())
  115. recv_req.insert(std::pair<Integer, Iterator<char>>(tag, (Iterator<char>)rbuf));
  116. else
  117. memcopy((Iterator<char>)rbuf, it->second, rcount * sizeof(RType));
  118. return nullptr;
  119. #endif
  120. }
  121. inline void Comm::Wait(void* req_ptr) const {
  122. #ifdef SCTL_HAVE_MPI
  123. if (req_ptr == nullptr) return;
  124. Vector<MPI_Request>& request = *(Vector<MPI_Request>*)req_ptr;
  125. // std::vector<MPI_Status> status(request.Dim());
  126. if (request.Dim()) MPI_Waitall(request.Dim(), &request[0], MPI_STATUSES_IGNORE); //&status[0]);
  127. DelReq(&request);
  128. #endif
  129. }
  130. template <class SType, class RType> void Comm::Allgather(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const {
  131. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  132. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  133. #ifdef SCTL_HAVE_MPI
  134. if (scount) {
  135. sbuf[0];
  136. sbuf[scount - 1];
  137. }
  138. if (rcount) {
  139. rbuf[0];
  140. rbuf[rcount * Size() - 1];
  141. }
  142. MPI_Allgather((scount ? &sbuf[0] : nullptr), scount, CommDatatype<SType>::value(), (rcount ? &rbuf[0] : nullptr), rcount, CommDatatype<RType>::value(), mpi_comm_);
  143. #else
  144. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  145. #endif
  146. }
  147. template <class SType, class RType> void Comm::Allgatherv(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const {
  148. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  149. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  150. #ifdef SCTL_HAVE_MPI
  151. Vector<int> rcounts_(mpi_size_), rdispls_(mpi_size_);
  152. Long rcount_sum = 0;
  153. #pragma omp parallel for schedule(static) reduction(+ : rcount_sum)
  154. for (Integer i = 0; i < mpi_size_; i++) {
  155. rcounts_[i] = rcounts[i];
  156. rdispls_[i] = rdispls[i];
  157. rcount_sum += rcounts[i];
  158. }
  159. if (scount) {
  160. sbuf[0];
  161. sbuf[scount - 1];
  162. }
  163. if (rcount_sum) {
  164. rbuf[0];
  165. rbuf[rcount_sum - 1];
  166. }
  167. MPI_Allgatherv((scount ? &sbuf[0] : nullptr), scount, CommDatatype<SType>::value(), (rcount_sum ? &rbuf[0] : nullptr), &rcounts_.begin()[0], &rdispls_.begin()[0], CommDatatype<RType>::value(), mpi_comm_);
  168. #else
  169. memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)sbuf, scount * sizeof(SType));
  170. #endif
  171. }
  172. template <class SType, class RType> void Comm::Alltoall(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const {
  173. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  174. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  175. #ifdef SCTL_HAVE_MPI
  176. if (scount) {
  177. sbuf[0];
  178. sbuf[scount * Size() - 1];
  179. }
  180. if (rcount) {
  181. rbuf[0];
  182. rbuf[rcount * Size() - 1];
  183. }
  184. MPI_Alltoall((scount ? &sbuf[0] : nullptr), scount, CommDatatype<SType>::value(), (rcount ? &rbuf[0] : nullptr), rcount, CommDatatype<RType>::value(), mpi_comm_);
  185. #else
  186. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  187. #endif
  188. }
  189. template <class SType, class RType> void* Comm::Ialltoallv_sparse(ConstIterator<SType> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls, Integer tag) const {
  190. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  191. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  192. #ifdef SCTL_HAVE_MPI
  193. Integer request_count = 0;
  194. for (Integer i = 0; i < mpi_size_; i++) {
  195. if (rcounts[i]) request_count++;
  196. if (scounts[i]) request_count++;
  197. }
  198. if (!request_count) return nullptr;
  199. Vector<MPI_Request>& request = *NewReq();
  200. request.ReInit(request_count);
  201. Integer request_iter = 0;
  202. for (Integer i = 0; i < mpi_size_; i++) {
  203. if (rcounts[i]) {
  204. rbuf[rdispls[i]];
  205. rbuf[rdispls[i] + rcounts[i] - 1];
  206. MPI_Irecv(&rbuf[rdispls[i]], rcounts[i], CommDatatype<RType>::value(), i, tag, mpi_comm_, &request[request_iter]);
  207. request_iter++;
  208. }
  209. }
  210. for (Integer i = 0; i < mpi_size_; i++) {
  211. if (scounts[i]) {
  212. sbuf[sdispls[i]];
  213. sbuf[sdispls[i] + scounts[i] - 1];
  214. MPI_Isend(&sbuf[sdispls[i]], scounts[i], CommDatatype<SType>::value(), i, tag, mpi_comm_, &request[request_iter]);
  215. request_iter++;
  216. }
  217. }
  218. return &request;
  219. #else
  220. memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)(sbuf + sdispls[0]), scounts[0] * sizeof(SType));
  221. return nullptr;
  222. #endif
  223. }
  224. template <class Type> void Comm::Alltoallv(ConstIterator<Type> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<Type> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const {
  225. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  226. #ifdef SCTL_HAVE_MPI
  227. { // Use Alltoallv_sparse of average connectivity<64
  228. Long connectivity = 0, glb_connectivity = 0;
  229. #pragma omp parallel for schedule(static) reduction(+ : connectivity)
  230. for (Integer i = 0; i < mpi_size_; i++) {
  231. if (rcounts[i]) connectivity++;
  232. }
  233. Allreduce(Ptr2ConstItr<Long>(&connectivity, 1), Ptr2Itr<Long>(&glb_connectivity, 1), 1, CommOp::SUM);
  234. if (glb_connectivity < 64 * Size()) {
  235. void* mpi_req = Ialltoallv_sparse(sbuf, scounts, sdispls, rbuf, rcounts, rdispls, 0);
  236. Wait(mpi_req);
  237. return;
  238. }
  239. }
  240. { // Use vendor MPI_Alltoallv
  241. //#ifndef ALLTOALLV_FIX
  242. Vector<int> scnt, sdsp, rcnt, rdsp;
  243. scnt.ReInit(mpi_size_);
  244. sdsp.ReInit(mpi_size_);
  245. rcnt.ReInit(mpi_size_);
  246. rdsp.ReInit(mpi_size_);
  247. Long stotal = 0, rtotal = 0;
  248. #pragma omp parallel for schedule(static) reduction(+ : stotal, rtotal)
  249. for (Integer i = 0; i < mpi_size_; i++) {
  250. scnt[i] = scounts[i];
  251. sdsp[i] = sdispls[i];
  252. rcnt[i] = rcounts[i];
  253. rdsp[i] = rdispls[i];
  254. stotal += scounts[i];
  255. rtotal += rcounts[i];
  256. }
  257. MPI_Alltoallv((stotal ? &sbuf[0] : nullptr), &scnt[0], &sdsp[0], CommDatatype<Type>::value(), (rtotal ? &rbuf[0] : nullptr), &rcnt[0], &rdsp[0], CommDatatype<Type>::value(), mpi_comm_);
  258. return;
  259. //#endif
  260. }
  261. // TODO: implement hypercube scheme
  262. #else
  263. memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)(sbuf + sdispls[0]), scounts[0] * sizeof(Type));
  264. #endif
  265. }
  266. template <class Type> void Comm::Allreduce(ConstIterator<Type> sbuf, Iterator<Type> rbuf, Long count, CommOp op) const {
  267. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  268. #ifdef SCTL_HAVE_MPI
  269. if (!count) return;
  270. MPI_Op mpi_op;
  271. switch (op) {
  272. case CommOp::SUM:
  273. mpi_op = CommDatatype<Type>::sum();
  274. break;
  275. case CommOp::MIN:
  276. mpi_op = CommDatatype<Type>::min();
  277. break;
  278. case CommOp::MAX:
  279. mpi_op = CommDatatype<Type>::max();
  280. break;
  281. default:
  282. mpi_op = MPI_OP_NULL;
  283. break;
  284. }
  285. sbuf[0];
  286. sbuf[count - 1];
  287. rbuf[0];
  288. rbuf[count - 1];
  289. MPI_Allreduce(&sbuf[0], &rbuf[0], count, CommDatatype<Type>::value(), mpi_op, mpi_comm_);
  290. #else
  291. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, count * sizeof(Type));
  292. #endif
  293. }
  294. template <class Type> void Comm::Scan(ConstIterator<Type> sbuf, Iterator<Type> rbuf, int count, CommOp op) const {
  295. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  296. #ifdef SCTL_HAVE_MPI
  297. if (!count) return;
  298. MPI_Op mpi_op;
  299. switch (op) {
  300. case CommOp::SUM:
  301. mpi_op = CommDatatype<Type>::sum();
  302. break;
  303. case CommOp::MIN:
  304. mpi_op = CommDatatype<Type>::min();
  305. break;
  306. case CommOp::MAX:
  307. mpi_op = CommDatatype<Type>::max();
  308. break;
  309. default:
  310. mpi_op = MPI_OP_NULL;
  311. break;
  312. }
  313. sbuf[0];
  314. sbuf[count - 1];
  315. rbuf[0];
  316. rbuf[count - 1];
  317. MPI_Scan(&sbuf[0], &rbuf[0], count, CommDatatype<Type>::value(), mpi_op, mpi_comm_);
  318. #else
  319. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, count * sizeof(Type));
  320. #endif
  321. }
  322. template <class Type> void Comm::PartitionW(Vector<Type>& nodeList, const Vector<Long>* wts_) const {
  323. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  324. Integer npes = Size();
  325. if (npes == 1) return;
  326. Long nlSize = nodeList.Dim();
  327. Vector<Long> wts;
  328. Long localWt = 0;
  329. if (wts_ == nullptr) { // Construct arrays of wts.
  330. wts.ReInit(nlSize);
  331. #pragma omp parallel for schedule(static)
  332. for (Long i = 0; i < nlSize; i++) {
  333. wts[i] = 1;
  334. }
  335. localWt = nlSize;
  336. } else {
  337. wts.ReInit(nlSize, (Iterator<Long>)wts_->begin(), false);
  338. #pragma omp parallel for reduction(+ : localWt)
  339. for (Long i = 0; i < nlSize; i++) {
  340. localWt += wts[i];
  341. }
  342. }
  343. Long off1 = 0, off2 = 0, totalWt = 0;
  344. { // compute the total weight of the problem ...
  345. Allreduce<Long>(Ptr2ConstItr<Long>(&localWt, 1), Ptr2Itr<Long>(&totalWt, 1), 1, CommOp::SUM);
  346. Scan<Long>(Ptr2ConstItr<Long>(&localWt, 1), Ptr2Itr<Long>(&off2, 1), 1, CommOp::SUM);
  347. off1 = off2 - localWt;
  348. }
  349. Vector<Long> lscn;
  350. if (nlSize) { // perform a local scan on the weights first ...
  351. lscn.ReInit(nlSize);
  352. lscn[0] = off1;
  353. omp_par::scan(wts.begin(), lscn.begin(), nlSize);
  354. }
  355. Vector<Long> sendSz, recvSz, sendOff, recvOff;
  356. sendSz.ReInit(npes);
  357. recvSz.ReInit(npes);
  358. sendOff.ReInit(npes);
  359. recvOff.ReInit(npes);
  360. sendSz.SetZero();
  361. if (nlSize > 0 && totalWt > 0) { // Compute sendSz
  362. Long pid1 = (off1 * npes) / totalWt;
  363. Long pid2 = ((off2 + 1) * npes) / totalWt + 1;
  364. assert((totalWt * pid2) / npes >= off2);
  365. pid1 = (pid1 < 0 ? 0 : pid1);
  366. pid2 = (pid2 > npes ? npes : pid2);
  367. #pragma omp parallel for schedule(static)
  368. for (Integer i = pid1; i < pid2; i++) {
  369. Long wt1 = (totalWt * (i)) / npes;
  370. Long wt2 = (totalWt * (i + 1)) / npes;
  371. Long start = std::lower_bound(lscn.begin(), lscn.begin() + nlSize, wt1, std::less<Long>()) - lscn.begin();
  372. Long end = std::lower_bound(lscn.begin(), lscn.begin() + nlSize, wt2, std::less<Long>()) - lscn.begin();
  373. if (i == 0) start = 0;
  374. if (i == npes - 1) end = nlSize;
  375. sendSz[i] = end - start;
  376. }
  377. } else {
  378. sendSz[0] = nlSize;
  379. }
  380. // Exchange sendSz, recvSz
  381. Alltoall<Long>(sendSz.begin(), 1, recvSz.begin(), 1);
  382. { // Compute sendOff, recvOff
  383. sendOff[0] = 0;
  384. omp_par::scan(sendSz.begin(), sendOff.begin(), npes);
  385. recvOff[0] = 0;
  386. omp_par::scan(recvSz.begin(), recvOff.begin(), npes);
  387. assert(sendOff[npes - 1] + sendSz[npes - 1] == nlSize);
  388. }
  389. // perform All2All ...
  390. Vector<Type> newNodes;
  391. newNodes.ReInit(recvSz[npes - 1] + recvOff[npes - 1]);
  392. void* mpi_req = Ialltoallv_sparse<Type>(nodeList.begin(), sendSz.begin(), sendOff.begin(), newNodes.begin(), recvSz.begin(), recvOff.begin());
  393. Wait(mpi_req);
  394. // reset the pointer ...
  395. nodeList.Swap(newNodes);
  396. }
  397. template <class Type> void Comm::PartitionN(Vector<Type>& v, Long N) const {
  398. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  399. Integer rank = Rank();
  400. Integer np = Size();
  401. if (np == 1) return;
  402. Vector<Long> v_cnt(np), v_dsp(np + 1);
  403. Vector<Long> N_cnt(np), N_dsp(np + 1);
  404. { // Set v_cnt, v_dsp
  405. v_dsp[0] = 0;
  406. Long cnt = v.Dim();
  407. Allgather(Ptr2ConstItr<Long>(&cnt, 1), 1, v_cnt.begin(), 1);
  408. omp_par::scan(v_cnt.begin(), v_dsp.begin(), np);
  409. v_dsp[np] = v_cnt[np - 1] + v_dsp[np - 1];
  410. }
  411. { // Set N_cnt, N_dsp
  412. N_dsp[0] = 0;
  413. Long cnt = N;
  414. Allgather(Ptr2ConstItr<Long>(&cnt, 1), 1, N_cnt.begin(), 1);
  415. omp_par::scan(N_cnt.begin(), N_dsp.begin(), np);
  416. N_dsp[np] = N_cnt[np - 1] + N_dsp[np - 1];
  417. }
  418. { // Adjust for dof
  419. Long dof = (N_dsp[np] ? v_dsp[np] / N_dsp[np] : 0);
  420. assert(dof * N_dsp[np] == v_dsp[np]);
  421. if (dof == 0) return;
  422. if (dof != 1) {
  423. #pragma omp parallel for schedule(static)
  424. for (Integer i = 0; i < np; i++) N_cnt[i] *= dof;
  425. #pragma omp parallel for schedule(static)
  426. for (Integer i = 0; i <= np; i++) N_dsp[i] *= dof;
  427. }
  428. }
  429. Vector<Type> v_(N_cnt[rank]);
  430. { // Set v_
  431. Vector<Long> scnt(np), sdsp(np);
  432. Vector<Long> rcnt(np), rdsp(np);
  433. #pragma omp parallel for schedule(static)
  434. for (Integer i = 0; i < np; i++) {
  435. { // Set scnt
  436. Long n0 = N_dsp[i + 0];
  437. Long n1 = N_dsp[i + 1];
  438. if (n0 < v_dsp[rank + 0]) n0 = v_dsp[rank + 0];
  439. if (n1 < v_dsp[rank + 0]) n1 = v_dsp[rank + 0];
  440. if (n0 > v_dsp[rank + 1]) n0 = v_dsp[rank + 1];
  441. if (n1 > v_dsp[rank + 1]) n1 = v_dsp[rank + 1];
  442. scnt[i] = n1 - n0;
  443. }
  444. { // Set rcnt
  445. Long n0 = v_dsp[i + 0];
  446. Long n1 = v_dsp[i + 1];
  447. if (n0 < N_dsp[rank + 0]) n0 = N_dsp[rank + 0];
  448. if (n1 < N_dsp[rank + 0]) n1 = N_dsp[rank + 0];
  449. if (n0 > N_dsp[rank + 1]) n0 = N_dsp[rank + 1];
  450. if (n1 > N_dsp[rank + 1]) n1 = N_dsp[rank + 1];
  451. rcnt[i] = n1 - n0;
  452. }
  453. }
  454. sdsp[0] = 0;
  455. omp_par::scan(scnt.begin(), sdsp.begin(), np);
  456. rdsp[0] = 0;
  457. omp_par::scan(rcnt.begin(), rdsp.begin(), np);
  458. void* mpi_request = Ialltoallv_sparse(v.begin(), scnt.begin(), sdsp.begin(), v_.begin(), rcnt.begin(), rdsp.begin());
  459. Wait(mpi_request);
  460. }
  461. v.Swap(v_);
  462. }
  463. template <class Type> void Comm::PartitionS(Vector<Type>& nodeList, const Type& splitter) const {
  464. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  465. Integer npes = Size();
  466. if (npes == 1) return;
  467. Vector<Type> mins(npes);
  468. Allgather(Ptr2ConstItr<Type>(&splitter, 1), 1, mins.begin(), 1);
  469. Vector<Long> scnt(npes), sdsp(npes);
  470. Vector<Long> rcnt(npes), rdsp(npes);
  471. { // Compute scnt, sdsp
  472. #pragma omp parallel for schedule(static)
  473. for (Integer i = 0; i < npes; i++) {
  474. sdsp[i] = std::lower_bound(nodeList.begin(), nodeList.begin() + nodeList.Dim(), mins[i]) - nodeList.begin();
  475. }
  476. #pragma omp parallel for schedule(static)
  477. for (Integer i = 0; i < npes - 1; i++) {
  478. scnt[i] = sdsp[i + 1] - sdsp[i];
  479. }
  480. scnt[npes - 1] = nodeList.Dim() - sdsp[npes - 1];
  481. }
  482. { // Compute rcnt, rdsp
  483. rdsp[0] = 0;
  484. Alltoall(scnt.begin(), 1, rcnt.begin(), 1);
  485. omp_par::scan(rcnt.begin(), rdsp.begin(), npes);
  486. }
  487. { // Redistribute nodeList
  488. Vector<Type> nodeList_(rdsp[npes - 1] + rcnt[npes - 1]);
  489. void* mpi_request = Ialltoallv_sparse(nodeList.begin(), scnt.begin(), sdsp.begin(), nodeList_.begin(), rcnt.begin(), rdsp.begin());
  490. Wait(mpi_request);
  491. nodeList.Swap(nodeList_);
  492. }
  493. }
  494. template <class Type> void Comm::SortScatterIndex(const Vector<Type>& key, Vector<Long>& scatter_index, const Type* split_key_) const {
  495. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  496. typedef SortPair<Type, Long> Pair_t;
  497. Integer npes = Size(), rank = Rank();
  498. Vector<Pair_t> parray(key.Dim());
  499. { // Build global index.
  500. Long glb_dsp = 0;
  501. Long loc_size = key.Dim();
  502. Scan(Ptr2ConstItr<Long>(&loc_size, 1), Ptr2Itr<Long>(&glb_dsp, 1), 1, CommOp::SUM);
  503. glb_dsp -= loc_size;
  504. #pragma omp parallel for schedule(static)
  505. for (Long i = 0; i < loc_size; i++) {
  506. parray[i].key = key[i];
  507. parray[i].data = glb_dsp + i;
  508. }
  509. }
  510. Vector<Pair_t> psorted;
  511. HyperQuickSort(parray, psorted);
  512. if (npes > 1 && split_key_ != nullptr) { // Partition data
  513. Vector<Type> split_key(npes);
  514. Allgather(Ptr2ConstItr<Type>(split_key_, 1), 1, split_key.begin(), 1);
  515. Vector<Long> sendSz(npes);
  516. Vector<Long> recvSz(npes);
  517. Vector<Long> sendOff(npes);
  518. Vector<Long> recvOff(npes);
  519. Long nlSize = psorted.Dim();
  520. sendSz.SetZero();
  521. if (nlSize > 0) { // Compute sendSz
  522. // Determine processor range.
  523. Long pid1 = std::lower_bound(split_key.begin(), split_key.begin() + npes, psorted[0].key) - split_key.begin() - 1;
  524. Long pid2 = std::upper_bound(split_key.begin(), split_key.begin() + npes, psorted[nlSize - 1].key) - split_key.begin() + 1;
  525. pid1 = (pid1 < 0 ? 0 : pid1);
  526. pid2 = (pid2 > npes ? npes : pid2);
  527. #pragma omp parallel for schedule(static)
  528. for (Integer i = pid1; i < pid2; i++) {
  529. Pair_t p1;
  530. p1.key = split_key[i];
  531. Pair_t p2;
  532. p2.key = split_key[i + 1 < npes ? i + 1 : i];
  533. Long start = std::lower_bound(psorted.begin(), psorted.begin() + nlSize, p1, std::less<Pair_t>()) - psorted.begin();
  534. Long end = std::lower_bound(psorted.begin(), psorted.begin() + nlSize, p2, std::less<Pair_t>()) - psorted.begin();
  535. if (i == 0) start = 0;
  536. if (i == npes - 1) end = nlSize;
  537. sendSz[i] = end - start;
  538. }
  539. }
  540. // Exchange sendSz, recvSz
  541. Alltoall<Long>(sendSz.begin(), 1, recvSz.begin(), 1);
  542. // compute offsets ...
  543. { // Compute sendOff, recvOff
  544. sendOff[0] = 0;
  545. omp_par::scan(sendSz.begin(), sendOff.begin(), npes);
  546. recvOff[0] = 0;
  547. omp_par::scan(recvSz.begin(), recvOff.begin(), npes);
  548. assert(sendOff[npes - 1] + sendSz[npes - 1] == nlSize);
  549. }
  550. // perform All2All ...
  551. Vector<Pair_t> newNodes(recvSz[npes - 1] + recvOff[npes - 1]);
  552. void* mpi_req = Ialltoallv_sparse<Pair_t>(psorted.begin(), sendSz.begin(), sendOff.begin(), newNodes.begin(), recvSz.begin(), recvOff.begin());
  553. Wait(mpi_req);
  554. // reset the pointer ...
  555. psorted.Swap(newNodes);
  556. }
  557. scatter_index.ReInit(psorted.Dim());
  558. #pragma omp parallel for schedule(static)
  559. for (Long i = 0; i < psorted.Dim(); i++) {
  560. scatter_index[i] = psorted[i].data;
  561. }
  562. }
  563. template <class Type> void Comm::ScatterForward(Vector<Type>& data_, const Vector<Long>& scatter_index) const {
  564. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  565. typedef SortPair<Long, Long> Pair_t;
  566. Integer npes = Size(), rank = Rank();
  567. Long data_dim = 0;
  568. Long send_size = 0;
  569. Long recv_size = 0;
  570. { // Set data_dim, send_size, recv_size
  571. recv_size = scatter_index.Dim();
  572. StaticArray<Long, 2> glb_size;
  573. StaticArray<Long, 2> loc_size;
  574. loc_size[0] = data_.Dim();
  575. loc_size[1] = recv_size;
  576. Allreduce(loc_size, glb_size, 2, CommOp::SUM);
  577. if (glb_size[0] == 0 || glb_size[1] == 0) return; // Nothing to be done.
  578. data_dim = glb_size[0] / glb_size[1];
  579. SCTL_ASSERT(glb_size[0] == data_dim * glb_size[1]);
  580. send_size = data_.Dim() / data_dim;
  581. }
  582. if (npes == 1) { // Scatter directly
  583. Vector<Type> data;
  584. data.ReInit(recv_size * data_dim);
  585. #pragma omp parallel for schedule(static)
  586. for (Long i = 0; i < recv_size; i++) {
  587. Long src_indx = scatter_index[i] * data_dim;
  588. Long trg_indx = i * data_dim;
  589. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = data_[src_indx + j];
  590. }
  591. data_.Swap(data);
  592. return;
  593. }
  594. Vector<Long> glb_scan;
  595. { // Global scan of data size.
  596. glb_scan.ReInit(npes);
  597. Long glb_rank = 0;
  598. Scan(Ptr2ConstItr<Long>(&send_size, 1), Ptr2Itr<Long>(&glb_rank, 1), 1, CommOp::SUM);
  599. glb_rank -= send_size;
  600. Allgather(Ptr2ConstItr<Long>(&glb_rank, 1), 1, glb_scan.begin(), 1);
  601. }
  602. Vector<Pair_t> psorted;
  603. { // Sort scatter_index.
  604. psorted.ReInit(recv_size);
  605. #pragma omp parallel for schedule(static)
  606. for (Long i = 0; i < recv_size; i++) {
  607. psorted[i].key = scatter_index[i];
  608. psorted[i].data = i;
  609. }
  610. omp_par::merge_sort(psorted.begin(), psorted.begin() + recv_size);
  611. }
  612. Vector<Long> recv_indx(recv_size);
  613. Vector<Long> send_indx(send_size);
  614. Vector<Long> sendSz(npes);
  615. Vector<Long> sendOff(npes);
  616. Vector<Long> recvSz(npes);
  617. Vector<Long> recvOff(npes);
  618. { // Exchange send, recv indices.
  619. #pragma omp parallel for schedule(static)
  620. for (Long i = 0; i < recv_size; i++) {
  621. recv_indx[i] = psorted[i].key;
  622. }
  623. #pragma omp parallel for schedule(static)
  624. for (Integer i = 0; i < npes; i++) {
  625. Long start = std::lower_bound(recv_indx.begin(), recv_indx.begin() + recv_size, glb_scan[i]) - recv_indx.begin();
  626. Long end = (i + 1 < npes ? std::lower_bound(recv_indx.begin(), recv_indx.begin() + recv_size, glb_scan[i + 1]) - recv_indx.begin() : recv_size);
  627. recvSz[i] = end - start;
  628. recvOff[i] = start;
  629. }
  630. Alltoall(recvSz.begin(), 1, sendSz.begin(), 1);
  631. sendOff[0] = 0;
  632. omp_par::scan(sendSz.begin(), sendOff.begin(), npes);
  633. assert(sendOff[npes - 1] + sendSz[npes - 1] == send_size);
  634. Alltoallv(recv_indx.begin(), recvSz.begin(), recvOff.begin(), send_indx.begin(), sendSz.begin(), sendOff.begin());
  635. #pragma omp parallel for schedule(static)
  636. for (Long i = 0; i < send_size; i++) {
  637. assert(send_indx[i] >= glb_scan[rank]);
  638. send_indx[i] -= glb_scan[rank];
  639. assert(send_indx[i] < send_size);
  640. }
  641. }
  642. Vector<Type> send_buff;
  643. { // Prepare send buffer
  644. send_buff.ReInit(send_size * data_dim);
  645. ConstIterator<Type> data = data_.begin();
  646. #pragma omp parallel for schedule(static)
  647. for (Long i = 0; i < send_size; i++) {
  648. Long src_indx = send_indx[i] * data_dim;
  649. Long trg_indx = i * data_dim;
  650. for (Long j = 0; j < data_dim; j++) send_buff[trg_indx + j] = data[src_indx + j];
  651. }
  652. }
  653. Vector<Type> recv_buff;
  654. { // All2Allv
  655. recv_buff.ReInit(recv_size * data_dim);
  656. #pragma omp parallel for schedule(static)
  657. for (Integer i = 0; i < npes; i++) {
  658. sendSz[i] *= data_dim;
  659. sendOff[i] *= data_dim;
  660. recvSz[i] *= data_dim;
  661. recvOff[i] *= data_dim;
  662. }
  663. Alltoallv(send_buff.begin(), sendSz.begin(), sendOff.begin(), recv_buff.begin(), recvSz.begin(), recvOff.begin());
  664. }
  665. { // Build output data.
  666. data_.ReInit(recv_size * data_dim);
  667. Iterator<Type> data = data_.begin();
  668. #pragma omp parallel for schedule(static)
  669. for (Long i = 0; i < recv_size; i++) {
  670. Long src_indx = i * data_dim;
  671. Long trg_indx = psorted[i].data * data_dim;
  672. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = recv_buff[src_indx + j];
  673. }
  674. }
  675. }
  676. template <class Type> void Comm::ScatterReverse(Vector<Type>& data_, const Vector<Long>& scatter_index_, Long loc_size_) const {
  677. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  678. typedef SortPair<Long, Long> Pair_t;
  679. Integer npes = Size(), rank = Rank();
  680. Long data_dim = 0;
  681. Long send_size = 0;
  682. Long recv_size = 0;
  683. { // Set data_dim, send_size, recv_size
  684. recv_size = loc_size_;
  685. StaticArray<Long, 3> glb_size;
  686. StaticArray<Long, 3> loc_size;
  687. loc_size[0] = data_.Dim();
  688. loc_size[1] = scatter_index_.Dim();
  689. loc_size[2] = recv_size;
  690. Allreduce(loc_size, glb_size, 3, CommOp::SUM);
  691. if (glb_size[0] == 0 || glb_size[1] == 0) return; // Nothing to be done.
  692. SCTL_ASSERT(glb_size[0] % glb_size[1] == 0);
  693. data_dim = glb_size[0] / glb_size[1];
  694. SCTL_ASSERT(loc_size[0] % data_dim == 0);
  695. send_size = loc_size[0] / data_dim;
  696. if (glb_size[0] != glb_size[2] * data_dim) {
  697. recv_size = (((rank + 1) * (glb_size[0] / data_dim)) / npes) - ((rank * (glb_size[0] / data_dim)) / npes);
  698. }
  699. }
  700. if (npes == 1) { // Scatter directly
  701. Vector<Type> data;
  702. data.ReInit(recv_size * data_dim);
  703. #pragma omp parallel for schedule(static)
  704. for (Long i = 0; i < recv_size; i++) {
  705. Long src_indx = i * data_dim;
  706. Long trg_indx = scatter_index_[i] * data_dim;
  707. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = data_[src_indx + j];
  708. }
  709. data_.Swap(data);
  710. return;
  711. }
  712. Vector<Long> scatter_index;
  713. {
  714. StaticArray<Long, 2> glb_rank;
  715. StaticArray<Long, 3> glb_size;
  716. StaticArray<Long, 2> loc_size;
  717. loc_size[0] = data_.Dim() / data_dim;
  718. loc_size[1] = scatter_index_.Dim();
  719. Scan(loc_size, glb_rank, 2, CommOp::SUM);
  720. Allreduce(loc_size, glb_size, 2, CommOp::SUM);
  721. SCTL_ASSERT(glb_size[0] == glb_size[1]);
  722. glb_rank[0] -= loc_size[0];
  723. glb_rank[1] -= loc_size[1];
  724. Vector<Long> glb_scan0(npes + 1);
  725. Vector<Long> glb_scan1(npes + 1);
  726. Allgather(glb_rank + 0, 1, glb_scan0.begin(), 1);
  727. Allgather(glb_rank + 1, 1, glb_scan1.begin(), 1);
  728. glb_scan0[npes] = glb_size[0];
  729. glb_scan1[npes] = glb_size[1];
  730. if (loc_size[0] != loc_size[1] || glb_rank[0] != glb_rank[1]) { // Repartition scatter_index
  731. scatter_index.ReInit(loc_size[0]);
  732. Vector<Long> send_dsp(npes + 1);
  733. Vector<Long> recv_dsp(npes + 1);
  734. #pragma omp parallel for schedule(static)
  735. for (Integer i = 0; i <= npes; i++) {
  736. send_dsp[i] = std::min(std::max(glb_scan0[i], glb_rank[1]), glb_rank[1] + loc_size[1]) - glb_rank[1];
  737. recv_dsp[i] = std::min(std::max(glb_scan1[i], glb_rank[0]), glb_rank[0] + loc_size[0]) - glb_rank[0];
  738. }
  739. // Long commCnt=0;
  740. Vector<Long> send_cnt(npes + 0);
  741. Vector<Long> recv_cnt(npes + 0);
  742. #pragma omp parallel for schedule(static) // reduction(+:commCnt)
  743. for (Integer i = 0; i < npes; i++) {
  744. send_cnt[i] = send_dsp[i + 1] - send_dsp[i];
  745. recv_cnt[i] = recv_dsp[i + 1] - recv_dsp[i];
  746. // if(send_cnt[i] && i!=rank) commCnt++;
  747. // if(recv_cnt[i] && i!=rank) commCnt++;
  748. }
  749. void* mpi_req = Ialltoallv_sparse<Long>(scatter_index_.begin(), send_cnt.begin(), send_dsp.begin(), scatter_index.begin(), recv_cnt.begin(), recv_dsp.begin(), 0);
  750. Wait(mpi_req);
  751. } else {
  752. scatter_index.ReInit(scatter_index_.Dim(), (Iterator<Long>)scatter_index_.begin(), false);
  753. }
  754. }
  755. Vector<Long> glb_scan(npes);
  756. { // Global data size.
  757. Long glb_rank = 0;
  758. Scan(Ptr2ConstItr<Long>(&recv_size, 1), Ptr2Itr<Long>(&glb_rank, 1), 1, CommOp::SUM);
  759. glb_rank -= recv_size;
  760. Allgather(Ptr2ConstItr<Long>(&glb_rank, 1), 1, glb_scan.begin(), 1);
  761. }
  762. Vector<Pair_t> psorted(send_size);
  763. { // Sort scatter_index.
  764. #pragma omp parallel for schedule(static)
  765. for (Long i = 0; i < send_size; i++) {
  766. psorted[i].key = scatter_index[i];
  767. psorted[i].data = i;
  768. }
  769. omp_par::merge_sort(psorted.begin(), psorted.begin() + send_size);
  770. }
  771. Vector<Long> recv_indx(recv_size);
  772. Vector<Long> send_indx(send_size);
  773. Vector<Long> sendSz(npes);
  774. Vector<Long> sendOff(npes);
  775. Vector<Long> recvSz(npes);
  776. Vector<Long> recvOff(npes);
  777. { // Exchange send, recv indices.
  778. #pragma omp parallel for schedule(static)
  779. for (Long i = 0; i < send_size; i++) {
  780. send_indx[i] = psorted[i].key;
  781. }
  782. #pragma omp parallel for schedule(static)
  783. for (Integer i = 0; i < npes; i++) {
  784. Long start = std::lower_bound(send_indx.begin(), send_indx.begin() + send_size, glb_scan[i]) - send_indx.begin();
  785. Long end = (i + 1 < npes ? std::lower_bound(send_indx.begin(), send_indx.begin() + send_size, glb_scan[i + 1]) - send_indx.begin() : send_size);
  786. sendSz[i] = end - start;
  787. sendOff[i] = start;
  788. }
  789. Alltoall(sendSz.begin(), 1, recvSz.begin(), 1);
  790. recvOff[0] = 0;
  791. omp_par::scan(recvSz.begin(), recvOff.begin(), npes);
  792. assert(recvOff[npes - 1] + recvSz[npes - 1] == recv_size);
  793. Alltoallv(send_indx.begin(), sendSz.begin(), sendOff.begin(), recv_indx.begin(), recvSz.begin(), recvOff.begin());
  794. #pragma omp parallel for schedule(static)
  795. for (Long i = 0; i < recv_size; i++) {
  796. assert(recv_indx[i] >= glb_scan[rank]);
  797. recv_indx[i] -= glb_scan[rank];
  798. assert(recv_indx[i] < recv_size);
  799. }
  800. }
  801. Vector<Type> send_buff;
  802. { // Prepare send buffer
  803. send_buff.ReInit(send_size * data_dim);
  804. ConstIterator<Type> data = data_.begin();
  805. #pragma omp parallel for schedule(static)
  806. for (Long i = 0; i < send_size; i++) {
  807. Long src_indx = psorted[i].data * data_dim;
  808. Long trg_indx = i * data_dim;
  809. for (Long j = 0; j < data_dim; j++) send_buff[trg_indx + j] = data[src_indx + j];
  810. }
  811. }
  812. Vector<Type> recv_buff;
  813. { // All2Allv
  814. recv_buff.ReInit(recv_size * data_dim);
  815. #pragma omp parallel for schedule(static)
  816. for (Integer i = 0; i < npes; i++) {
  817. sendSz[i] *= data_dim;
  818. sendOff[i] *= data_dim;
  819. recvSz[i] *= data_dim;
  820. recvOff[i] *= data_dim;
  821. }
  822. Alltoallv(send_buff.begin(), sendSz.begin(), sendOff.begin(), recv_buff.begin(), recvSz.begin(), recvOff.begin());
  823. }
  824. { // Build output data.
  825. data_.ReInit(recv_size * data_dim);
  826. Iterator<Type> data = data_.begin();
  827. #pragma omp parallel for schedule(static)
  828. for (Long i = 0; i < recv_size; i++) {
  829. Long src_indx = i * data_dim;
  830. Long trg_indx = recv_indx[i] * data_dim;
  831. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = recv_buff[src_indx + j];
  832. }
  833. }
  834. }
  835. #ifdef SCTL_HAVE_MPI
  836. inline Vector<MPI_Request>* Comm::NewReq() const {
  837. if (req.empty()) req.push(new Vector<MPI_Request>);
  838. Vector<MPI_Request>& request = *(Vector<MPI_Request>*)req.top();
  839. req.pop();
  840. return &request;
  841. }
  842. inline void Comm::DelReq(Vector<MPI_Request>* req_ptr) const {
  843. if (req_ptr) req.push(req_ptr);
  844. }
  845. #define HS_MPIDATATYPE(CTYPE, MPITYPE) \
  846. template <> class Comm::CommDatatype<CTYPE> { \
  847. public: \
  848. static MPI_Datatype value() { return MPITYPE; } \
  849. static MPI_Op sum() { return MPI_SUM; } \
  850. static MPI_Op min() { return MPI_MIN; } \
  851. static MPI_Op max() { return MPI_MAX; } \
  852. }
  853. HS_MPIDATATYPE(short, MPI_SHORT);
  854. HS_MPIDATATYPE(int, MPI_INT);
  855. HS_MPIDATATYPE(long, MPI_LONG);
  856. HS_MPIDATATYPE(unsigned short, MPI_UNSIGNED_SHORT);
  857. HS_MPIDATATYPE(unsigned int, MPI_UNSIGNED);
  858. HS_MPIDATATYPE(unsigned long, MPI_UNSIGNED_LONG);
  859. HS_MPIDATATYPE(float, MPI_FLOAT);
  860. HS_MPIDATATYPE(double, MPI_DOUBLE);
  861. HS_MPIDATATYPE(long double, MPI_LONG_DOUBLE);
  862. HS_MPIDATATYPE(long long, MPI_LONG_LONG_INT);
  863. HS_MPIDATATYPE(char, MPI_CHAR);
  864. HS_MPIDATATYPE(unsigned char, MPI_UNSIGNED_CHAR);
  865. #undef HS_MPIDATATYPE
  866. #endif
  867. template <class Type> void Comm::HyperQuickSort(const Vector<Type>& arr_, Vector<Type>& SortedElem) const { // O( ((N/p)+log(p))*(log(N/p)+log(p)) )
  868. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  869. #ifdef SCTL_HAVE_MPI
  870. Integer npes, myrank, omp_p;
  871. { // Get comm size and rank.
  872. npes = Size();
  873. myrank = Rank();
  874. omp_p = omp_get_max_threads();
  875. }
  876. srand(myrank);
  877. Long totSize, nelem = arr_.Dim();
  878. { // Local and global sizes. O(log p)
  879. assert(nelem); // TODO: Check if this is needed.
  880. Allreduce<Long>(Ptr2ConstItr<Long>(&nelem, 1), Ptr2Itr<Long>(&totSize, 1), 1, CommOp::SUM);
  881. }
  882. if (npes == 1) { // SortedElem <--- local_sort(arr_)
  883. SortedElem = arr_;
  884. omp_par::merge_sort(SortedElem.begin(), SortedElem.begin() + nelem);
  885. return;
  886. }
  887. Vector<Type> arr;
  888. { // arr <-- local_sort(arr_)
  889. arr = arr_;
  890. omp_par::merge_sort(arr.begin(), arr.begin() + nelem);
  891. }
  892. Vector<Type> nbuff, nbuff_ext, rbuff, rbuff_ext; // Allocate memory.
  893. MPI_Comm comm = mpi_comm_; // Copy comm
  894. bool free_comm = false; // Flag to free comm.
  895. // Binary split and merge in each iteration.
  896. while (npes > 1 && totSize > 0) { // O(log p) iterations.
  897. Type split_key;
  898. Long totSize_new;
  899. { // Determine split_key. O( log(N/p) + log(p) )
  900. Integer glb_splt_count;
  901. Vector<Type> glb_splitters;
  902. { // Take random splitters. glb_splt_count = const = 100~1000
  903. Integer splt_count;
  904. { // Set splt_coun. O( 1 ) -- Let p * splt_count = t
  905. splt_count = (100 * nelem) / totSize;
  906. if (npes > 100) splt_count = (drand48() * totSize) < (100 * nelem) ? 1 : 0;
  907. if (splt_count > nelem) splt_count = nelem;
  908. }
  909. Vector<Type> splitters(splt_count);
  910. for (Integer i = 0; i < splt_count; i++) {
  911. splitters[i] = arr[rand() % nelem];
  912. }
  913. Vector<Integer> glb_splt_cnts(npes), glb_splt_disp(npes);
  914. { // Set glb_splt_count, glb_splt_cnts, glb_splt_disp
  915. MPI_Allgather(&splt_count, 1, CommDatatype<Integer>::value(), &glb_splt_cnts[0], 1, CommDatatype<Integer>::value(), comm);
  916. glb_splt_disp[0] = 0;
  917. omp_par::scan(glb_splt_cnts.begin(), glb_splt_disp.begin(), npes);
  918. glb_splt_count = glb_splt_cnts[npes - 1] + glb_splt_disp[npes - 1];
  919. SCTL_ASSERT(glb_splt_count);
  920. }
  921. { // Gather all splitters. O( log(p) )
  922. glb_splitters.ReInit(glb_splt_count);
  923. Vector<int> glb_splt_cnts_(npes), glb_splt_disp_(npes);
  924. for (Integer i = 0; i < npes; i++) {
  925. glb_splt_cnts_[i] = glb_splt_cnts[i];
  926. glb_splt_disp_[i] = glb_splt_disp[i];
  927. }
  928. MPI_Allgatherv((splt_count ? &splitters[0] : nullptr), splt_count, CommDatatype<Type>::value(), &glb_splitters[0], &glb_splt_cnts_[0], &glb_splt_disp_[0], CommDatatype<Type>::value(), comm);
  929. }
  930. }
  931. // Determine split key. O( log(N/p) + log(p) )
  932. Vector<Long> lrank(glb_splt_count);
  933. { // Compute local rank
  934. #pragma omp parallel for schedule(static)
  935. for (Integer i = 0; i < glb_splt_count; i++) {
  936. lrank[i] = std::lower_bound(arr.begin(), arr.begin() + nelem, glb_splitters[i]) - arr.begin();
  937. }
  938. }
  939. Vector<Long> grank(glb_splt_count);
  940. { // Compute global rank
  941. MPI_Allreduce(&lrank[0], &grank[0], glb_splt_count, CommDatatype<Long>::value(), CommDatatype<Long>::sum(), comm);
  942. }
  943. { // Determine split_key, totSize_new
  944. ConstIterator<Long> split_disp = grank.begin();
  945. for (Integer i = 0; i < glb_splt_count; i++) {
  946. if (labs(grank[i] - totSize / 2) < labs(*split_disp - totSize / 2)) {
  947. split_disp = grank.begin() + i;
  948. }
  949. }
  950. split_key = glb_splitters[split_disp - grank.begin()];
  951. if (myrank <= (npes - 1) / 2)
  952. totSize_new = split_disp[0];
  953. else
  954. totSize_new = totSize - split_disp[0];
  955. // double err=(((double)*split_disp)/(totSize/2))-1.0;
  956. // if(fabs<double>(err)<0.01 || npes<=16) break;
  957. // else if(!myrank) std::cout<<err<<'\n';
  958. }
  959. }
  960. Integer split_id = (npes - 1) / 2;
  961. { // Split problem into two. O( N/p )
  962. Integer new_p0 = (myrank <= split_id ? 0 : split_id + 1);
  963. Integer cmp_p0 = (myrank > split_id ? 0 : split_id + 1);
  964. Integer partner;
  965. { // Set partner
  966. partner = myrank + cmp_p0 - new_p0;
  967. if (partner >= npes) partner = npes - 1;
  968. assert(partner >= 0);
  969. }
  970. bool extra_partner = (npes % 2 == 1 && npes - 1 == myrank);
  971. Long ssize = 0, lsize = 0;
  972. ConstIterator<Type> sbuff, lbuff;
  973. { // Set ssize, lsize, sbuff, lbuff
  974. Long split_indx = std::lower_bound(arr.begin(), arr.begin() + nelem, split_key) - arr.begin();
  975. ssize = (myrank > split_id ? split_indx : nelem - split_indx);
  976. sbuff = (myrank > split_id ? arr.begin() : arr.begin() + split_indx);
  977. lsize = (myrank <= split_id ? split_indx : nelem - split_indx);
  978. lbuff = (myrank <= split_id ? arr.begin() : arr.begin() + split_indx);
  979. }
  980. Long rsize = 0, ext_rsize = 0;
  981. { // Get rsize, ext_rsize
  982. Long ext_ssize = 0;
  983. MPI_Status status;
  984. MPI_Sendrecv(&ssize, 1, CommDatatype<Long>::value(), partner, 0, &rsize, 1, CommDatatype<Long>::value(), partner, 0, comm, &status);
  985. if (extra_partner) MPI_Sendrecv(&ext_ssize, 1, CommDatatype<Long>::value(), split_id, 0, &ext_rsize, 1, CommDatatype<Long>::value(), split_id, 0, comm, &status);
  986. }
  987. { // Exchange data.
  988. rbuff.ReInit(rsize);
  989. rbuff_ext.ReInit(ext_rsize);
  990. MPI_Status status;
  991. MPI_Sendrecv((ssize ? &sbuff[0] : nullptr), ssize, CommDatatype<Type>::value(), partner, 0, (rsize ? &rbuff[0] : nullptr), rsize, CommDatatype<Type>::value(), partner, 0, comm, &status);
  992. if (extra_partner) MPI_Sendrecv(nullptr, 0, CommDatatype<Type>::value(), split_id, 0, (ext_rsize ? &rbuff_ext[0] : nullptr), ext_rsize, CommDatatype<Type>::value(), split_id, 0, comm, &status);
  993. }
  994. Long nbuff_size = lsize + rsize + ext_rsize;
  995. { // nbuff <-- merge(lbuff, rbuff, rbuff_ext)
  996. nbuff.ReInit(lsize + rsize);
  997. omp_par::merge<ConstIterator<Type>>(lbuff, (lbuff + lsize), rbuff.begin(), rbuff.begin() + rsize, nbuff.begin(), omp_p, std::less<Type>());
  998. if (ext_rsize > 0 && nbuff.Dim() > 0) {
  999. nbuff_ext.ReInit(nbuff_size);
  1000. omp_par::merge(nbuff.begin(), nbuff.begin() + (lsize + rsize), rbuff_ext.begin(), rbuff_ext.begin() + ext_rsize, nbuff_ext.begin(), omp_p, std::less<Type>());
  1001. nbuff.Swap(nbuff_ext);
  1002. nbuff_ext.ReInit(0);
  1003. }
  1004. }
  1005. // Copy new data.
  1006. totSize = totSize_new;
  1007. nelem = nbuff_size;
  1008. arr.Swap(nbuff);
  1009. nbuff.ReInit(0);
  1010. }
  1011. { // Split comm. O( log(p) ) ??
  1012. MPI_Comm scomm;
  1013. MPI_Comm_split(comm, myrank <= split_id, myrank, &scomm);
  1014. if (free_comm) MPI_Comm_free(&comm);
  1015. comm = scomm;
  1016. free_comm = true;
  1017. npes = (myrank <= split_id ? split_id + 1 : npes - split_id - 1);
  1018. myrank = (myrank <= split_id ? myrank : myrank - split_id - 1);
  1019. }
  1020. }
  1021. if (free_comm) MPI_Comm_free(&comm);
  1022. SortedElem = arr;
  1023. PartitionW<Type>(SortedElem);
  1024. #else
  1025. SortedElem = arr_;
  1026. std::sort(SortedElem.begin(), SortedElem.begin() + SortedElem.Dim());
  1027. #endif
  1028. }
  1029. } // end namespace