tree.txx 41 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019
  1. #include <vector>
  2. namespace SCTL_NAMESPACE {
  3. template <Integer DIM> constexpr Integer Tree<DIM>::Dim() {
  4. return DIM;
  5. }
  6. template <Integer DIM> Tree<DIM>::Tree(const Comm& comm_) : comm(comm_) {
  7. Integer rank = comm.Rank();
  8. Integer np = comm.Size();
  9. Vector<double> coord;
  10. { // Set coord
  11. Long N0 = 1;
  12. while (sctl::pow<DIM,Long>(N0) < np) N0++;
  13. Long N = sctl::pow<DIM,Long>(N0);
  14. Long start = N * (rank+0) / np;
  15. Long end = N * (rank+1) / np;
  16. coord.ReInit((end-start)*DIM);
  17. for (Long i = start; i < end; i++) {
  18. Long idx = i;
  19. for (Integer k = 0; k < DIM; k++) {
  20. coord[(i-start)*DIM+k] = (idx % N0) / (double)N0;
  21. idx /= N0;
  22. }
  23. }
  24. }
  25. this->UpdateRefinement(coord);
  26. }
  27. template <Integer DIM> Tree<DIM>::~Tree() {
  28. #ifdef SCTL_MEMDEBUG
  29. for (auto& pair : node_data) {
  30. SCTL_ASSERT(node_cnt.find(pair.first) != node_cnt.end());
  31. }
  32. #endif
  33. }
  34. template <Integer DIM> const Vector<Morton<DIM>>& Tree<DIM>::GetPartitionMID() const {
  35. return mins;
  36. }
  37. template <Integer DIM> const Vector<Morton<DIM>>& Tree<DIM>::GetNodeMID() const {
  38. return node_mid;
  39. }
  40. template <Integer DIM> const Vector<typename Tree<DIM>::NodeAttr>& Tree<DIM>::GetNodeAttr() const {
  41. return node_attr;
  42. }
  43. template <Integer DIM> const Vector<typename Tree<DIM>::NodeLists>& Tree<DIM>::GetNodeLists() const {
  44. return node_lst;
  45. }
  46. template <Integer DIM> const Comm& Tree<DIM>::GetComm() const {
  47. return comm;
  48. }
  49. template <Integer DIM> template <class Real> void Tree<DIM>::UpdateRefinement(const Vector<Real>& coord, Long M, bool balance21, bool periodic) {
  50. Integer np = comm.Size();
  51. Integer rank = comm.Rank();
  52. Vector<Morton<DIM>> node_mid_orig;
  53. Long start_idx_orig, end_idx_orig;
  54. if (mins.Dim()) { // Set start_idx_orig, end_idx_orig
  55. start_idx_orig = std::lower_bound(node_mid.begin(), node_mid.end(), mins[rank]) - node_mid.begin();
  56. end_idx_orig = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
  57. node_mid_orig.ReInit(end_idx_orig - start_idx_orig, node_mid.begin() + start_idx_orig, true);
  58. } else {
  59. start_idx_orig = 0;
  60. end_idx_orig = 0;
  61. }
  62. auto coarsest_ancestor_mid = [](const Morton<DIM>& m0) {
  63. Morton<DIM> md;
  64. Integer d0 = m0.Depth();
  65. for (Integer d = 0; d <= d0; d++) {
  66. md = m0.Ancestor(d);
  67. if (md.Ancestor(d0) == m0) break;
  68. }
  69. return md;
  70. };
  71. Morton<DIM> pt_mid0;
  72. Vector<Morton<DIM>> pt_mid;
  73. { // Construct sorted pt_mid
  74. Long Npt = coord.Dim() / DIM;
  75. pt_mid.ReInit(Npt);
  76. for (Long i = 0; i < Npt; i++) {
  77. pt_mid[i] = Morton<DIM>(coord.begin() + i*DIM);
  78. }
  79. Vector<Morton<DIM>> sorted_mid;
  80. comm.HyperQuickSort(pt_mid, sorted_mid);
  81. pt_mid.Swap(sorted_mid);
  82. SCTL_ASSERT(pt_mid.Dim());
  83. pt_mid0 = pt_mid[0];
  84. }
  85. { // Update M = global_min(pt_mid.Dim(), M)
  86. Long M0, M1, Npt = pt_mid.Dim();
  87. comm.Allreduce(Ptr2ConstItr<Long>(&M,1), Ptr2Itr<Long>(&M0,1), 1, Comm::CommOp::MIN);
  88. comm.Allreduce(Ptr2ConstItr<Long>(&Npt,1), Ptr2Itr<Long>(&M1,1), 1, Comm::CommOp::MIN);
  89. M = std::min(M0,M1);
  90. SCTL_ASSERT(M > 0);
  91. }
  92. { // pt_mid <-- [M points from rank-1; pt_mid; M points from rank+1]
  93. Long send_size0 = (rank+1<np ? M : 0);
  94. Long send_size1 = (rank > 0 ? M : 0);
  95. Long recv_size0 = (rank > 0 ? M : 0);
  96. Long recv_size1 = (rank+1<np ? M : 0);
  97. Vector<Morton<DIM>> pt_mid_(recv_size0 + pt_mid.Dim() + recv_size1);
  98. memcopy(pt_mid_.begin()+recv_size0, pt_mid.begin(), pt_mid.Dim());
  99. void* recv_req0 = comm.Irecv(pt_mid_.begin(), recv_size0, (rank+np-1)%np, 0);
  100. void* recv_req1 = comm.Irecv(pt_mid_.begin() + recv_size0 + pt_mid.Dim(), recv_size1, (rank+1)%np, 1);
  101. void* send_req0 = comm.Isend(pt_mid .begin() + pt_mid.Dim() - send_size0, send_size0, (rank+1)%np, 0);
  102. void* send_req1 = comm.Isend(pt_mid .begin(), send_size1, (rank+np-1)%np, 1);
  103. comm.Wait(recv_req0);
  104. comm.Wait(recv_req1);
  105. comm.Wait(send_req0);
  106. comm.Wait(send_req1);
  107. pt_mid.Swap(pt_mid_);
  108. }
  109. { // Build linear MortonID tree from pt_mid
  110. node_mid.ReInit(0);
  111. Long idx = 0;
  112. Morton<DIM> m0;
  113. Morton<DIM> mend = Morton<DIM>().Next();
  114. while (m0 < mend) {
  115. Integer d = m0.Depth();
  116. Morton<DIM> m1 = (idx + M < pt_mid.Dim() ? pt_mid[idx+M] : Morton<DIM>().Next());
  117. while (d < Morton<DIM>::MAX_DEPTH && m0.Ancestor(d) == m1.Ancestor(d)) {
  118. node_mid.PushBack(m0.Ancestor(d));
  119. d++;
  120. }
  121. m0 = m0.Ancestor(d);
  122. node_mid.PushBack(m0);
  123. m0 = m0.Next();
  124. idx = std::lower_bound(pt_mid.begin(), pt_mid.end(), m0) - pt_mid.begin();
  125. }
  126. }
  127. { // Set mins
  128. mins.ReInit(np);
  129. Long min_idx = std::lower_bound(node_mid.begin(), node_mid.end(), pt_mid0) - node_mid.begin() - 1;
  130. if (!rank || min_idx < 0) min_idx = 0;
  131. Morton<DIM> m0 = coarsest_ancestor_mid(node_mid[min_idx]);
  132. comm.Allgather(Ptr2ConstItr<Morton<DIM>>(&m0,1), 1, mins.begin(), 1);
  133. }
  134. if (balance21) { // 2:1 balance refinement // TODO: optimize
  135. Vector<Morton<DIM>> parent_mid;
  136. { // add balancing Morton IDs
  137. Vector<std::set<Morton<DIM>>> parent_mid_set(Morton<DIM>::MAX_DEPTH+1);
  138. Vector<Morton<DIM>> nlst;
  139. for (const auto& m0 : node_mid) {
  140. Integer d0 = m0.Depth();
  141. parent_mid_set[m0.Depth()].insert(m0.Ancestor(d0-1));
  142. }
  143. for (Integer d = Morton<DIM>::MAX_DEPTH; d > 0; d--) {
  144. for (const auto& m : parent_mid_set[d]) {
  145. m.NbrList(nlst, d-1, periodic);
  146. parent_mid_set[d-1].insert(nlst.begin(), nlst.end());
  147. parent_mid.PushBack(m);
  148. }
  149. }
  150. }
  151. Vector<Morton<DIM>> parent_mid_sorted;
  152. { // sort and repartition
  153. comm.HyperQuickSort(parent_mid, parent_mid_sorted);
  154. comm.PartitionS(parent_mid_sorted, mins[comm.Rank()]);
  155. }
  156. Vector<Morton<DIM>> tmp_mid;
  157. { // add children
  158. Vector<Morton<DIM>> clst;
  159. tmp_mid.PushBack(Morton<DIM>()); // include root node
  160. for (Long i = 0; i < parent_mid_sorted.Dim(); i++) {
  161. if (i+1 == parent_mid_sorted.Dim() || parent_mid_sorted[i] != parent_mid_sorted[i+1]) {
  162. const auto& m = parent_mid_sorted[i];
  163. tmp_mid.PushBack(m);
  164. m.Children(clst);
  165. for (const auto& c : clst) tmp_mid.PushBack(c);
  166. }
  167. }
  168. auto insert_ancestor_children = [](Vector<Morton<DIM>>& mvec, const Morton<DIM>& m0) {
  169. Integer d0 = m0.Depth();
  170. Vector<Morton<DIM>> clst;
  171. for (Integer d = 0; d < d0; d++) {
  172. m0.Ancestor(d).Children(clst);
  173. for (const auto& m : clst) mvec.PushBack(m);
  174. }
  175. };
  176. insert_ancestor_children(tmp_mid, mins[rank]);
  177. omp_par::merge_sort(tmp_mid.begin(), tmp_mid.end());
  178. }
  179. node_mid.ReInit(0);
  180. for (Long i = 0; i < tmp_mid.Dim(); i++) { // remove duplicates
  181. if (i+1 == tmp_mid.Dim() || tmp_mid[i] != tmp_mid[i+1]) {
  182. node_mid.PushBack(tmp_mid[i]);
  183. }
  184. }
  185. }
  186. { // Add place-holder for ghost nodes
  187. Long start_idx, end_idx;
  188. { // Set start_idx, end_idx
  189. start_idx = std::lower_bound(node_mid.begin(), node_mid.end(), mins[rank]) - node_mid.begin();
  190. end_idx = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
  191. }
  192. { // Set user_mid, user_cnt
  193. Vector<SortPair<Long,Morton<DIM>>> user_node_lst;
  194. Vector<Morton<DIM>> nlst;
  195. std::set<Long> user_procs;
  196. for (Long i = start_idx; i < end_idx; i++) {
  197. Morton<DIM> m0 = node_mid[i];
  198. Integer d0 = m0.Depth();
  199. m0.NbrList(nlst, std::max<Integer>(d0-2,0), periodic);
  200. user_procs.clear();
  201. for (const auto& m : nlst) {
  202. Morton<DIM> m_start = m.DFD();
  203. Morton<DIM> m_end = m.Next();
  204. Integer p_start = std::lower_bound(mins.begin(), mins.end(), m_start) - mins.begin() - 1;
  205. Integer p_end = std::lower_bound(mins.begin(), mins.end(), m_end ) - mins.begin();
  206. SCTL_ASSERT(0 <= p_start);
  207. SCTL_ASSERT(p_start < p_end);
  208. SCTL_ASSERT(p_end <= np);
  209. for (Long p = p_start; p < p_end; p++) {
  210. if (p != rank) user_procs.insert(p);
  211. }
  212. }
  213. for (const auto p : user_procs) {
  214. SortPair<Long,Morton<DIM>> pair;
  215. pair.key = p;
  216. pair.data = m0;
  217. user_node_lst.PushBack(pair);
  218. }
  219. }
  220. omp_par::merge_sort(user_node_lst.begin(), user_node_lst.end());
  221. user_cnt.ReInit(np);
  222. user_mid.ReInit(user_node_lst.Dim());
  223. for (Integer i = 0; i < np; i++) {
  224. SortPair<Long,Morton<DIM>> pair_start, pair_end;
  225. pair_start.key = i;
  226. pair_end.key = i+1;
  227. Long cnt_start = std::lower_bound(user_node_lst.begin(), user_node_lst.end(), pair_start) - user_node_lst.begin();
  228. Long cnt_end = std::lower_bound(user_node_lst.begin(), user_node_lst.end(), pair_end ) - user_node_lst.begin();
  229. user_cnt[i] = cnt_end - cnt_start;
  230. for (Long j = cnt_start; j < cnt_end; j++) {
  231. user_mid[j] = user_node_lst[j].data;
  232. }
  233. std::sort(user_mid.begin() + cnt_start, user_mid.begin() + cnt_end);
  234. }
  235. }
  236. Vector<Morton<DIM>> ghost_mid;
  237. { // SendRecv user_mid
  238. const Vector<Long>& send_cnt = user_cnt;
  239. Vector<Long> send_dsp(np);
  240. scan(send_dsp, send_cnt);
  241. Vector<Long> recv_cnt(np), recv_dsp(np);
  242. comm.Alltoall(send_cnt.begin(), 1, recv_cnt.begin(), 1);
  243. scan(recv_dsp, recv_cnt);
  244. const Vector<Morton<DIM>>& send_mid = user_mid;
  245. Long Nsend = send_dsp[np-1] + send_cnt[np-1];
  246. Long Nrecv = recv_dsp[np-1] + recv_cnt[np-1];
  247. SCTL_ASSERT(send_mid.Dim() == Nsend);
  248. ghost_mid.ReInit(Nrecv);
  249. comm.Alltoallv(send_mid.begin(), send_cnt.begin(), send_dsp.begin(), ghost_mid.begin(), recv_cnt.begin(), recv_dsp.begin());
  250. }
  251. { // Update node_mid <-- ghost_mid + node_mid
  252. Vector<Morton<DIM>> new_mid(end_idx-start_idx + ghost_mid.Dim());
  253. Long Nsplit = std::lower_bound(ghost_mid.begin(), ghost_mid.end(), mins[rank]) - ghost_mid.begin();
  254. for (Long i = 0; i < Nsplit; i++) {
  255. new_mid[i] = ghost_mid[i];
  256. }
  257. for (Long i = 0; i < end_idx - start_idx; i++) {
  258. new_mid[Nsplit + i] = node_mid[start_idx + i];
  259. }
  260. for (Long i = Nsplit; i < ghost_mid.Dim(); i++) {
  261. new_mid[end_idx - start_idx + i] = ghost_mid[i];
  262. }
  263. node_mid.Swap(new_mid);
  264. }
  265. }
  266. { // Set node_mid, node_attr
  267. Morton<DIM> m0 = (rank ? mins[rank] : Morton<DIM>() );
  268. Morton<DIM> m1 = (rank+1<np ? mins[rank+1] : Morton<DIM>().Next());
  269. Long Nnodes = node_mid.Dim();
  270. node_attr.ReInit(Nnodes);
  271. for (Long i = 0; i < Nnodes; i++) {
  272. node_attr[i].Leaf = !(i+1<Nnodes && node_mid[i].isAncestor(node_mid[i+1]));
  273. node_attr[i].Ghost = (node_mid[i] < m0 || node_mid[i] >= m1);
  274. }
  275. }
  276. { // Set node_lst
  277. static constexpr Integer MAX_CHILD = (1u << DIM);
  278. static constexpr Integer MAX_NBRS = sctl::pow<DIM,Integer>(3);
  279. Long Nnodes = node_mid.Dim();
  280. node_lst.ReInit(Nnodes);
  281. Vector<Long> ancestors(Morton<DIM>::MAX_DEPTH);
  282. Vector<Long> child_cnt(Morton<DIM>::MAX_DEPTH);
  283. #pragma omp parallel for schedule(static)
  284. for (Long i = 0; i < Nnodes; i++) {
  285. node_lst[i].p2n = -1;
  286. node_lst[i].parent = -1;
  287. for (Integer j = 0; j < MAX_CHILD; j++) node_lst[i].child[j] = -1;
  288. for (Integer j = 0; j < MAX_NBRS; j++) node_lst[i].nbr[j] = -1;
  289. }
  290. for (Long i = 0; i < Nnodes; i++) { // Set parent_lst, child_lst_
  291. Integer depth = node_mid[i].Depth();
  292. ancestors[depth] = i;
  293. child_cnt[depth] = 0;
  294. if (depth) {
  295. Long p = ancestors[depth-1];
  296. Long& c = child_cnt[depth-1];
  297. node_lst[i].parent = p;
  298. node_lst[p].child[c] = i;
  299. node_lst[p].p2n = c;
  300. c++;
  301. }
  302. }
  303. Vector<Morton<DIM>> nlst;
  304. for (Long i = 0; i < Nnodes; i++) { // Set nbr-list // TODO: optimize this
  305. node_mid[i].NbrList(nlst, node_mid[i].Depth(), periodic);
  306. for (Long k = 0; k < nlst.Dim(); k++) {
  307. Long idx = std::lower_bound(node_mid.begin(), node_mid.end(), nlst[k]) - node_mid.begin();
  308. if (idx < node_mid.Dim() && node_mid[idx] == nlst[k]) node_lst[i].nbr[k] = idx;
  309. }
  310. }
  311. }
  312. if (0) { // Check tree
  313. Morton<DIM> m0;
  314. SCTL_ASSERT(node_mid.Dim() && m0 == node_mid[0]);
  315. for (Long i = 1; i < node_mid.Dim(); i++) {
  316. const auto& m = node_mid[i];
  317. if (m0.isAncestor(m)) m0 = m0.Ancestor(m0.Depth()+1);
  318. else m0 = m0.Next();
  319. SCTL_ASSERT(m0 == m);
  320. }
  321. SCTL_ASSERT(m0.Next() == Morton<DIM>().Next());
  322. }
  323. { // Update node_data, node_cnt
  324. Long start_idx, end_idx;
  325. { // Set start_idx, end_idx
  326. start_idx = std::lower_bound(node_mid.begin(), node_mid.end(), mins[rank]) - node_mid.begin();
  327. end_idx = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
  328. }
  329. comm.PartitionS(node_mid_orig, mins[comm.Rank()]);
  330. Vector<Long> new_cnt_range0(node_mid.Dim()), new_cnt_range1(node_mid.Dim());
  331. { // Set new_cnt_range0, new_cnt_range1
  332. for (Long i = 0; i < start_idx; i++) {
  333. new_cnt_range0[i] = 0;
  334. new_cnt_range1[i] = 0;
  335. }
  336. for (Long i = start_idx; i < end_idx; i++) {
  337. auto m0 = (node_mid[i+0]);
  338. auto m1 = (i+1==end_idx ? Morton<DIM>().Next() : (node_mid[i+1]));
  339. new_cnt_range0[i] = std::lower_bound(node_mid_orig.begin(), node_mid_orig.begin() + node_mid_orig.Dim(), m0) - node_mid_orig.begin();
  340. new_cnt_range1[i] = std::lower_bound(node_mid_orig.begin(), node_mid_orig.begin() + node_mid_orig.Dim(), m1) - node_mid_orig.begin();
  341. }
  342. for (Long i = end_idx; i < node_mid.Dim(); i++) {
  343. new_cnt_range0[i] = 0;
  344. new_cnt_range1[i] = 0;
  345. }
  346. }
  347. Vector<Long> cnt_tmp;
  348. Vector<char> data_tmp;
  349. for (const auto& pair : node_data) {
  350. const std::string& data_name = pair.first;
  351. Long dof;
  352. Iterator<Vector<char>> data_;
  353. Iterator<Vector<Long>> cnt_;
  354. GetData_(data_, cnt_, data_name);
  355. { // Set dof
  356. StaticArray<Long,2> Nl, Ng;
  357. Nl[0] = data_->Dim();
  358. Nl[1] = omp_par::reduce(cnt_->begin(), cnt_->Dim());
  359. comm.Allreduce((ConstIterator<Long>)Nl, (Iterator<Long>)Ng, 2, Comm::CommOp::SUM);
  360. dof = Ng[0] / std::max<Long>(Ng[1],1);
  361. SCTL_ASSERT(Nl[0] == Nl[1] * dof);
  362. SCTL_ASSERT(Ng[0] == Ng[1] * dof);
  363. }
  364. Long data_dsp = omp_par::reduce(cnt_->begin(), start_idx_orig);
  365. Long data_cnt = omp_par::reduce(cnt_->begin() + start_idx_orig, end_idx_orig - start_idx_orig);
  366. data_tmp.ReInit(data_cnt * dof, data_->begin() + data_dsp * dof, true);
  367. cnt_tmp.ReInit(end_idx_orig - start_idx_orig, cnt_->begin() + start_idx_orig, true);
  368. comm.PartitionN(cnt_tmp, node_mid_orig.Dim());
  369. cnt_->ReInit(node_mid.Dim());
  370. for (Long i = 0; i < node_mid.Dim(); i++) {
  371. Long sum = 0;
  372. Long j0 = new_cnt_range0[i];
  373. Long j1 = new_cnt_range1[i];
  374. for (Long j = j0; j < j1; j++) sum += cnt_tmp[j];
  375. cnt_[0][i] = sum;
  376. }
  377. SCTL_ASSERT(omp_par::reduce(cnt_->begin(), cnt_->Dim()) == omp_par::reduce(cnt_tmp.begin(), cnt_tmp.Dim()));
  378. Long Ndata = omp_par::reduce(cnt_->begin(), cnt_->Dim()) * dof;
  379. comm.PartitionN(data_tmp, Ndata);
  380. SCTL_ASSERT(data_tmp.Dim() == Ndata);
  381. data_->Swap(data_tmp);
  382. }
  383. }
  384. }
  385. template <Integer DIM> template <class ValueType> void Tree<DIM>::AddData(const std::string& name, const Vector<ValueType>& data, const Vector<Long>& cnt) {
  386. Long dof;
  387. { // Check dof
  388. StaticArray<Long,2> Nl, Ng;
  389. Nl[0] = data.Dim();
  390. Nl[1] = omp_par::reduce(cnt.begin(), cnt.Dim());
  391. comm.Allreduce((ConstIterator<Long>)Nl, (Iterator<Long>)Ng, 2, Comm::CommOp::SUM);
  392. dof = Ng[0] / std::max<Long>(Ng[1],1);
  393. SCTL_ASSERT(Nl[0] == Nl[1] * dof);
  394. SCTL_ASSERT(Ng[0] == Ng[1] * dof);
  395. }
  396. if (dof) SCTL_ASSERT(cnt.Dim() == node_mid.Dim());
  397. SCTL_ASSERT(node_data.find(name) == node_data.end());
  398. node_data[name].ReInit(data.Dim()*sizeof(ValueType), (Iterator<char>)data.begin(), true);
  399. node_cnt [name] = cnt;
  400. }
  401. template <Integer DIM> template <class ValueType> void Tree<DIM>::GetData(Vector<ValueType>& data, Vector<Long>& cnt, const std::string& name) const {
  402. const auto data_ = node_data.find(name);
  403. const auto cnt_ = node_cnt.find(name);
  404. SCTL_ASSERT(data_ != node_data.end());
  405. SCTL_ASSERT( cnt_ != node_cnt .end());
  406. data.ReInit(data_->second.Dim()/sizeof(ValueType), (Iterator<ValueType>)data_->second.begin(), false);
  407. SCTL_ASSERT(data.Dim()*(Long)sizeof(ValueType) == data_->second.Dim());
  408. cnt .ReInit( cnt_->second.Dim(), (Iterator<Long>)cnt_->second.begin(), false);
  409. }
  410. template <Integer DIM> template <class ValueType> void Tree<DIM>::ReduceBroadcast(const std::string& name) {
  411. Integer np = comm.Size();
  412. Integer rank = comm.Rank();
  413. Vector<Long> dsp;
  414. Iterator<Vector<char>> data_;
  415. Iterator<Vector<Long>> cnt_;
  416. GetData_(data_, cnt_, name);
  417. Vector<ValueType> data(data_->Dim()/sizeof(ValueType), (Iterator<ValueType>)data_->begin(), false);
  418. Vector<Long>& cnt = *cnt_;
  419. scan(dsp, cnt);
  420. Long dof;
  421. { // Set dof
  422. StaticArray<Long,2> Nl, Ng;
  423. Nl[0] = data.Dim();
  424. Nl[1] = omp_par::reduce(cnt.begin(), cnt.Dim());
  425. comm.Allreduce((ConstIterator<Long>)Nl, (Iterator<Long>)Ng, 2, Comm::CommOp::SUM);
  426. dof = Ng[0] / std::max<Long>(Ng[1],1);
  427. SCTL_ASSERT(Nl[0] == Nl[1] * dof);
  428. SCTL_ASSERT(Ng[0] == Ng[1] * dof);
  429. }
  430. { // Reduce
  431. Vector<Morton<DIM>> send_mid, recv_mid;
  432. Vector<Long> send_node_cnt(np), send_node_dsp(np);
  433. Vector<Long> recv_node_cnt(np), recv_node_dsp(np);
  434. { // Set send_mid, send_node_cnt, send_node_dsp, recv_mid, recv_node_cnt, recv_node_dsp
  435. { // Set send_mid
  436. Morton<DIM> m0 = mins[rank];
  437. for (Integer d = 0; d < m0.Depth(); d++) {
  438. send_mid.PushBack(m0.Ancestor(d));
  439. }
  440. }
  441. for (Integer p = 0; p < np; p++) {
  442. Long start_idx = std::lower_bound(send_mid.begin(), send_mid.end(), mins[p]) - send_mid.begin();
  443. Long end_idx = std::lower_bound(send_mid.begin(), send_mid.end(), (p+1==np ? Morton<DIM>().Next() : mins[p+1])) - send_mid.begin();
  444. send_node_cnt[p] = end_idx - start_idx;
  445. }
  446. scan(send_node_dsp, send_node_cnt);
  447. SCTL_ASSERT(send_node_dsp[np-1]+send_node_cnt[np-1] == send_mid.Dim());
  448. comm.Alltoall(send_node_cnt.begin(), 1, recv_node_cnt.begin(), 1);
  449. scan(recv_node_dsp, recv_node_cnt);
  450. recv_mid.ReInit(recv_node_dsp[np-1] + recv_node_cnt[np-1]);
  451. comm.Alltoallv(send_mid.begin(), send_node_cnt.begin(), send_node_dsp.begin(), recv_mid.begin(), recv_node_cnt.begin(), recv_node_dsp.begin());
  452. }
  453. Vector<Long> send_data_cnt, send_data_dsp;
  454. Vector<Long> recv_data_cnt, recv_data_dsp;
  455. { // Set send_data_cnt, send_data_dsp
  456. send_data_cnt.ReInit(send_mid.Dim());
  457. recv_data_cnt.ReInit(recv_mid.Dim());
  458. for (Long i = 0; i < send_mid.Dim(); i++) {
  459. Long idx = std::lower_bound(node_mid.begin(), node_mid.end(), send_mid[i]) - node_mid.begin();
  460. SCTL_ASSERT(send_mid[i] == node_mid[idx]);
  461. send_data_cnt[i] = cnt[idx];
  462. }
  463. scan(send_data_dsp, send_data_cnt);
  464. comm.Alltoallv(send_data_cnt.begin(), send_node_cnt.begin(), send_node_dsp.begin(), recv_data_cnt.begin(), recv_node_cnt.begin(), recv_node_dsp.begin());
  465. scan(recv_data_dsp, recv_data_cnt);
  466. }
  467. Vector<ValueType> send_buff, recv_buff;
  468. Vector<Long> send_buff_cnt(np), send_buff_dsp(np);
  469. Vector<Long> recv_buff_cnt(np), recv_buff_dsp(np);
  470. { // Set send_buff, send_buff_cnt, send_buff_dsp, recv_buff, recv_buff_cnt, recv_buff_dsp
  471. Long N_send_nodes = send_mid.Dim();
  472. Long N_recv_nodes = recv_mid.Dim();
  473. if (N_send_nodes) send_buff.ReInit((send_data_dsp[N_send_nodes-1] + send_data_cnt[N_send_nodes-1]) * dof);
  474. if (N_recv_nodes) recv_buff.ReInit((recv_data_dsp[N_recv_nodes-1] + recv_data_cnt[N_recv_nodes-1]) * dof);
  475. for (Long i = 0; i < N_send_nodes; i++) {
  476. Long idx = std::lower_bound(node_mid.begin(), node_mid.end(), send_mid[i]) - node_mid.begin();
  477. SCTL_ASSERT(send_mid[i] == node_mid[idx]);
  478. Long dsp_ = dsp[idx] * dof;
  479. Long cnt_ = cnt[idx] * dof;
  480. Long send_data_dsp_ = send_data_dsp[i] * dof;
  481. Long send_data_cnt_ = send_data_cnt[i] * dof;
  482. SCTL_ASSERT(send_data_cnt_ == cnt_);
  483. for (Long j = 0; j < cnt_; j++) {
  484. send_buff[send_data_dsp_+j] = data[dsp_+j];
  485. }
  486. }
  487. for (Integer p = 0; p < np; p++) {
  488. Long send_buff_cnt_ = 0;
  489. Long recv_buff_cnt_ = 0;
  490. for (Long i = 0; i < send_node_cnt[p]; i++) {
  491. send_buff_cnt_ += send_data_cnt[send_node_dsp[p]+i];
  492. }
  493. for (Long i = 0; i < recv_node_cnt[p]; i++) {
  494. recv_buff_cnt_ += recv_data_cnt[recv_node_dsp[p]+i];
  495. }
  496. send_buff_cnt[p] = send_buff_cnt_ * dof;
  497. recv_buff_cnt[p] = recv_buff_cnt_ * dof;
  498. }
  499. scan(send_buff_dsp, send_buff_cnt);
  500. scan(recv_buff_dsp, recv_buff_cnt);
  501. comm.Alltoallv(send_buff.begin(), send_buff_cnt.begin(), send_buff_dsp.begin(), recv_buff.begin(), recv_buff_cnt.begin(), recv_buff_dsp.begin());
  502. }
  503. { // Reduction
  504. Long N_recv_nodes = recv_mid.Dim();
  505. for (Long i = 0; i < N_recv_nodes; i++) {
  506. Long idx = std::lower_bound(node_mid.begin(), node_mid.end(), recv_mid[i]) - node_mid.begin();
  507. Long dsp_ = dsp[idx] * dof;
  508. Long cnt_ = cnt[idx] * dof;
  509. Long recv_data_dsp_ = recv_data_dsp[i] * dof;
  510. Long recv_data_cnt_ = recv_data_cnt[i] * dof;
  511. if (recv_data_cnt_ == cnt_) {
  512. for (Long j = 0; j < cnt_; j++) {
  513. data[dsp_+j] += recv_buff[recv_data_dsp_+j];
  514. }
  515. }
  516. }
  517. }
  518. }
  519. Broadcast<ValueType>(name);
  520. }
  521. template <Integer DIM> template <class ValueType> void Tree<DIM>::Broadcast(const std::string& name) {
  522. Integer np = comm.Size();
  523. Integer rank = comm.Rank();
  524. Vector<Long> dsp;
  525. Iterator<Vector<char>> data_;
  526. Iterator<Vector<Long>> cnt_;
  527. GetData_(data_, cnt_, name);
  528. Vector<ValueType> data(data_->Dim()/sizeof(ValueType), (Iterator<ValueType>)data_->begin(), false);
  529. Vector<Long>& cnt = *cnt_;
  530. scan(dsp, cnt);
  531. Long dof;
  532. { // Set dof
  533. StaticArray<Long,2> Nl, Ng;
  534. Nl[0] = data.Dim();
  535. Nl[1] = omp_par::reduce(cnt.begin(), cnt.Dim());
  536. comm.Allreduce((ConstIterator<Long>)Nl, (Iterator<Long>)Ng, 2, Comm::CommOp::SUM);
  537. dof = Ng[0] / std::max<Long>(Ng[1],1);
  538. SCTL_ASSERT(Nl[0] == Nl[1] * dof);
  539. SCTL_ASSERT(Ng[0] == Ng[1] * dof);
  540. }
  541. { // Broadcast
  542. const Vector<Morton<DIM>>& send_mid = user_mid;
  543. const Vector<Long>& send_node_cnt = user_cnt;
  544. Vector<Long> send_node_dsp(np);
  545. { // Set send_dsp
  546. SCTL_ASSERT(send_node_cnt.Dim() == np);
  547. scan(send_node_dsp, send_node_cnt);
  548. SCTL_ASSERT(send_node_dsp[np-1] + send_node_cnt[np-1] == send_mid.Dim());
  549. }
  550. Vector<Morton<DIM>> recv_mid;
  551. Vector<Long> recv_node_cnt(np), recv_node_dsp(np);
  552. { // Set recv_mid, recv_node_cnt, recv_node_dsp
  553. comm.Alltoall(send_node_cnt.begin(), 1, recv_node_cnt.begin(), 1);
  554. scan(recv_node_dsp, recv_node_cnt);
  555. recv_mid.ReInit(recv_node_dsp[np-1] + recv_node_cnt[np-1]);
  556. comm.Alltoallv(send_mid.begin(), send_node_cnt.begin(), send_node_dsp.begin(), recv_mid.begin(), recv_node_cnt.begin(), recv_node_dsp.begin());
  557. }
  558. Vector<Long> send_data_cnt, send_data_dsp;
  559. Vector<Long> recv_data_cnt, recv_data_dsp;
  560. { // Set send_data_cnt, send_data_dsp
  561. send_data_cnt.ReInit(send_mid.Dim());
  562. recv_data_cnt.ReInit(recv_mid.Dim());
  563. for (Long i = 0; i < send_mid.Dim(); i++) {
  564. Long idx = std::lower_bound(node_mid.begin(), node_mid.end(), send_mid[i]) - node_mid.begin();
  565. SCTL_ASSERT(send_mid[i] == node_mid[idx]);
  566. send_data_cnt[i] = cnt[idx];
  567. }
  568. scan(send_data_dsp, send_data_cnt);
  569. comm.Alltoallv(send_data_cnt.begin(), send_node_cnt.begin(), send_node_dsp.begin(), recv_data_cnt.begin(), recv_node_cnt.begin(), recv_node_dsp.begin());
  570. scan(recv_data_dsp, recv_data_cnt);
  571. }
  572. Vector<ValueType> send_buff, recv_buff;
  573. Vector<Long> send_buff_cnt(np), send_buff_dsp(np);
  574. Vector<Long> recv_buff_cnt(np), recv_buff_dsp(np);
  575. { // Set send_buff, send_buff_cnt, send_buff_dsp, recv_buff, recv_buff_cnt, recv_buff_dsp
  576. Long N_send_nodes = send_mid.Dim();
  577. Long N_recv_nodes = recv_mid.Dim();
  578. if (N_send_nodes) send_buff.ReInit((send_data_dsp[N_send_nodes-1] + send_data_cnt[N_send_nodes-1]) * dof);
  579. if (N_recv_nodes) recv_buff.ReInit((recv_data_dsp[N_recv_nodes-1] + recv_data_cnt[N_recv_nodes-1]) * dof);
  580. for (Long i = 0; i < N_send_nodes; i++) {
  581. Long idx = std::lower_bound(node_mid.begin(), node_mid.end(), send_mid[i]) - node_mid.begin();
  582. SCTL_ASSERT(send_mid[i] == node_mid[idx]);
  583. Long dsp_ = dsp[idx] * dof;
  584. Long cnt_ = cnt[idx] * dof;
  585. Long send_data_dsp_ = send_data_dsp[i] * dof;
  586. Long send_data_cnt_ = send_data_cnt[i] * dof;
  587. SCTL_ASSERT(send_data_cnt_ == cnt_);
  588. for (Long j = 0; j < cnt_; j++) {
  589. send_buff[send_data_dsp_+j] = data[dsp_+j];
  590. }
  591. }
  592. for (Integer p = 0; p < np; p++) {
  593. Long send_buff_cnt_ = 0;
  594. Long recv_buff_cnt_ = 0;
  595. for (Long i = 0; i < send_node_cnt[p]; i++) {
  596. send_buff_cnt_ += send_data_cnt[send_node_dsp[p]+i];
  597. }
  598. for (Long i = 0; i < recv_node_cnt[p]; i++) {
  599. recv_buff_cnt_ += recv_data_cnt[recv_node_dsp[p]+i];
  600. }
  601. send_buff_cnt[p] = send_buff_cnt_ * dof;
  602. recv_buff_cnt[p] = recv_buff_cnt_ * dof;
  603. }
  604. scan(send_buff_dsp, send_buff_cnt);
  605. scan(recv_buff_dsp, recv_buff_cnt);
  606. comm.Alltoallv(send_buff.begin(), send_buff_cnt.begin(), send_buff_dsp.begin(), recv_buff.begin(), recv_buff_cnt.begin(), recv_buff_dsp.begin());
  607. }
  608. Long start_idx, end_idx;
  609. { // Set start_idx, end_idx
  610. start_idx = std::lower_bound(node_mid.begin(), node_mid.end(), mins[rank]) - node_mid.begin();
  611. end_idx = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
  612. SCTL_ASSERT(0 <= start_idx);
  613. SCTL_ASSERT(start_idx < end_idx);
  614. SCTL_ASSERT(end_idx <= node_mid.Dim());
  615. }
  616. { // Update data <-- data + recv_buff
  617. Long Nsplit = std::lower_bound(recv_mid.begin(), recv_mid.end(), mins[rank]) - recv_mid.begin();
  618. SCTL_ASSERT(recv_mid.Dim()-Nsplit == node_mid.Dim()-end_idx);
  619. SCTL_ASSERT(Nsplit == start_idx);
  620. Long N0 = (start_idx ? dsp[start_idx-1] + cnt[start_idx-1] : 0) * dof;
  621. Long N1 = (end_idx ? dsp[end_idx-1] + cnt[end_idx-1] : 0) * dof;
  622. Long Ns = (Nsplit ? recv_data_dsp[Nsplit-1] + recv_data_cnt[Nsplit-1] : 0) * dof;
  623. if (N0 != Ns || recv_buff.Dim() != N0+data.Dim()-N1) { // resize data and preserve non-ghost data
  624. Vector<char> data_new((recv_buff.Dim() + N1-N0) * sizeof(ValueType));
  625. memcopy(data_new.begin() + Ns * sizeof(ValueType), data_->begin() + N0 * sizeof(ValueType), (N1-N0) * sizeof(ValueType));
  626. data_->Swap(data_new);
  627. data.ReInit(data_->Dim()/sizeof(ValueType), (Iterator<ValueType>)data_->begin(), false);
  628. }
  629. memcopy(cnt.begin(), recv_data_cnt.begin(), start_idx);
  630. memcopy(cnt.begin()+end_idx, recv_data_cnt.begin()+Nsplit, node_mid.Dim()-end_idx);
  631. memcopy(data.begin(), recv_buff.begin(), Ns);
  632. memcopy(data.begin()+data.Dim()+Ns-recv_buff.Dim(), recv_buff.begin()+Ns, recv_buff.Dim()-Ns);
  633. }
  634. }
  635. }
  636. template <Integer DIM> void Tree<DIM>::DeleteData(const std::string& name) {
  637. SCTL_ASSERT(node_data.find(name) != node_data.end());
  638. SCTL_ASSERT(node_cnt .find(name) != node_cnt .end());
  639. node_data.erase(name);
  640. node_cnt .erase(name);
  641. }
  642. template <Integer DIM> void Tree<DIM>::WriteTreeVTK(std::string fname, bool show_ghost) const {
  643. typedef typename VTUData::VTKReal VTKReal;
  644. VTUData vtu_data;
  645. if (DIM <= 3) { // Set vtu data
  646. static const Integer Ncorner = (1u << DIM);
  647. Vector<VTKReal> &coord = vtu_data.coord;
  648. //Vector<VTKReal> &value = vtu_data.value;
  649. Vector<int32_t> &connect = vtu_data.connect;
  650. Vector<int32_t> &offset = vtu_data.offset;
  651. Vector<uint8_t> &types = vtu_data.types;
  652. StaticArray<VTKReal, DIM> c;
  653. Long point_cnt = coord.Dim() / 3;
  654. Long connect_cnt = connect.Dim();
  655. for (Long nid = 0; nid < node_mid.Dim(); nid++) {
  656. const Morton<DIM> &mid = node_mid[nid];
  657. const NodeAttr &attr = node_attr[nid];
  658. if (!show_ghost && attr.Ghost) continue;
  659. if (!attr.Leaf) continue;
  660. mid.Coord((Iterator<VTKReal>)c);
  661. VTKReal s = sctl::pow<VTKReal>(0.5, mid.Depth());
  662. for (Integer j = 0; j < Ncorner; j++) {
  663. for (Integer i = 0; i < DIM; i++) coord.PushBack(c[i] + ((j & (1u << i)) ? 1 : 0) * s);
  664. for (Integer i = DIM; i < 3; i++) coord.PushBack(0);
  665. connect.PushBack(point_cnt);
  666. connect_cnt++;
  667. point_cnt++;
  668. }
  669. offset.PushBack(connect_cnt);
  670. if (DIM == 2)
  671. types.PushBack(8);
  672. else if (DIM == 3)
  673. types.PushBack(11);
  674. else
  675. types.PushBack(4);
  676. }
  677. }
  678. vtu_data.WriteVTK(fname, comm);
  679. }
  680. template <Integer DIM> void Tree<DIM>::GetData_(Iterator<Vector<char>>& data, Iterator<Vector<Long>>& cnt, const std::string& name) {
  681. auto data_ = node_data.find(name);
  682. const auto cnt_ = node_cnt.find(name);
  683. SCTL_ASSERT(data_ != node_data.end());
  684. SCTL_ASSERT( cnt_ != node_cnt .end());
  685. data = Ptr2Itr<Vector<char>>(&data_->second,1);
  686. cnt = Ptr2Itr<Vector<Long>>(& cnt_->second,1);
  687. }
  688. template <Integer DIM> void Tree<DIM>::scan(Vector<Long>& dsp, const Vector<Long>& cnt) {
  689. dsp.ReInit(cnt.Dim());
  690. if (cnt.Dim()) dsp[0] = 0;
  691. omp_par::scan(cnt.begin(), dsp.begin(), cnt.Dim());
  692. }
  693. template <class Real, Integer DIM, class BaseTree> PtTree<Real,DIM,BaseTree>::PtTree(const Comm& comm) : BaseTree(comm) {}
  694. template <class Real, Integer DIM, class BaseTree> PtTree<Real,DIM,BaseTree>::~PtTree() {
  695. #ifdef SCTL_MEMDEBUG
  696. for (auto& pair : data_pt_name) {
  697. Vector<Real> data;
  698. Vector<Long> cnt;
  699. this->GetData(data, cnt, pair.second);
  700. SCTL_ASSERT(scatter_idx.find(pair.second) != scatter_idx.end());
  701. }
  702. #endif
  703. }
  704. template <class Real, Integer DIM, class BaseTree> void PtTree<Real,DIM,BaseTree>::UpdateRefinement(const Vector<Real>& coord, Long M, bool balance21, bool periodic) {
  705. const auto& comm = this->GetComm();
  706. BaseTree::UpdateRefinement(coord, M, balance21, periodic);
  707. Long start_node_idx, end_node_idx;
  708. { // Set start_node_idx, end_node_idx
  709. const auto& mins = this->GetPartitionMID();
  710. const auto& node_mid = this->GetNodeMID();
  711. Integer np = comm.Size();
  712. Integer rank = comm.Rank();
  713. start_node_idx = std::lower_bound(node_mid.begin(), node_mid.end(), mins[rank]) - node_mid.begin();
  714. end_node_idx = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
  715. }
  716. const auto& mins = this->GetPartitionMID();
  717. const auto& node_mid = this->GetNodeMID();
  718. for (const auto& pair : pt_mid) {
  719. const auto& pt_name = pair.first;
  720. auto& pt_mid_ = pt_mid[pt_name];
  721. auto& scatter_idx_ = scatter_idx[pt_name];
  722. comm.PartitionS(pt_mid_, mins[comm.Rank()]);
  723. comm.PartitionN(scatter_idx_, pt_mid_.Dim());
  724. Vector<Long> pt_cnt(node_mid.Dim());
  725. for (Long i = 0; i < node_mid.Dim(); i++) { // Set pt_cnt
  726. Long start = std::lower_bound(pt_mid_.begin(), pt_mid_.end(), node_mid[i]) - pt_mid_.begin();
  727. Long end = std::lower_bound(pt_mid_.begin(), pt_mid_.end(), (i+1==node_mid.Dim() ? Morton<DIM>().Next() : node_mid[i+1])) - pt_mid_.begin();
  728. if (i == 0) SCTL_ASSERT(start == 0);
  729. if (i+1 == node_mid.Dim()) SCTL_ASSERT(end == pt_mid_.Dim());
  730. pt_cnt[i] = end - start;
  731. }
  732. for (const auto& pair : data_pt_name) {
  733. if (pair.second == pt_name) {
  734. const auto& data_name = pair.first;
  735. Iterator<Vector<char>> data;
  736. Iterator<Vector<Long>> cnt;
  737. this->GetData_(data, cnt, data_name);
  738. { // Update data
  739. Long dof = 0;
  740. { // Set dof
  741. StaticArray<Long,2> Nl {0, 0}, Ng;
  742. Nl[0] = data->Dim();
  743. for (Long i = 0; i < cnt->Dim(); i++) Nl[1] += cnt[0][i];
  744. comm.Allreduce((ConstIterator<Long>)Nl, (Iterator<Long>)Ng, 2, Comm::CommOp::SUM);
  745. dof = Ng[0] / std::max<Long>(Ng[1],1);
  746. }
  747. Long offset = 0, count = 0;
  748. SCTL_ASSERT(0 <= start_node_idx);
  749. SCTL_ASSERT(start_node_idx <= end_node_idx);
  750. SCTL_ASSERT(end_node_idx <= cnt->Dim());
  751. for (Long i = 0; i < start_node_idx; i++) offset += cnt[0][i];
  752. for (Long i = start_node_idx; i < end_node_idx; i++) count += cnt[0][i];
  753. offset *= dof;
  754. count *= dof;
  755. Vector<char> data_(count, data->begin() + offset);
  756. comm.PartitionN(data_, pt_mid_.Dim());
  757. data->Swap(data_);
  758. }
  759. cnt[0] = pt_cnt;
  760. }
  761. }
  762. }
  763. }
  764. template <class Real, Integer DIM, class BaseTree> void PtTree<Real,DIM,BaseTree>::AddParticles(const std::string& name, const Vector<Real>& coord) {
  765. const auto& mins = this->GetPartitionMID();
  766. const auto& node_mid = this->GetNodeMID();
  767. const auto& comm = this->GetComm();
  768. SCTL_ASSERT(scatter_idx.find(name) == scatter_idx.end());
  769. Vector<Long>& scatter_idx_ = scatter_idx[name];
  770. Long N = coord.Dim() / DIM;
  771. SCTL_ASSERT(coord.Dim() == N * DIM);
  772. Nlocal[name] = N;
  773. Vector<Morton<DIM>>& pt_mid_ = pt_mid[name];
  774. if (pt_mid_.Dim() != N) pt_mid_.ReInit(N);
  775. for (Long i = 0; i < N; i++) {
  776. pt_mid_[i] = Morton<DIM>(coord.begin() + i*DIM);
  777. }
  778. comm.SortScatterIndex(pt_mid_, scatter_idx_, &mins[comm.Rank()]);
  779. comm.ScatterForward(pt_mid_, scatter_idx_);
  780. AddParticleData(name, name, coord);
  781. { // Set node_cnt
  782. Iterator<Vector<char>> data_;
  783. Iterator<Vector<Long>> cnt_;
  784. this->GetData_(data_,cnt_,name);
  785. cnt_[0].ReInit(node_mid.Dim());
  786. for (Long i = 0; i < node_mid.Dim(); i++) {
  787. Long start = std::lower_bound(pt_mid_.begin(), pt_mid_.end(), node_mid[i]) - pt_mid_.begin();
  788. Long end = std::lower_bound(pt_mid_.begin(), pt_mid_.end(), (i+1==node_mid.Dim() ? Morton<DIM>().Next() : node_mid[i+1])) - pt_mid_.begin();
  789. if (i == 0) SCTL_ASSERT(start == 0);
  790. if (i+1 == node_mid.Dim()) SCTL_ASSERT(end == pt_mid_.Dim());
  791. cnt_[0][i] = end - start;
  792. }
  793. }
  794. }
  795. template <class Real, Integer DIM, class BaseTree> void PtTree<Real,DIM,BaseTree>::AddParticleData(const std::string& data_name, const std::string& particle_name, const Vector<Real>& data) {
  796. SCTL_ASSERT(scatter_idx.find(particle_name) != scatter_idx.end());
  797. SCTL_ASSERT(data_pt_name.find(data_name) == data_pt_name.end());
  798. data_pt_name[data_name] = particle_name;
  799. Iterator<Vector<char>> data_;
  800. Iterator<Vector<Long>> cnt_;
  801. this->AddData(data_name, Vector<Real>(), Vector<Long>());
  802. this->GetData_(data_,cnt_,data_name);
  803. { // Set data_[0]
  804. data_[0].ReInit(data.Dim()*sizeof(Real), (Iterator<char>)data.begin(), true);
  805. this->GetComm().ScatterForward(data_[0], scatter_idx[particle_name]);
  806. }
  807. if (data_name != particle_name) { // Set cnt_[0]
  808. Vector<Real> pt_coord;
  809. Vector<Long> pt_cnt;
  810. this->GetData(pt_coord, pt_cnt, particle_name);
  811. cnt_[0] = pt_cnt;
  812. }
  813. }
  814. template <class Real, Integer DIM, class BaseTree> void PtTree<Real,DIM,BaseTree>::GetParticleData(Vector<Real>& data, const std::string& data_name) const {
  815. SCTL_ASSERT(data_pt_name.find(data_name) != data_pt_name.end());
  816. const std::string& particle_name = data_pt_name.find(data_name)->second;
  817. SCTL_ASSERT(scatter_idx.find(particle_name) != scatter_idx.end());
  818. const auto& scatter_idx_ = scatter_idx.find(particle_name)->second;
  819. const Long Nlocal_ = Nlocal.find(particle_name)->second;
  820. const auto& mins = this->GetPartitionMID();
  821. const auto& node_mid = this->GetNodeMID();
  822. const auto& comm = this->GetComm();
  823. Long dof;
  824. Vector<Long> dsp;
  825. Vector<Long> cnt_;
  826. Vector<Real> data_;
  827. this->GetData(data_, cnt_, data_name);
  828. SCTL_ASSERT(cnt_.Dim() == node_mid.Dim());
  829. BaseTree::scan(dsp, cnt_);
  830. { // Set dof
  831. Long Nn = node_mid.Dim();
  832. StaticArray<Long,2> Ng, Nl = {data_.Dim(), dsp[Nn-1]+cnt_[Nn-1]};
  833. comm.Allreduce((ConstIterator<Long>)Nl, (Iterator<Long>)Ng, 2, Comm::CommOp::SUM);
  834. dof = Ng[0] / std::max<Long>(Ng[1],1);
  835. }
  836. { // Set data
  837. Integer np = comm.Size();
  838. Integer rank = comm.Rank();
  839. Long N0 = std::lower_bound(node_mid.begin(), node_mid.end(), mins[rank]) - node_mid.begin();
  840. Long N1 = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
  841. Long start = dsp[N0] * dof;
  842. Long end = (N1<dsp.Dim() ? dsp[N1] : dsp[N1-1]+cnt_[N1-1]) * dof;
  843. data.ReInit(end-start, data_.begin()+start, true);
  844. comm.ScatterReverse(data, scatter_idx_, Nlocal_ * dof);
  845. }
  846. }
  847. template <class Real, Integer DIM, class BaseTree> void PtTree<Real,DIM,BaseTree>::DeleteParticleData(const std::string& data_name) {
  848. SCTL_ASSERT(data_pt_name.find(data_name) != data_pt_name.end());
  849. auto particle_name = data_pt_name[data_name];
  850. if (data_name == particle_name) {
  851. std::vector<std::string> data_name_lst;
  852. for (auto& pair : data_pt_name) {
  853. if (pair.second == particle_name) {
  854. data_name_lst.push_back(pair.first);
  855. }
  856. }
  857. for (auto x : data_name_lst) {
  858. if (x != particle_name) {
  859. DeleteParticleData(x);
  860. }
  861. }
  862. Nlocal.erase(particle_name);
  863. }
  864. this->DeleteData(data_name);
  865. data_pt_name.erase(data_name);
  866. }
  867. template <class Real, Integer DIM, class BaseTree> void PtTree<Real,DIM,BaseTree>::WriteParticleVTK(std::string fname, std::string data_name, bool show_ghost) const {
  868. typedef typename VTUData::VTKReal VTKReal;
  869. const auto& node_mid = this->GetNodeMID();
  870. const auto& node_attr = this->GetNodeAttr();
  871. VTUData vtu_data;
  872. if (DIM <= 3) { // Set vtu data
  873. SCTL_ASSERT(data_pt_name.find(data_name) != data_pt_name.end());
  874. std::string particle_name = data_pt_name.find(data_name)->second;
  875. Vector<Real> pt_coord;
  876. Vector<Real> pt_value;
  877. Vector<Long> pt_cnt;
  878. Vector<Long> pt_dsp;
  879. Long value_dof = 0;
  880. { // Set pt_coord, pt_cnt, pt_dsp
  881. this->GetData(pt_coord, pt_cnt, particle_name);
  882. Tree<DIM>::scan(pt_dsp, pt_cnt);
  883. }
  884. if (particle_name != data_name) { // Set pt_value, value_dof
  885. Vector<Long> pt_cnt;
  886. this->GetData(pt_value, pt_cnt, data_name);
  887. Long Npt = omp_par::reduce(pt_cnt.begin(), pt_cnt.Dim());
  888. value_dof = pt_value.Dim() / std::max<Long>(Npt,1);
  889. }
  890. Vector<VTKReal> &coord = vtu_data.coord;
  891. Vector<VTKReal> &value = vtu_data.value;
  892. Vector<int32_t> &connect = vtu_data.connect;
  893. Vector<int32_t> &offset = vtu_data.offset;
  894. Vector<uint8_t> &types = vtu_data.types;
  895. Long point_cnt = coord.Dim() / DIM;
  896. Long connect_cnt = connect.Dim();
  897. value.ReInit(point_cnt * value_dof);
  898. value.SetZero();
  899. SCTL_ASSERT(node_mid.Dim() == node_attr.Dim());
  900. SCTL_ASSERT(node_mid.Dim() == pt_cnt.Dim());
  901. for (Long i = 0; i < node_mid.Dim(); i++) {
  902. if (!show_ghost && node_attr[i].Ghost) continue;
  903. if (!node_attr[i].Leaf) continue;
  904. for (Long j = 0; j < pt_cnt[i]; j++) {
  905. ConstIterator<Real> pt_coord_ = pt_coord.begin() + (pt_dsp[i] + j) * DIM;
  906. ConstIterator<Real> pt_value_ = (value_dof ? pt_value.begin() + (pt_dsp[i] + j) * value_dof : NullIterator<Real>());
  907. for (Integer k = 0; k < DIM; k++) coord.PushBack((VTKReal)pt_coord_[k]);
  908. for (Integer k = DIM; k < 3; k++) coord.PushBack(0);
  909. for (Integer k = 0; k < value_dof; k++) value.PushBack((VTKReal)pt_value_[k]);
  910. connect.PushBack(point_cnt);
  911. connect_cnt++;
  912. point_cnt++;
  913. offset.PushBack(connect_cnt);
  914. types.PushBack(1);
  915. }
  916. }
  917. }
  918. vtu_data.WriteVTK(fname, this->GetComm());
  919. }
  920. }