comm.txx 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169
  1. #include <type_traits>
  2. #include SCTL_INCLUDE(ompUtils.hpp)
  3. #include SCTL_INCLUDE(vector.hpp)
  4. namespace SCTL_NAMESPACE {
  5. inline Comm::Comm() {
  6. #ifdef SCTL_HAVE_MPI
  7. Init(MPI_COMM_SELF);
  8. #endif
  9. }
  10. inline Comm::Comm(const Comm& c) {
  11. #ifdef SCTL_HAVE_MPI
  12. Init(c.mpi_comm_);
  13. #endif
  14. }
  15. inline Comm Comm::Self() {
  16. #ifdef SCTL_HAVE_MPI
  17. Comm comm_self(MPI_COMM_SELF);
  18. return comm_self;
  19. #else
  20. Comm comm_self;
  21. return comm_self;
  22. #endif
  23. }
  24. inline Comm Comm::World() {
  25. #ifdef SCTL_HAVE_MPI
  26. Comm comm_world(MPI_COMM_WORLD);
  27. return comm_world;
  28. #else
  29. Comm comm_self;
  30. return comm_self;
  31. #endif
  32. }
  33. inline Comm& Comm::operator=(const Comm& c) {
  34. #ifdef SCTL_HAVE_MPI
  35. #pragma omp critical(SCTL_COMM_DUP)
  36. MPI_Comm_free(&mpi_comm_);
  37. Init(c.mpi_comm_);
  38. #endif
  39. return *this;
  40. }
  41. inline Comm::~Comm() {
  42. #ifdef SCTL_HAVE_MPI
  43. while (!req.empty()) {
  44. delete (Vector<MPI_Request>*)req.top();
  45. req.pop();
  46. }
  47. #pragma omp critical(SCTL_COMM_DUP)
  48. MPI_Comm_free(&mpi_comm_);
  49. #endif
  50. }
  51. inline Comm Comm::Split(Integer clr) const {
  52. #ifdef SCTL_HAVE_MPI
  53. MPI_Comm new_comm;
  54. #pragma omp critical(SCTL_COMM_DUP)
  55. MPI_Comm_split(mpi_comm_, clr, mpi_rank_, &new_comm);
  56. Comm c(new_comm);
  57. #pragma omp critical(SCTL_COMM_DUP)
  58. MPI_Comm_free(&new_comm);
  59. return c;
  60. #else
  61. Comm c;
  62. return c;
  63. #endif
  64. }
  65. inline Integer Comm::Rank() const {
  66. #ifdef SCTL_HAVE_MPI
  67. return mpi_rank_;
  68. #else
  69. return 0;
  70. #endif
  71. }
  72. inline Integer Comm::Size() const {
  73. #ifdef SCTL_HAVE_MPI
  74. return mpi_size_;
  75. #else
  76. return 1;
  77. #endif
  78. }
  79. inline void Comm::Barrier() const {
  80. #ifdef SCTL_HAVE_MPI
  81. MPI_Barrier(mpi_comm_);
  82. #endif
  83. }
  84. template <class SType> void* Comm::Isend(ConstIterator<SType> sbuf, Long scount, Integer dest, Integer tag) const {
  85. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  86. #ifdef SCTL_HAVE_MPI
  87. if (!scount) return nullptr;
  88. Vector<MPI_Request>& request = *NewReq();
  89. request.ReInit(1);
  90. SCTL_UNUSED(sbuf[0] );
  91. SCTL_UNUSED(sbuf[scount - 1]);
  92. #ifndef NDEBUG
  93. MPI_Issend(&sbuf[0], scount, CommDatatype<SType>::value(), dest, tag, mpi_comm_, &request[0]);
  94. #else
  95. MPI_Isend(&sbuf[0], scount, CommDatatype<SType>::value(), dest, tag, mpi_comm_, &request[0]);
  96. #endif
  97. return &request;
  98. #else
  99. auto it = recv_req.find(tag);
  100. if (it == recv_req.end()) {
  101. send_req.insert(std::pair<Integer, ConstIterator<char>>(tag, (ConstIterator<char>)sbuf));
  102. } else {
  103. memcopy(it->second, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  104. recv_req.erase(it);
  105. }
  106. return nullptr;
  107. #endif
  108. }
  109. template <class RType> void* Comm::Irecv(Iterator<RType> rbuf, Long rcount, Integer source, Integer tag) const {
  110. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  111. #ifdef SCTL_HAVE_MPI
  112. if (!rcount) return nullptr;
  113. Vector<MPI_Request>& request = *NewReq();
  114. request.ReInit(1);
  115. SCTL_UNUSED(rbuf[0] );
  116. SCTL_UNUSED(rbuf[rcount - 1]);
  117. MPI_Irecv(&rbuf[0], rcount, CommDatatype<RType>::value(), source, tag, mpi_comm_, &request[0]);
  118. return &request;
  119. #else
  120. auto it = send_req.find(tag);
  121. if (it == send_req.end()) {
  122. recv_req.insert(std::pair<Integer, Iterator<char>>(tag, (Iterator<char>)rbuf));
  123. } else {
  124. memcopy((Iterator<char>)rbuf, it->second, rcount * sizeof(RType));
  125. send_req.erase(it);
  126. }
  127. return nullptr;
  128. #endif
  129. }
  130. inline void Comm::Wait(void* req_ptr) const {
  131. #ifdef SCTL_HAVE_MPI
  132. if (req_ptr == nullptr) return;
  133. Vector<MPI_Request>& request = *(Vector<MPI_Request>*)req_ptr;
  134. // std::vector<MPI_Status> status(request.Dim());
  135. if (request.Dim()) MPI_Waitall(request.Dim(), &request[0], MPI_STATUSES_IGNORE); //&status[0]);
  136. DelReq(&request);
  137. #endif
  138. }
  139. template <class SType, class RType> void Comm::Allgather(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const {
  140. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  141. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  142. if (scount) {
  143. SCTL_UNUSED(sbuf[0] );
  144. SCTL_UNUSED(sbuf[scount - 1]);
  145. }
  146. if (rcount) {
  147. SCTL_UNUSED(rbuf[0] );
  148. SCTL_UNUSED(rbuf[rcount * Size() - 1]);
  149. }
  150. #ifdef SCTL_HAVE_MPI
  151. MPI_Allgather((scount ? &sbuf[0] : nullptr), scount, CommDatatype<SType>::value(), (rcount ? &rbuf[0] : nullptr), rcount, CommDatatype<RType>::value(), mpi_comm_);
  152. #else
  153. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  154. #endif
  155. }
  156. template <class SType, class RType> void Comm::Allgatherv(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const {
  157. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  158. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  159. #ifdef SCTL_HAVE_MPI
  160. Vector<int> rcounts_(mpi_size_), rdispls_(mpi_size_);
  161. Long rcount_sum = 0;
  162. #pragma omp parallel for schedule(static) reduction(+ : rcount_sum)
  163. for (Integer i = 0; i < mpi_size_; i++) {
  164. rcounts_[i] = rcounts[i];
  165. rdispls_[i] = rdispls[i];
  166. rcount_sum += rcounts[i];
  167. }
  168. if (scount) {
  169. SCTL_UNUSED(sbuf[0] );
  170. SCTL_UNUSED(sbuf[scount - 1]);
  171. }
  172. if (rcount_sum) {
  173. SCTL_UNUSED(rbuf[0] );
  174. SCTL_UNUSED(rbuf[rcount_sum - 1]);
  175. }
  176. MPI_Allgatherv((scount ? &sbuf[0] : nullptr), scount, CommDatatype<SType>::value(), (rcount_sum ? &rbuf[0] : nullptr), &rcounts_.begin()[0], &rdispls_.begin()[0], CommDatatype<RType>::value(), mpi_comm_);
  177. #else
  178. memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)sbuf, scount * sizeof(SType));
  179. #endif
  180. }
  181. template <class SType, class RType> void Comm::Alltoall(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const {
  182. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  183. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  184. #ifdef SCTL_HAVE_MPI
  185. if (scount) {
  186. SCTL_UNUSED(sbuf[0] );
  187. SCTL_UNUSED(sbuf[scount * Size() - 1]);
  188. }
  189. if (rcount) {
  190. SCTL_UNUSED(rbuf[0] );
  191. SCTL_UNUSED(rbuf[rcount * Size() - 1]);
  192. }
  193. MPI_Alltoall((scount ? &sbuf[0] : nullptr), scount, CommDatatype<SType>::value(), (rcount ? &rbuf[0] : nullptr), rcount, CommDatatype<RType>::value(), mpi_comm_);
  194. #else
  195. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, scount * sizeof(SType));
  196. #endif
  197. }
  198. template <class SType, class RType> void* Comm::Ialltoallv_sparse(ConstIterator<SType> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls, Integer tag) const {
  199. static_assert(std::is_trivially_copyable<SType>::value, "Data is not trivially copyable!");
  200. static_assert(std::is_trivially_copyable<RType>::value, "Data is not trivially copyable!");
  201. #ifdef SCTL_HAVE_MPI
  202. Integer request_count = 0;
  203. for (Integer i = 0; i < mpi_size_; i++) {
  204. if (rcounts[i]) request_count++;
  205. if (scounts[i]) request_count++;
  206. }
  207. if (!request_count) return nullptr;
  208. Vector<MPI_Request>& request = *NewReq();
  209. request.ReInit(request_count);
  210. Integer request_iter = 0;
  211. for (Integer i = 0; i < mpi_size_; i++) {
  212. if (rcounts[i]) {
  213. SCTL_UNUSED(rbuf[rdispls[i]]);
  214. SCTL_UNUSED(rbuf[rdispls[i] + rcounts[i] - 1]);
  215. MPI_Irecv(&rbuf[rdispls[i]], rcounts[i], CommDatatype<RType>::value(), i, tag, mpi_comm_, &request[request_iter]);
  216. request_iter++;
  217. }
  218. }
  219. for (Integer i = 0; i < mpi_size_; i++) {
  220. if (scounts[i]) {
  221. SCTL_UNUSED(sbuf[sdispls[i]]);
  222. SCTL_UNUSED(sbuf[sdispls[i] + scounts[i] - 1]);
  223. MPI_Isend(&sbuf[sdispls[i]], scounts[i], CommDatatype<SType>::value(), i, tag, mpi_comm_, &request[request_iter]);
  224. request_iter++;
  225. }
  226. }
  227. return &request;
  228. #else
  229. memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)(sbuf + sdispls[0]), scounts[0] * sizeof(SType));
  230. return nullptr;
  231. #endif
  232. }
  233. template <class Type> void Comm::Alltoallv(ConstIterator<Type> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<Type> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const {
  234. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  235. #ifdef SCTL_HAVE_MPI
  236. { // Use Alltoallv_sparse of average connectivity<64
  237. Long connectivity = 0, glb_connectivity = 0;
  238. #pragma omp parallel for schedule(static) reduction(+ : connectivity)
  239. for (Integer i = 0; i < mpi_size_; i++) {
  240. if (rcounts[i]) connectivity++;
  241. }
  242. Allreduce(Ptr2ConstItr<Long>(&connectivity, 1), Ptr2Itr<Long>(&glb_connectivity, 1), 1, CommOp::SUM);
  243. if (glb_connectivity < 64 * Size()) {
  244. void* mpi_req = Ialltoallv_sparse(sbuf, scounts, sdispls, rbuf, rcounts, rdispls, 0);
  245. Wait(mpi_req);
  246. return;
  247. }
  248. }
  249. { // Use vendor MPI_Alltoallv
  250. //#ifndef ALLTOALLV_FIX
  251. Vector<int> scnt, sdsp, rcnt, rdsp;
  252. scnt.ReInit(mpi_size_);
  253. sdsp.ReInit(mpi_size_);
  254. rcnt.ReInit(mpi_size_);
  255. rdsp.ReInit(mpi_size_);
  256. Long stotal = 0, rtotal = 0;
  257. #pragma omp parallel for schedule(static) reduction(+ : stotal, rtotal)
  258. for (Integer i = 0; i < mpi_size_; i++) {
  259. scnt[i] = scounts[i];
  260. sdsp[i] = sdispls[i];
  261. rcnt[i] = rcounts[i];
  262. rdsp[i] = rdispls[i];
  263. stotal += scounts[i];
  264. rtotal += rcounts[i];
  265. }
  266. MPI_Alltoallv((stotal ? &sbuf[0] : nullptr), &scnt[0], &sdsp[0], CommDatatype<Type>::value(), (rtotal ? &rbuf[0] : nullptr), &rcnt[0], &rdsp[0], CommDatatype<Type>::value(), mpi_comm_);
  267. return;
  268. //#endif
  269. }
  270. // TODO: implement hypercube scheme
  271. #else
  272. memcopy((Iterator<char>)(rbuf + rdispls[0]), (ConstIterator<char>)(sbuf + sdispls[0]), scounts[0] * sizeof(Type));
  273. #endif
  274. }
  275. template <class Type> void Comm::Allreduce(ConstIterator<Type> sbuf, Iterator<Type> rbuf, Long count, CommOp op) const {
  276. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  277. #ifdef SCTL_HAVE_MPI
  278. if (!count) return;
  279. MPI_Op mpi_op;
  280. switch (op) {
  281. case CommOp::SUM:
  282. mpi_op = CommDatatype<Type>::sum();
  283. break;
  284. case CommOp::MIN:
  285. mpi_op = CommDatatype<Type>::min();
  286. break;
  287. case CommOp::MAX:
  288. mpi_op = CommDatatype<Type>::max();
  289. break;
  290. default:
  291. mpi_op = MPI_OP_NULL;
  292. break;
  293. }
  294. SCTL_UNUSED(sbuf[0] );
  295. SCTL_UNUSED(sbuf[count - 1]);
  296. SCTL_UNUSED(rbuf[0] );
  297. SCTL_UNUSED(rbuf[count - 1]);
  298. MPI_Allreduce(&sbuf[0], &rbuf[0], count, CommDatatype<Type>::value(), mpi_op, mpi_comm_);
  299. #else
  300. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, count * sizeof(Type));
  301. #endif
  302. }
  303. template <class Type> void Comm::Scan(ConstIterator<Type> sbuf, Iterator<Type> rbuf, int count, CommOp op) const {
  304. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  305. #ifdef SCTL_HAVE_MPI
  306. if (!count) return;
  307. MPI_Op mpi_op;
  308. switch (op) {
  309. case CommOp::SUM:
  310. mpi_op = CommDatatype<Type>::sum();
  311. break;
  312. case CommOp::MIN:
  313. mpi_op = CommDatatype<Type>::min();
  314. break;
  315. case CommOp::MAX:
  316. mpi_op = CommDatatype<Type>::max();
  317. break;
  318. default:
  319. mpi_op = MPI_OP_NULL;
  320. break;
  321. }
  322. SCTL_UNUSED(sbuf[0] );
  323. SCTL_UNUSED(sbuf[count - 1]);
  324. SCTL_UNUSED(rbuf[0] );
  325. SCTL_UNUSED(rbuf[count - 1]);
  326. MPI_Scan(&sbuf[0], &rbuf[0], count, CommDatatype<Type>::value(), mpi_op, mpi_comm_);
  327. #else
  328. memcopy((Iterator<char>)rbuf, (ConstIterator<char>)sbuf, count * sizeof(Type));
  329. #endif
  330. }
  331. template <class Type> void Comm::PartitionW(Vector<Type>& nodeList, const Vector<Long>* wts_) const {
  332. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  333. Integer npes = Size();
  334. if (npes == 1) return;
  335. Long nlSize = nodeList.Dim();
  336. Vector<Long> wts;
  337. Long localWt = 0;
  338. if (wts_ == nullptr) { // Construct arrays of wts.
  339. wts.ReInit(nlSize);
  340. #pragma omp parallel for schedule(static)
  341. for (Long i = 0; i < nlSize; i++) {
  342. wts[i] = 1;
  343. }
  344. localWt = nlSize;
  345. } else {
  346. wts.ReInit(nlSize, (Iterator<Long>)wts_->begin(), false);
  347. #pragma omp parallel for reduction(+ : localWt)
  348. for (Long i = 0; i < nlSize; i++) {
  349. localWt += wts[i];
  350. }
  351. }
  352. Long off1 = 0, off2 = 0, totalWt = 0;
  353. { // compute the total weight of the problem ...
  354. Allreduce<Long>(Ptr2ConstItr<Long>(&localWt, 1), Ptr2Itr<Long>(&totalWt, 1), 1, CommOp::SUM);
  355. Scan<Long>(Ptr2ConstItr<Long>(&localWt, 1), Ptr2Itr<Long>(&off2, 1), 1, CommOp::SUM);
  356. off1 = off2 - localWt;
  357. }
  358. Vector<Long> lscn;
  359. if (nlSize) { // perform a local scan on the weights first ...
  360. lscn.ReInit(nlSize);
  361. lscn[0] = off1;
  362. omp_par::scan(wts.begin(), lscn.begin(), nlSize);
  363. }
  364. Vector<Long> sendSz, recvSz, sendOff, recvOff;
  365. sendSz.ReInit(npes);
  366. recvSz.ReInit(npes);
  367. sendOff.ReInit(npes);
  368. recvOff.ReInit(npes);
  369. sendSz.SetZero();
  370. if (nlSize > 0 && totalWt > 0) { // Compute sendSz
  371. Long pid1 = (off1 * npes) / totalWt;
  372. Long pid2 = ((off2 + 1) * npes) / totalWt + 1;
  373. assert((totalWt * pid2) / npes >= off2);
  374. pid1 = (pid1 < 0 ? 0 : pid1);
  375. pid2 = (pid2 > npes ? npes : pid2);
  376. #pragma omp parallel for schedule(static)
  377. for (Integer i = pid1; i < pid2; i++) {
  378. Long wt1 = (totalWt * (i)) / npes;
  379. Long wt2 = (totalWt * (i + 1)) / npes;
  380. Long start = std::lower_bound(lscn.begin(), lscn.begin() + nlSize, wt1, std::less<Long>()) - lscn.begin();
  381. Long end = std::lower_bound(lscn.begin(), lscn.begin() + nlSize, wt2, std::less<Long>()) - lscn.begin();
  382. if (i == 0) start = 0;
  383. if (i == npes - 1) end = nlSize;
  384. sendSz[i] = end - start;
  385. }
  386. } else {
  387. sendSz[0] = nlSize;
  388. }
  389. // Exchange sendSz, recvSz
  390. Alltoall<Long>(sendSz.begin(), 1, recvSz.begin(), 1);
  391. { // Compute sendOff, recvOff
  392. sendOff[0] = 0;
  393. omp_par::scan(sendSz.begin(), sendOff.begin(), npes);
  394. recvOff[0] = 0;
  395. omp_par::scan(recvSz.begin(), recvOff.begin(), npes);
  396. assert(sendOff[npes - 1] + sendSz[npes - 1] == nlSize);
  397. }
  398. // perform All2All ...
  399. Vector<Type> newNodes;
  400. newNodes.ReInit(recvSz[npes - 1] + recvOff[npes - 1]);
  401. void* mpi_req = Ialltoallv_sparse<Type>(nodeList.begin(), sendSz.begin(), sendOff.begin(), newNodes.begin(), recvSz.begin(), recvOff.begin());
  402. Wait(mpi_req);
  403. // reset the pointer ...
  404. nodeList.Swap(newNodes);
  405. }
  406. template <class Type> void Comm::PartitionN(Vector<Type>& v, Long N) const {
  407. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  408. Integer rank = Rank();
  409. Integer np = Size();
  410. if (np == 1) return;
  411. Vector<Long> v_cnt(np), v_dsp(np + 1);
  412. Vector<Long> N_cnt(np), N_dsp(np + 1);
  413. { // Set v_cnt, v_dsp
  414. v_dsp[0] = 0;
  415. Long cnt = v.Dim();
  416. Allgather(Ptr2ConstItr<Long>(&cnt, 1), 1, v_cnt.begin(), 1);
  417. omp_par::scan(v_cnt.begin(), v_dsp.begin(), np);
  418. v_dsp[np] = v_cnt[np - 1] + v_dsp[np - 1];
  419. }
  420. { // Set N_cnt, N_dsp
  421. N_dsp[0] = 0;
  422. Long cnt = N;
  423. Allgather(Ptr2ConstItr<Long>(&cnt, 1), 1, N_cnt.begin(), 1);
  424. omp_par::scan(N_cnt.begin(), N_dsp.begin(), np);
  425. N_dsp[np] = N_cnt[np - 1] + N_dsp[np - 1];
  426. }
  427. { // Adjust for dof
  428. Long dof = (N_dsp[np] ? v_dsp[np] / N_dsp[np] : 0);
  429. assert(dof * N_dsp[np] == v_dsp[np]);
  430. if (dof == 0) return;
  431. if (dof != 1) {
  432. #pragma omp parallel for schedule(static)
  433. for (Integer i = 0; i < np; i++) N_cnt[i] *= dof;
  434. #pragma omp parallel for schedule(static)
  435. for (Integer i = 0; i <= np; i++) N_dsp[i] *= dof;
  436. }
  437. }
  438. Vector<Type> v_(N_cnt[rank]);
  439. { // Set v_
  440. Vector<Long> scnt(np), sdsp(np);
  441. Vector<Long> rcnt(np), rdsp(np);
  442. #pragma omp parallel for schedule(static)
  443. for (Integer i = 0; i < np; i++) {
  444. { // Set scnt
  445. Long n0 = N_dsp[i + 0];
  446. Long n1 = N_dsp[i + 1];
  447. if (n0 < v_dsp[rank + 0]) n0 = v_dsp[rank + 0];
  448. if (n1 < v_dsp[rank + 0]) n1 = v_dsp[rank + 0];
  449. if (n0 > v_dsp[rank + 1]) n0 = v_dsp[rank + 1];
  450. if (n1 > v_dsp[rank + 1]) n1 = v_dsp[rank + 1];
  451. scnt[i] = n1 - n0;
  452. }
  453. { // Set rcnt
  454. Long n0 = v_dsp[i + 0];
  455. Long n1 = v_dsp[i + 1];
  456. if (n0 < N_dsp[rank + 0]) n0 = N_dsp[rank + 0];
  457. if (n1 < N_dsp[rank + 0]) n1 = N_dsp[rank + 0];
  458. if (n0 > N_dsp[rank + 1]) n0 = N_dsp[rank + 1];
  459. if (n1 > N_dsp[rank + 1]) n1 = N_dsp[rank + 1];
  460. rcnt[i] = n1 - n0;
  461. }
  462. }
  463. sdsp[0] = 0;
  464. omp_par::scan(scnt.begin(), sdsp.begin(), np);
  465. rdsp[0] = 0;
  466. omp_par::scan(rcnt.begin(), rdsp.begin(), np);
  467. void* mpi_request = Ialltoallv_sparse(v.begin(), scnt.begin(), sdsp.begin(), v_.begin(), rcnt.begin(), rdsp.begin());
  468. Wait(mpi_request);
  469. }
  470. v.Swap(v_);
  471. }
  472. template <class Type> void Comm::PartitionS(Vector<Type>& nodeList, const Type& splitter) const {
  473. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  474. Integer npes = Size();
  475. if (npes == 1) return;
  476. Vector<Type> mins(npes);
  477. Allgather(Ptr2ConstItr<Type>(&splitter, 1), 1, mins.begin(), 1);
  478. Vector<Long> scnt(npes), sdsp(npes);
  479. Vector<Long> rcnt(npes), rdsp(npes);
  480. { // Compute scnt, sdsp
  481. #pragma omp parallel for schedule(static)
  482. for (Integer i = 0; i < npes; i++) {
  483. sdsp[i] = std::lower_bound(nodeList.begin(), nodeList.begin() + nodeList.Dim(), mins[i]) - nodeList.begin();
  484. }
  485. #pragma omp parallel for schedule(static)
  486. for (Integer i = 0; i < npes - 1; i++) {
  487. scnt[i] = sdsp[i + 1] - sdsp[i];
  488. }
  489. scnt[npes - 1] = nodeList.Dim() - sdsp[npes - 1];
  490. }
  491. { // Compute rcnt, rdsp
  492. rdsp[0] = 0;
  493. Alltoall(scnt.begin(), 1, rcnt.begin(), 1);
  494. omp_par::scan(rcnt.begin(), rdsp.begin(), npes);
  495. }
  496. { // Redistribute nodeList
  497. Vector<Type> nodeList_(rdsp[npes - 1] + rcnt[npes - 1]);
  498. void* mpi_request = Ialltoallv_sparse(nodeList.begin(), scnt.begin(), sdsp.begin(), nodeList_.begin(), rcnt.begin(), rdsp.begin());
  499. Wait(mpi_request);
  500. nodeList.Swap(nodeList_);
  501. }
  502. }
  503. template <class Type> void Comm::SortScatterIndex(const Vector<Type>& key, Vector<Long>& scatter_index, const Type* split_key_) const {
  504. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  505. typedef SortPair<Type, Long> Pair_t;
  506. Integer npes = Size();
  507. Vector<Pair_t> parray(key.Dim());
  508. { // Build global index.
  509. Long glb_dsp = 0;
  510. Long loc_size = key.Dim();
  511. Scan(Ptr2ConstItr<Long>(&loc_size, 1), Ptr2Itr<Long>(&glb_dsp, 1), 1, CommOp::SUM);
  512. glb_dsp -= loc_size;
  513. #pragma omp parallel for schedule(static)
  514. for (Long i = 0; i < loc_size; i++) {
  515. parray[i].key = key[i];
  516. parray[i].data = glb_dsp + i;
  517. }
  518. }
  519. Vector<Pair_t> psorted;
  520. HyperQuickSort(parray, psorted);
  521. if (npes > 1 && split_key_ != nullptr) { // Partition data
  522. Vector<Type> split_key(npes);
  523. Allgather(Ptr2ConstItr<Type>(split_key_, 1), 1, split_key.begin(), 1);
  524. Vector<Long> sendSz(npes);
  525. Vector<Long> recvSz(npes);
  526. Vector<Long> sendOff(npes);
  527. Vector<Long> recvOff(npes);
  528. Long nlSize = psorted.Dim();
  529. sendSz.SetZero();
  530. if (nlSize > 0) { // Compute sendSz
  531. // Determine processor range.
  532. Long pid1 = std::lower_bound(split_key.begin(), split_key.begin() + npes, psorted[0].key) - split_key.begin() - 1;
  533. Long pid2 = std::upper_bound(split_key.begin(), split_key.begin() + npes, psorted[nlSize - 1].key) - split_key.begin() + 1;
  534. pid1 = (pid1 < 0 ? 0 : pid1);
  535. pid2 = (pid2 > npes ? npes : pid2);
  536. #pragma omp parallel for schedule(static)
  537. for (Integer i = pid1; i < pid2; i++) {
  538. Pair_t p1;
  539. p1.key = split_key[i];
  540. Pair_t p2;
  541. p2.key = split_key[i + 1 < npes ? i + 1 : i];
  542. Long start = std::lower_bound(psorted.begin(), psorted.begin() + nlSize, p1, std::less<Pair_t>()) - psorted.begin();
  543. Long end = std::lower_bound(psorted.begin(), psorted.begin() + nlSize, p2, std::less<Pair_t>()) - psorted.begin();
  544. if (i == 0) start = 0;
  545. if (i == npes - 1) end = nlSize;
  546. sendSz[i] = end - start;
  547. }
  548. }
  549. // Exchange sendSz, recvSz
  550. Alltoall<Long>(sendSz.begin(), 1, recvSz.begin(), 1);
  551. // compute offsets ...
  552. { // Compute sendOff, recvOff
  553. sendOff[0] = 0;
  554. omp_par::scan(sendSz.begin(), sendOff.begin(), npes);
  555. recvOff[0] = 0;
  556. omp_par::scan(recvSz.begin(), recvOff.begin(), npes);
  557. assert(sendOff[npes - 1] + sendSz[npes - 1] == nlSize);
  558. }
  559. // perform All2All ...
  560. Vector<Pair_t> newNodes(recvSz[npes - 1] + recvOff[npes - 1]);
  561. void* mpi_req = Ialltoallv_sparse<Pair_t>(psorted.begin(), sendSz.begin(), sendOff.begin(), newNodes.begin(), recvSz.begin(), recvOff.begin());
  562. Wait(mpi_req);
  563. // reset the pointer ...
  564. psorted.Swap(newNodes);
  565. }
  566. scatter_index.ReInit(psorted.Dim());
  567. #pragma omp parallel for schedule(static)
  568. for (Long i = 0; i < psorted.Dim(); i++) {
  569. scatter_index[i] = psorted[i].data;
  570. }
  571. }
  572. template <class Type> void Comm::ScatterForward(Vector<Type>& data_, const Vector<Long>& scatter_index) const {
  573. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  574. typedef SortPair<Long, Long> Pair_t;
  575. Integer npes = Size(), rank = Rank();
  576. Long data_dim = 0;
  577. Long send_size = 0;
  578. Long recv_size = 0;
  579. { // Set data_dim, send_size, recv_size
  580. recv_size = scatter_index.Dim();
  581. StaticArray<Long, 2> glb_size;
  582. StaticArray<Long, 2> loc_size;
  583. loc_size[0] = data_.Dim();
  584. loc_size[1] = recv_size;
  585. Allreduce<Long>(loc_size, glb_size, 2, CommOp::SUM);
  586. if (glb_size[0] == 0 || glb_size[1] == 0) return; // Nothing to be done.
  587. data_dim = glb_size[0] / glb_size[1];
  588. SCTL_ASSERT(glb_size[0] == data_dim * glb_size[1]);
  589. send_size = data_.Dim() / data_dim;
  590. }
  591. if (npes == 1) { // Scatter directly
  592. Vector<Type> data;
  593. data.ReInit(recv_size * data_dim);
  594. #pragma omp parallel for schedule(static)
  595. for (Long i = 0; i < recv_size; i++) {
  596. Long src_indx = scatter_index[i] * data_dim;
  597. Long trg_indx = i * data_dim;
  598. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = data_[src_indx + j];
  599. }
  600. data_.Swap(data);
  601. return;
  602. }
  603. Vector<Long> glb_scan;
  604. { // Global scan of data size.
  605. glb_scan.ReInit(npes);
  606. Long glb_rank = 0;
  607. Scan(Ptr2ConstItr<Long>(&send_size, 1), Ptr2Itr<Long>(&glb_rank, 1), 1, CommOp::SUM);
  608. glb_rank -= send_size;
  609. Allgather(Ptr2ConstItr<Long>(&glb_rank, 1), 1, glb_scan.begin(), 1);
  610. }
  611. Vector<Pair_t> psorted;
  612. { // Sort scatter_index.
  613. psorted.ReInit(recv_size);
  614. #pragma omp parallel for schedule(static)
  615. for (Long i = 0; i < recv_size; i++) {
  616. psorted[i].key = scatter_index[i];
  617. psorted[i].data = i;
  618. }
  619. omp_par::merge_sort(psorted.begin(), psorted.begin() + recv_size);
  620. }
  621. Vector<Long> recv_indx(recv_size);
  622. Vector<Long> send_indx(send_size);
  623. Vector<Long> sendSz(npes);
  624. Vector<Long> sendOff(npes);
  625. Vector<Long> recvSz(npes);
  626. Vector<Long> recvOff(npes);
  627. { // Exchange send, recv indices.
  628. #pragma omp parallel for schedule(static)
  629. for (Long i = 0; i < recv_size; i++) {
  630. recv_indx[i] = psorted[i].key;
  631. }
  632. #pragma omp parallel for schedule(static)
  633. for (Integer i = 0; i < npes; i++) {
  634. Long start = std::lower_bound(recv_indx.begin(), recv_indx.begin() + recv_size, glb_scan[i]) - recv_indx.begin();
  635. Long end = (i + 1 < npes ? std::lower_bound(recv_indx.begin(), recv_indx.begin() + recv_size, glb_scan[i + 1]) - recv_indx.begin() : recv_size);
  636. recvSz[i] = end - start;
  637. recvOff[i] = start;
  638. }
  639. Alltoall(recvSz.begin(), 1, sendSz.begin(), 1);
  640. sendOff[0] = 0;
  641. omp_par::scan(sendSz.begin(), sendOff.begin(), npes);
  642. assert(sendOff[npes - 1] + sendSz[npes - 1] == send_size);
  643. Alltoallv(recv_indx.begin(), recvSz.begin(), recvOff.begin(), send_indx.begin(), sendSz.begin(), sendOff.begin());
  644. #pragma omp parallel for schedule(static)
  645. for (Long i = 0; i < send_size; i++) {
  646. assert(send_indx[i] >= glb_scan[rank]);
  647. send_indx[i] -= glb_scan[rank];
  648. assert(send_indx[i] < send_size);
  649. }
  650. }
  651. Vector<Type> send_buff;
  652. { // Prepare send buffer
  653. send_buff.ReInit(send_size * data_dim);
  654. ConstIterator<Type> data = data_.begin();
  655. #pragma omp parallel for schedule(static)
  656. for (Long i = 0; i < send_size; i++) {
  657. Long src_indx = send_indx[i] * data_dim;
  658. Long trg_indx = i * data_dim;
  659. for (Long j = 0; j < data_dim; j++) send_buff[trg_indx + j] = data[src_indx + j];
  660. }
  661. }
  662. Vector<Type> recv_buff;
  663. { // All2Allv
  664. recv_buff.ReInit(recv_size * data_dim);
  665. #pragma omp parallel for schedule(static)
  666. for (Integer i = 0; i < npes; i++) {
  667. sendSz[i] *= data_dim;
  668. sendOff[i] *= data_dim;
  669. recvSz[i] *= data_dim;
  670. recvOff[i] *= data_dim;
  671. }
  672. Alltoallv(send_buff.begin(), sendSz.begin(), sendOff.begin(), recv_buff.begin(), recvSz.begin(), recvOff.begin());
  673. }
  674. { // Build output data.
  675. data_.ReInit(recv_size * data_dim);
  676. Iterator<Type> data = data_.begin();
  677. #pragma omp parallel for schedule(static)
  678. for (Long i = 0; i < recv_size; i++) {
  679. Long src_indx = i * data_dim;
  680. Long trg_indx = psorted[i].data * data_dim;
  681. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = recv_buff[src_indx + j];
  682. }
  683. }
  684. }
  685. template <class Type> void Comm::ScatterReverse(Vector<Type>& data_, const Vector<Long>& scatter_index_, Long loc_size_) const {
  686. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  687. typedef SortPair<Long, Long> Pair_t;
  688. Integer npes = Size(), rank = Rank();
  689. Long data_dim = 0;
  690. Long send_size = 0;
  691. Long recv_size = 0;
  692. { // Set data_dim, send_size, recv_size
  693. recv_size = loc_size_;
  694. StaticArray<Long, 3> glb_size;
  695. StaticArray<Long, 3> loc_size;
  696. loc_size[0] = data_.Dim();
  697. loc_size[1] = scatter_index_.Dim();
  698. loc_size[2] = recv_size;
  699. Allreduce<Long>(loc_size, glb_size, 3, CommOp::SUM);
  700. if (glb_size[0] == 0 || glb_size[1] == 0) return; // Nothing to be done.
  701. SCTL_ASSERT(glb_size[0] % glb_size[1] == 0);
  702. data_dim = glb_size[0] / glb_size[1];
  703. SCTL_ASSERT(loc_size[0] % data_dim == 0);
  704. send_size = loc_size[0] / data_dim;
  705. if (glb_size[0] != glb_size[2] * data_dim) {
  706. recv_size = (((rank + 1) * (glb_size[0] / data_dim)) / npes) - ((rank * (glb_size[0] / data_dim)) / npes);
  707. }
  708. }
  709. if (npes == 1) { // Scatter directly
  710. Vector<Type> data;
  711. data.ReInit(recv_size * data_dim);
  712. #pragma omp parallel for schedule(static)
  713. for (Long i = 0; i < recv_size; i++) {
  714. Long src_indx = i * data_dim;
  715. Long trg_indx = scatter_index_[i] * data_dim;
  716. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = data_[src_indx + j];
  717. }
  718. data_.Swap(data);
  719. return;
  720. }
  721. Vector<Long> scatter_index;
  722. {
  723. StaticArray<Long, 2> glb_rank;
  724. StaticArray<Long, 3> glb_size;
  725. StaticArray<Long, 2> loc_size;
  726. loc_size[0] = data_.Dim() / data_dim;
  727. loc_size[1] = scatter_index_.Dim();
  728. Scan<Long>(loc_size, glb_rank, 2, CommOp::SUM);
  729. Allreduce<Long>(loc_size, glb_size, 2, CommOp::SUM);
  730. SCTL_ASSERT(glb_size[0] == glb_size[1]);
  731. glb_rank[0] -= loc_size[0];
  732. glb_rank[1] -= loc_size[1];
  733. Vector<Long> glb_scan0(npes + 1);
  734. Vector<Long> glb_scan1(npes + 1);
  735. Allgather<Long>(glb_rank + 0, 1, glb_scan0.begin(), 1);
  736. Allgather<Long>(glb_rank + 1, 1, glb_scan1.begin(), 1);
  737. glb_scan0[npes] = glb_size[0];
  738. glb_scan1[npes] = glb_size[1];
  739. if (loc_size[0] != loc_size[1] || glb_rank[0] != glb_rank[1]) { // Repartition scatter_index
  740. scatter_index.ReInit(loc_size[0]);
  741. Vector<Long> send_dsp(npes + 1);
  742. Vector<Long> recv_dsp(npes + 1);
  743. #pragma omp parallel for schedule(static)
  744. for (Integer i = 0; i <= npes; i++) {
  745. send_dsp[i] = std::min(std::max(glb_scan0[i], glb_rank[1]), glb_rank[1] + loc_size[1]) - glb_rank[1];
  746. recv_dsp[i] = std::min(std::max(glb_scan1[i], glb_rank[0]), glb_rank[0] + loc_size[0]) - glb_rank[0];
  747. }
  748. // Long commCnt=0;
  749. Vector<Long> send_cnt(npes + 0);
  750. Vector<Long> recv_cnt(npes + 0);
  751. #pragma omp parallel for schedule(static) // reduction(+:commCnt)
  752. for (Integer i = 0; i < npes; i++) {
  753. send_cnt[i] = send_dsp[i + 1] - send_dsp[i];
  754. recv_cnt[i] = recv_dsp[i + 1] - recv_dsp[i];
  755. // if(send_cnt[i] && i!=rank) commCnt++;
  756. // if(recv_cnt[i] && i!=rank) commCnt++;
  757. }
  758. void* mpi_req = Ialltoallv_sparse<Long>(scatter_index_.begin(), send_cnt.begin(), send_dsp.begin(), scatter_index.begin(), recv_cnt.begin(), recv_dsp.begin(), 0);
  759. Wait(mpi_req);
  760. } else {
  761. scatter_index.ReInit(scatter_index_.Dim(), (Iterator<Long>)scatter_index_.begin(), false);
  762. }
  763. }
  764. Vector<Long> glb_scan(npes);
  765. { // Global data size.
  766. Long glb_rank = 0;
  767. Scan(Ptr2ConstItr<Long>(&recv_size, 1), Ptr2Itr<Long>(&glb_rank, 1), 1, CommOp::SUM);
  768. glb_rank -= recv_size;
  769. Allgather(Ptr2ConstItr<Long>(&glb_rank, 1), 1, glb_scan.begin(), 1);
  770. }
  771. Vector<Pair_t> psorted(send_size);
  772. { // Sort scatter_index.
  773. #pragma omp parallel for schedule(static)
  774. for (Long i = 0; i < send_size; i++) {
  775. psorted[i].key = scatter_index[i];
  776. psorted[i].data = i;
  777. }
  778. omp_par::merge_sort(psorted.begin(), psorted.begin() + send_size);
  779. }
  780. Vector<Long> recv_indx(recv_size);
  781. Vector<Long> send_indx(send_size);
  782. Vector<Long> sendSz(npes);
  783. Vector<Long> sendOff(npes);
  784. Vector<Long> recvSz(npes);
  785. Vector<Long> recvOff(npes);
  786. { // Exchange send, recv indices.
  787. #pragma omp parallel for schedule(static)
  788. for (Long i = 0; i < send_size; i++) {
  789. send_indx[i] = psorted[i].key;
  790. }
  791. #pragma omp parallel for schedule(static)
  792. for (Integer i = 0; i < npes; i++) {
  793. Long start = std::lower_bound(send_indx.begin(), send_indx.begin() + send_size, glb_scan[i]) - send_indx.begin();
  794. Long end = (i + 1 < npes ? std::lower_bound(send_indx.begin(), send_indx.begin() + send_size, glb_scan[i + 1]) - send_indx.begin() : send_size);
  795. sendSz[i] = end - start;
  796. sendOff[i] = start;
  797. }
  798. Alltoall(sendSz.begin(), 1, recvSz.begin(), 1);
  799. recvOff[0] = 0;
  800. omp_par::scan(recvSz.begin(), recvOff.begin(), npes);
  801. assert(recvOff[npes - 1] + recvSz[npes - 1] == recv_size);
  802. Alltoallv(send_indx.begin(), sendSz.begin(), sendOff.begin(), recv_indx.begin(), recvSz.begin(), recvOff.begin());
  803. #pragma omp parallel for schedule(static)
  804. for (Long i = 0; i < recv_size; i++) {
  805. assert(recv_indx[i] >= glb_scan[rank]);
  806. recv_indx[i] -= glb_scan[rank];
  807. assert(recv_indx[i] < recv_size);
  808. }
  809. }
  810. Vector<Type> send_buff;
  811. { // Prepare send buffer
  812. send_buff.ReInit(send_size * data_dim);
  813. ConstIterator<Type> data = data_.begin();
  814. #pragma omp parallel for schedule(static)
  815. for (Long i = 0; i < send_size; i++) {
  816. Long src_indx = psorted[i].data * data_dim;
  817. Long trg_indx = i * data_dim;
  818. for (Long j = 0; j < data_dim; j++) send_buff[trg_indx + j] = data[src_indx + j];
  819. }
  820. }
  821. Vector<Type> recv_buff;
  822. { // All2Allv
  823. recv_buff.ReInit(recv_size * data_dim);
  824. #pragma omp parallel for schedule(static)
  825. for (Integer i = 0; i < npes; i++) {
  826. sendSz[i] *= data_dim;
  827. sendOff[i] *= data_dim;
  828. recvSz[i] *= data_dim;
  829. recvOff[i] *= data_dim;
  830. }
  831. Alltoallv(send_buff.begin(), sendSz.begin(), sendOff.begin(), recv_buff.begin(), recvSz.begin(), recvOff.begin());
  832. }
  833. { // Build output data.
  834. data_.ReInit(recv_size * data_dim);
  835. Iterator<Type> data = data_.begin();
  836. #pragma omp parallel for schedule(static)
  837. for (Long i = 0; i < recv_size; i++) {
  838. Long src_indx = i * data_dim;
  839. Long trg_indx = recv_indx[i] * data_dim;
  840. for (Long j = 0; j < data_dim; j++) data[trg_indx + j] = recv_buff[src_indx + j];
  841. }
  842. }
  843. }
  844. #ifdef SCTL_HAVE_MPI
  845. inline Vector<MPI_Request>* Comm::NewReq() const {
  846. if (req.empty()) req.push(new Vector<MPI_Request>);
  847. Vector<MPI_Request>& request = *(Vector<MPI_Request>*)req.top();
  848. req.pop();
  849. return &request;
  850. }
  851. inline void Comm::Init(const MPI_Comm mpi_comm) {
  852. #pragma omp critical(SCTL_COMM_DUP)
  853. MPI_Comm_dup(mpi_comm, &mpi_comm_);
  854. MPI_Comm_rank(mpi_comm_, &mpi_rank_);
  855. MPI_Comm_size(mpi_comm_, &mpi_size_);
  856. }
  857. inline void Comm::DelReq(Vector<MPI_Request>* req_ptr) const {
  858. if (req_ptr) req.push(req_ptr);
  859. }
  860. #define SCTL_HS_MPIDATATYPE(CTYPE, MPITYPE) \
  861. template <> class Comm::CommDatatype<CTYPE> { \
  862. public: \
  863. static MPI_Datatype value() { return MPITYPE; } \
  864. static MPI_Op sum() { return MPI_SUM; } \
  865. static MPI_Op min() { return MPI_MIN; } \
  866. static MPI_Op max() { return MPI_MAX; } \
  867. }
  868. SCTL_HS_MPIDATATYPE(short, MPI_SHORT);
  869. SCTL_HS_MPIDATATYPE(int, MPI_INT);
  870. SCTL_HS_MPIDATATYPE(long, MPI_LONG);
  871. SCTL_HS_MPIDATATYPE(unsigned short, MPI_UNSIGNED_SHORT);
  872. SCTL_HS_MPIDATATYPE(unsigned int, MPI_UNSIGNED);
  873. SCTL_HS_MPIDATATYPE(unsigned long, MPI_UNSIGNED_LONG);
  874. SCTL_HS_MPIDATATYPE(float, MPI_FLOAT);
  875. SCTL_HS_MPIDATATYPE(double, MPI_DOUBLE);
  876. SCTL_HS_MPIDATATYPE(long double, MPI_LONG_DOUBLE);
  877. SCTL_HS_MPIDATATYPE(long long, MPI_LONG_LONG_INT);
  878. SCTL_HS_MPIDATATYPE(char, MPI_CHAR);
  879. SCTL_HS_MPIDATATYPE(unsigned char, MPI_UNSIGNED_CHAR);
  880. #undef SCTL_HS_MPIDATATYPE
  881. #endif
  882. template <class Type> void Comm::HyperQuickSort(const Vector<Type>& arr_, Vector<Type>& SortedElem) const { // O( ((N/p)+log(p))*(log(N/p)+log(p)) )
  883. static_assert(std::is_trivially_copyable<Type>::value, "Data is not trivially copyable!");
  884. #ifdef SCTL_HAVE_MPI
  885. Integer npes, myrank, omp_p;
  886. { // Get comm size and rank.
  887. npes = Size();
  888. myrank = Rank();
  889. omp_p = omp_get_max_threads();
  890. }
  891. srand(myrank);
  892. Long totSize;
  893. { // Local and global sizes. O(log p)
  894. Long nelem = arr_.Dim();
  895. Allreduce<Long>(Ptr2ConstItr<Long>(&nelem, 1), Ptr2Itr<Long>(&totSize, 1), 1, CommOp::SUM);
  896. }
  897. if (npes == 1) { // SortedElem <--- local_sort(arr_)
  898. SortedElem = arr_;
  899. omp_par::merge_sort(SortedElem.begin(), SortedElem.end());
  900. return;
  901. }
  902. Vector<Type> arr;
  903. { // arr <-- local_sort(arr_)
  904. arr = arr_;
  905. omp_par::merge_sort(arr.begin(), arr.end());
  906. }
  907. Vector<Type> nbuff, nbuff_ext, rbuff, rbuff_ext; // Allocate memory.
  908. MPI_Comm comm = mpi_comm_; // Copy comm
  909. bool free_comm = false; // Flag to free comm.
  910. // Binary split and merge in each iteration.
  911. while (npes > 1 && totSize > 0) { // O(log p) iterations.
  912. Type split_key;
  913. Long totSize_new;
  914. { // Determine split_key. O( log(N/p) + log(p) )
  915. Integer glb_splt_count;
  916. Vector<Type> glb_splitters;
  917. { // Take random splitters. glb_splt_count = const = 100~1000
  918. Integer splt_count;
  919. Long nelem = arr.Dim();
  920. { // Set splt_coun. O( 1 ) -- Let p * splt_count = t
  921. splt_count = (100 * nelem) / totSize;
  922. if (npes > 100) splt_count = (drand48() * totSize) < (100 * nelem) ? 1 : 0;
  923. if (splt_count > nelem) splt_count = nelem;
  924. MPI_Allreduce (&splt_count, &glb_splt_count, 1, CommDatatype<Integer>::value(), CommDatatype<Integer>::sum(), comm);
  925. if (!glb_splt_count) splt_count = std::min<Long>(1, nelem);
  926. MPI_Allreduce (&splt_count, &glb_splt_count, 1, CommDatatype<Integer>::value(), CommDatatype<Integer>::sum(), comm);
  927. SCTL_ASSERT(glb_splt_count);
  928. }
  929. Vector<Type> splitters(splt_count);
  930. for (Integer i = 0; i < splt_count; i++) {
  931. splitters[i] = arr[rand() % nelem];
  932. }
  933. Vector<Integer> glb_splt_cnts(npes), glb_splt_disp(npes);
  934. { // Set glb_splt_cnts, glb_splt_disp
  935. MPI_Allgather(&splt_count, 1, CommDatatype<Integer>::value(), &glb_splt_cnts[0], 1, CommDatatype<Integer>::value(), comm);
  936. glb_splt_disp[0] = 0;
  937. omp_par::scan(glb_splt_cnts.begin(), glb_splt_disp.begin(), npes);
  938. SCTL_ASSERT(glb_splt_count == glb_splt_cnts[npes - 1] + glb_splt_disp[npes - 1]);
  939. }
  940. { // Gather all splitters. O( log(p) )
  941. glb_splitters.ReInit(glb_splt_count);
  942. Vector<int> glb_splt_cnts_(npes), glb_splt_disp_(npes);
  943. for (Integer i = 0; i < npes; i++) {
  944. glb_splt_cnts_[i] = glb_splt_cnts[i];
  945. glb_splt_disp_[i] = glb_splt_disp[i];
  946. }
  947. MPI_Allgatherv((splt_count ? &splitters[0] : nullptr), splt_count, CommDatatype<Type>::value(), &glb_splitters[0], &glb_splt_cnts_[0], &glb_splt_disp_[0], CommDatatype<Type>::value(), comm);
  948. }
  949. }
  950. // Determine split key. O( log(N/p) + log(p) )
  951. Vector<Long> lrank(glb_splt_count);
  952. { // Compute local rank
  953. #pragma omp parallel for schedule(static)
  954. for (Integer i = 0; i < glb_splt_count; i++) {
  955. lrank[i] = std::lower_bound(arr.begin(), arr.end(), glb_splitters[i]) - arr.begin();
  956. }
  957. }
  958. Vector<Long> grank(glb_splt_count);
  959. { // Compute global rank
  960. MPI_Allreduce(&lrank[0], &grank[0], glb_splt_count, CommDatatype<Long>::value(), CommDatatype<Long>::sum(), comm);
  961. }
  962. { // Determine split_key, totSize_new
  963. Integer splitter_idx = 0;
  964. for (Integer i = 0; i < glb_splt_count; i++) {
  965. if (labs(grank[i] - totSize / 2) < labs(grank[splitter_idx] - totSize / 2)) {
  966. splitter_idx = i;
  967. }
  968. }
  969. split_key = glb_splitters[splitter_idx];
  970. if (myrank <= (npes - 1) / 2)
  971. totSize_new = grank[splitter_idx];
  972. else
  973. totSize_new = totSize - grank[splitter_idx];
  974. // double err=(((double)grank[splitter_idx])/(totSize/2))-1.0;
  975. // if(fabs<double>(err)<0.01 || npes<=16) break;
  976. // else if(!myrank) std::cout<<err<<'\n';
  977. }
  978. }
  979. Integer split_id = (npes - 1) / 2;
  980. { // Split problem into two. O( N/p )
  981. Integer partner;
  982. { // Set partner
  983. partner = myrank + (split_id+1) * (myrank<=split_id ? 1 : -1);
  984. if (partner >= npes) partner = npes - 1;
  985. assert(partner >= 0);
  986. }
  987. bool extra_partner = (npes % 2 == 1 && myrank == npes - 1);
  988. Long ssize = 0, lsize = 0;
  989. ConstIterator<Type> sbuff, lbuff;
  990. { // Set ssize, lsize, sbuff, lbuff
  991. Long split_indx = std::lower_bound(arr.begin(), arr.end(), split_key) - arr.begin();
  992. ssize = (myrank > split_id ? split_indx : arr.Dim() - split_indx);
  993. sbuff = (myrank > split_id ? arr.begin() : arr.begin() + split_indx);
  994. lsize = (myrank <= split_id ? split_indx : arr.Dim() - split_indx);
  995. lbuff = (myrank <= split_id ? arr.begin() : arr.begin() + split_indx);
  996. }
  997. Long rsize = 0, ext_rsize = 0;
  998. { // Get rsize, ext_rsize
  999. Long ext_ssize = 0;
  1000. MPI_Status status;
  1001. MPI_Sendrecv(&ssize, 1, CommDatatype<Long>::value(), partner, 0, &rsize, 1, CommDatatype<Long>::value(), partner, 0, comm, &status);
  1002. if (extra_partner) MPI_Sendrecv(&ext_ssize, 1, CommDatatype<Long>::value(), split_id, 0, &ext_rsize, 1, CommDatatype<Long>::value(), split_id, 0, comm, &status);
  1003. }
  1004. { // Exchange data.
  1005. rbuff.ReInit(rsize);
  1006. rbuff_ext.ReInit(ext_rsize);
  1007. MPI_Status status;
  1008. MPI_Sendrecv((ssize ? &sbuff[0] : nullptr), ssize, CommDatatype<Type>::value(), partner, 0, (rsize ? &rbuff[0] : nullptr), rsize, CommDatatype<Type>::value(), partner, 0, comm, &status);
  1009. if (extra_partner) MPI_Sendrecv(nullptr, 0, CommDatatype<Type>::value(), split_id, 0, (ext_rsize ? &rbuff_ext[0] : nullptr), ext_rsize, CommDatatype<Type>::value(), split_id, 0, comm, &status);
  1010. }
  1011. { // nbuff <-- merge(lbuff, rbuff, rbuff_ext)
  1012. nbuff.ReInit(lsize + rsize);
  1013. omp_par::merge<ConstIterator<Type>>(lbuff, (lbuff + lsize), rbuff.begin(), rbuff.begin() + rsize, nbuff.begin(), omp_p, std::less<Type>());
  1014. if (ext_rsize > 0) {
  1015. if (nbuff.Dim() > 0) {
  1016. nbuff_ext.ReInit(lsize + rsize + ext_rsize);
  1017. omp_par::merge(nbuff.begin(), nbuff.begin() + (lsize + rsize), rbuff_ext.begin(), rbuff_ext.begin() + ext_rsize, nbuff_ext.begin(), omp_p, std::less<Type>());
  1018. nbuff.Swap(nbuff_ext);
  1019. nbuff_ext.ReInit(0);
  1020. } else {
  1021. nbuff.Swap(rbuff_ext);
  1022. }
  1023. }
  1024. }
  1025. // Copy new data.
  1026. totSize = totSize_new;
  1027. arr.Swap(nbuff);
  1028. }
  1029. { // Split comm. O( log(p) ) ??
  1030. MPI_Comm scomm;
  1031. #pragma omp critical(SCTL_COMM_DUP)
  1032. MPI_Comm_split(comm, myrank <= split_id, myrank, &scomm);
  1033. #pragma omp critical(SCTL_COMM_DUP)
  1034. if (free_comm) MPI_Comm_free(&comm);
  1035. comm = scomm;
  1036. free_comm = true;
  1037. npes = (myrank <= split_id ? split_id + 1 : npes - split_id - 1);
  1038. myrank = (myrank <= split_id ? myrank : myrank - split_id - 1);
  1039. }
  1040. }
  1041. #pragma omp critical(SCTL_COMM_DUP)
  1042. if (free_comm) MPI_Comm_free(&comm);
  1043. SortedElem = arr;
  1044. PartitionW<Type>(SortedElem);
  1045. #else
  1046. SortedElem = arr_;
  1047. std::sort(SortedElem.begin(), SortedElem.begin() + SortedElem.Dim());
  1048. #endif
  1049. }
  1050. } // end namespace