tree.hpp 54 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327
  1. #ifndef _SCTL_TREE_
  2. #define _SCTL_TREE_
  3. #include SCTL_INCLUDE(common.hpp)
  4. #include SCTL_INCLUDE(morton.hpp)
  5. #include SCTL_INCLUDE(comm.hpp)
  6. #include SCTL_INCLUDE(matrix.hpp) // TODO: fix issues when this is before #Include <comm.hpp>
  7. #include <fstream>
  8. #include <algorithm>
  9. namespace SCTL_NAMESPACE {
  10. struct VTUData {
  11. typedef float VTKReal;
  12. // Point data
  13. Vector<VTKReal> coord; // always 3D
  14. Vector<VTKReal> value;
  15. // Cell data
  16. Vector<int32_t> connect;
  17. Vector<int32_t> offset;
  18. Vector<uint8_t> types;
  19. void WriteVTK(const std::string& fname, const Comm& comm = Comm::Self()) const {
  20. typedef typename VTUData::VTKReal VTKReal;
  21. Long value_dof = 0;
  22. { // Write vtu file.
  23. std::ofstream vtufile;
  24. { // Open file for writing.
  25. std::stringstream vtufname;
  26. vtufname << fname << std::setfill('0') << std::setw(6) << comm.Rank() << ".vtu";
  27. vtufile.open(vtufname.str().c_str());
  28. if (vtufile.fail()) return;
  29. }
  30. { // Write to file.
  31. Long pt_cnt = coord.Dim() / 3;
  32. Long cell_cnt = types.Dim();
  33. value_dof = (pt_cnt ? value.Dim() / pt_cnt : 0);
  34. Vector<int32_t> mpi_rank;
  35. { // Set mpi_rank
  36. Integer new_myrank = comm.Rank();
  37. mpi_rank.ReInit(pt_cnt);
  38. for (Long i = 0; i < mpi_rank.Dim(); i++) mpi_rank[i] = new_myrank;
  39. }
  40. bool isLittleEndian;
  41. { // Set isLittleEndian
  42. uint16_t number = 0x1;
  43. uint8_t *numPtr = (uint8_t *)&number;
  44. isLittleEndian = (numPtr[0] == 1);
  45. }
  46. Long data_size = 0;
  47. vtufile << "<?xml version=\"1.0\"?>\n";
  48. vtufile << "<VTKFile type=\"UnstructuredGrid\" version=\"0.1\" byte_order=\"" << (isLittleEndian ? "LittleEndian" : "BigEndian") << "\">\n";
  49. // ===========================================================================
  50. vtufile << " <UnstructuredGrid>\n";
  51. vtufile << " <Piece NumberOfPoints=\"" << pt_cnt << "\" NumberOfCells=\"" << cell_cnt << "\">\n";
  52. //---------------------------------------------------------------------------
  53. vtufile << " <Points>\n";
  54. vtufile << " <DataArray type=\"Float" << sizeof(VTKReal) * 8 << "\" NumberOfComponents=\"3\" Name=\"Position\" format=\"appended\" offset=\"" << data_size << "\" />\n";
  55. data_size += sizeof(uint32_t) + coord.Dim() * sizeof(VTKReal);
  56. vtufile << " </Points>\n";
  57. //---------------------------------------------------------------------------
  58. vtufile << " <PointData>\n";
  59. if (value_dof) { // value
  60. vtufile << " <DataArray type=\"Float" << sizeof(VTKReal) * 8 << "\" NumberOfComponents=\"" << value_dof << "\" Name=\"value\" format=\"appended\" offset=\"" << data_size << "\" />\n";
  61. data_size += sizeof(uint32_t) + value.Dim() * sizeof(VTKReal);
  62. }
  63. { // mpi_rank
  64. vtufile << " <DataArray type=\"Int32\" NumberOfComponents=\"1\" Name=\"mpi_rank\" format=\"appended\" offset=\"" << data_size << "\" />\n";
  65. data_size += sizeof(uint32_t) + pt_cnt * sizeof(int32_t);
  66. }
  67. vtufile << " </PointData>\n";
  68. //---------------------------------------------------------------------------
  69. //---------------------------------------------------------------------------
  70. vtufile << " <Cells>\n";
  71. vtufile << " <DataArray type=\"Int32\" Name=\"connectivity\" format=\"appended\" offset=\"" << data_size << "\" />\n";
  72. data_size += sizeof(uint32_t) + connect.Dim() * sizeof(int32_t);
  73. vtufile << " <DataArray type=\"Int32\" Name=\"offsets\" format=\"appended\" offset=\"" << data_size << "\" />\n";
  74. data_size += sizeof(uint32_t) + offset.Dim() * sizeof(int32_t);
  75. vtufile << " <DataArray type=\"UInt8\" Name=\"types\" format=\"appended\" offset=\"" << data_size << "\" />\n";
  76. //data_size += sizeof(uint32_t) + types.Dim() * sizeof(uint8_t);
  77. vtufile << " </Cells>\n";
  78. //---------------------------------------------------------------------------
  79. vtufile << " </Piece>\n";
  80. vtufile << " </UnstructuredGrid>\n";
  81. // ===========================================================================
  82. vtufile << " <AppendedData encoding=\"raw\">\n";
  83. vtufile << " _";
  84. int32_t block_size;
  85. { // coord
  86. block_size = coord.Dim() * sizeof(VTKReal);
  87. vtufile.write((char *)&block_size, sizeof(int32_t));
  88. if (coord.Dim()) vtufile.write((char *)&coord[0], coord.Dim() * sizeof(VTKReal));
  89. }
  90. if (value_dof) { // value
  91. block_size = value.Dim() * sizeof(VTKReal);
  92. vtufile.write((char *)&block_size, sizeof(int32_t));
  93. if (value.Dim()) vtufile.write((char *)&value[0], value.Dim() * sizeof(VTKReal));
  94. }
  95. { // mpi_rank
  96. block_size = mpi_rank.Dim() * sizeof(int32_t);
  97. vtufile.write((char *)&block_size, sizeof(int32_t));
  98. if (mpi_rank.Dim()) vtufile.write((char *)&mpi_rank[0], mpi_rank.Dim() * sizeof(int32_t));
  99. }
  100. { // block_size
  101. block_size = connect.Dim() * sizeof(int32_t);
  102. vtufile.write((char *)&block_size, sizeof(int32_t));
  103. if (connect.Dim()) vtufile.write((char *)&connect[0], connect.Dim() * sizeof(int32_t));
  104. }
  105. { // offset
  106. block_size = offset.Dim() * sizeof(int32_t);
  107. vtufile.write((char *)&block_size, sizeof(int32_t));
  108. if (offset.Dim()) vtufile.write((char *)&offset[0], offset.Dim() * sizeof(int32_t));
  109. }
  110. { // types
  111. block_size = types.Dim() * sizeof(uint8_t);
  112. vtufile.write((char *)&block_size, sizeof(int32_t));
  113. if (types.Dim()) vtufile.write((char *)&types[0], types.Dim() * sizeof(uint8_t));
  114. }
  115. vtufile << "\n";
  116. vtufile << " </AppendedData>\n";
  117. // ===========================================================================
  118. vtufile << "</VTKFile>\n";
  119. }
  120. vtufile.close(); // close file
  121. }
  122. if (!comm.Rank()) { // Write pvtu file
  123. std::ofstream pvtufile;
  124. { // Open file for writing
  125. std::stringstream pvtufname;
  126. pvtufname << fname << ".pvtu";
  127. pvtufile.open(pvtufname.str().c_str());
  128. if (pvtufile.fail()) return;
  129. }
  130. { // Write to file.
  131. pvtufile << "<?xml version=\"1.0\"?>\n";
  132. pvtufile << "<VTKFile type=\"PUnstructuredGrid\">\n";
  133. pvtufile << " <PUnstructuredGrid GhostLevel=\"0\">\n";
  134. pvtufile << " <PPoints>\n";
  135. pvtufile << " <PDataArray type=\"Float" << sizeof(VTKReal) * 8 << "\" NumberOfComponents=\"3\" Name=\"Position\"/>\n";
  136. pvtufile << " </PPoints>\n";
  137. pvtufile << " <PPointData>\n";
  138. if (value_dof) { // value
  139. pvtufile << " <PDataArray type=\"Float" << sizeof(VTKReal) * 8 << "\" NumberOfComponents=\"" << value_dof << "\" Name=\"value\"/>\n";
  140. }
  141. { // mpi_rank
  142. pvtufile << " <PDataArray type=\"Int32\" NumberOfComponents=\"1\" Name=\"mpi_rank\"/>\n";
  143. }
  144. pvtufile << " </PPointData>\n";
  145. {
  146. // Extract filename from path.
  147. std::stringstream vtupath;
  148. vtupath << '/' << fname;
  149. std::string pathname = vtupath.str();
  150. std::string fname_ = pathname.substr(pathname.find_last_of("/\\") + 1);
  151. // char *fname_ = (char*)strrchr(vtupath.str().c_str(), '/') + 1;
  152. // std::string fname_ =
  153. // boost::filesystem::path(fname).filename().string().
  154. for (Integer i = 0; i < comm.Size(); i++) pvtufile << " <Piece Source=\"" << fname_ << std::setfill('0') << std::setw(6) << i << ".vtu\"/>\n";
  155. }
  156. pvtufile << " </PUnstructuredGrid>\n";
  157. pvtufile << "</VTKFile>\n";
  158. }
  159. pvtufile.close(); // close file
  160. }
  161. };
  162. template <class ElemLst> void AddElems(const ElemLst elem_lst, Integer order, const Comm& comm = Comm::Self()) {
  163. constexpr Integer COORD_DIM = ElemLst::CoordDim();
  164. constexpr Integer ElemDim = ElemLst::ElemDim();
  165. using CoordBasis = typename ElemLst::CoordBasis;
  166. using CoordType = typename ElemLst::CoordType;
  167. Long N0 = coord.Dim() / COORD_DIM;
  168. Long NElem = elem_lst.NElem();
  169. Matrix<CoordType> nodes = VTK_Nodes<CoordType, ElemDim>(order);
  170. Integer Nnodes = sctl::pow<ElemDim,Integer>(order);
  171. SCTL_ASSERT(nodes.Dim(0) == ElemDim);
  172. SCTL_ASSERT(nodes.Dim(1) == Nnodes);
  173. { // Set coord
  174. Matrix<CoordType> vtk_coord;
  175. auto M = CoordBasis::SetupEval(nodes);
  176. CoordBasis::Eval(vtk_coord, elem_lst.ElemVector(), M);
  177. for (Long k = 0; k < NElem; k++) {
  178. for (Integer i = 0; i < Nnodes; i++) {
  179. constexpr Integer dim = (COORD_DIM < 3 ? COORD_DIM : 3);
  180. for (Integer j = 0; j < dim; j++) {
  181. coord.PushBack((VTUData::VTKReal)vtk_coord[k*COORD_DIM+j][i]);
  182. }
  183. for (Integer j = dim; j < 3; j++) {
  184. coord.PushBack((VTUData::VTKReal)0);
  185. }
  186. }
  187. }
  188. }
  189. if (ElemLst::ElemDim() == 2) {
  190. for (Long k = 0; k < NElem; k++) {
  191. for (Integer i = 0; i < order-1; i++) {
  192. for (Integer j = 0; j < order-1; j++) {
  193. Long idx = k*Nnodes + i*order + j;
  194. connect.PushBack(N0+idx);
  195. connect.PushBack(N0+idx+1);
  196. connect.PushBack(N0+idx+order+1);
  197. connect.PushBack(N0+idx+order);
  198. offset.PushBack(connect.Dim());
  199. types.PushBack(9);
  200. }
  201. }
  202. }
  203. } else {
  204. // TODO
  205. SCTL_ASSERT(false);
  206. }
  207. }
  208. template <class ElemLst, class ValueBasis> void AddElems(const ElemLst elem_lst, const Vector<ValueBasis>& elem_value, Integer order, const Comm& comm = Comm::Self()) {
  209. constexpr Integer ElemDim = ElemLst::ElemDim();
  210. using ValueType = typename ValueBasis::ValueType;
  211. Long NElem = elem_lst.NElem();
  212. Integer dof = (NElem==0 ? 0 : elem_value.Dim() / NElem);
  213. SCTL_ASSERT(elem_value.Dim() == NElem * dof);
  214. AddElems(elem_lst, order, comm);
  215. Matrix<ValueType> nodes = VTK_Nodes<ValueType, ElemDim>(order);
  216. Integer Nnodes = sctl::pow<ElemDim,Integer>(order);
  217. SCTL_ASSERT(nodes.Dim(0) == ElemDim);
  218. SCTL_ASSERT(nodes.Dim(1) == Nnodes);
  219. { // Set value
  220. Matrix<ValueType> vtk_value;
  221. auto M = ValueBasis::SetupEval(nodes);
  222. ValueBasis::Eval(vtk_value, elem_value, M);
  223. for (Long k = 0; k < NElem; k++) {
  224. for (Integer i = 0; i < Nnodes; i++) {
  225. for (Integer j = 0; j < dof; j++) {
  226. value.PushBack((VTUData::VTKReal)vtk_value[k*dof+j][i]);
  227. }
  228. }
  229. }
  230. }
  231. }
  232. private:
  233. template <class CoordType, Integer ELEM_DIM> static Matrix<CoordType> VTK_Nodes(Integer order) {
  234. Matrix<CoordType> nodes;
  235. if (ELEM_DIM == 2) {
  236. Integer Nnodes = order*order;
  237. nodes.ReInit(ELEM_DIM, Nnodes);
  238. for (Integer i = 0; i < order; i++) {
  239. for (Integer j = 0; j < order; j++) {
  240. //nodes[0][i*order+j] = i / (CoordType)(order-1);
  241. //nodes[1][i*order+j] = j / (CoordType)(order-1);
  242. nodes[0][i*order+j] = 0.5 - 0.5 * sctl::cos<CoordType>((2*i+1) * const_pi<CoordType>() / (2*order));
  243. nodes[1][i*order+j] = 0.5 - 0.5 * sctl::cos<CoordType>((2*j+1) * const_pi<CoordType>() / (2*order));
  244. }
  245. }
  246. } else {
  247. // TODO
  248. SCTL_ASSERT(false);
  249. }
  250. return nodes;
  251. }
  252. };
  253. template <Integer DIM> class Tree {
  254. public:
  255. struct NodeAttr {
  256. unsigned char Leaf : 1, Ghost : 1;
  257. };
  258. struct NodeLists {
  259. Long p2n;
  260. Long parent;
  261. Long child[1 << DIM];
  262. Long nbr[sctl::pow<DIM,Integer>(3)];
  263. };
  264. static constexpr Integer Dim() {
  265. return DIM;
  266. }
  267. Tree(const Comm& comm_ = Comm::Self()) : comm(comm_) {
  268. Integer rank = comm.Rank();
  269. Integer np = comm.Size();
  270. Vector<double> coord;
  271. { // Set coord
  272. Long N0 = 1;
  273. while (sctl::pow<DIM,Long>(N0) < np) N0++;
  274. Long N = sctl::pow<DIM,Long>(N0);
  275. Long start = N * (rank+0) / np;
  276. Long end = N * (rank+1) / np;
  277. coord.ReInit((end-start)*DIM);
  278. for (Long i = start; i < end; i++) {
  279. Long idx = i;
  280. for (Integer k = 0; k < DIM; k++) {
  281. coord[(i-start)*DIM+k] = (idx % N0) / (double)N0;
  282. idx /= N0;
  283. }
  284. }
  285. }
  286. this->UpdateRefinement(coord);
  287. }
  288. ~Tree() {
  289. #ifdef SCTL_MEMDEBUG
  290. for (auto& pair : node_data) {
  291. SCTL_ASSERT(node_cnt.find(pair.first) != node_cnt.end());
  292. }
  293. #endif
  294. }
  295. const Vector<Morton<DIM>>& GetPartitionMID() const {
  296. return mins;
  297. }
  298. const Vector<Morton<DIM>>& GetNodeMID() const {
  299. return node_mid;
  300. }
  301. const Vector<NodeAttr>& GetNodeAttr() const {
  302. return node_attr;
  303. }
  304. const Vector<NodeLists>& GetNodeLists() const {
  305. return node_lst;
  306. }
  307. const Comm& GetComm() const {
  308. return comm;
  309. }
  310. template <class Real> void UpdateRefinement(const Vector<Real>& coord, Long M = 1, bool balance21 = 0, bool periodic = 0) {
  311. Integer np = comm.Size();
  312. Integer rank = comm.Rank();
  313. Vector<Morton<DIM>> node_mid_orig;
  314. Long start_idx_orig, end_idx_orig;
  315. if (mins.Dim()) { // Set start_idx_orig, end_idx_orig
  316. start_idx_orig = std::lower_bound(node_mid.begin(), node_mid.end(), mins[rank]) - node_mid.begin();
  317. end_idx_orig = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
  318. node_mid_orig.ReInit(end_idx_orig - start_idx_orig, node_mid.begin() + start_idx_orig, true);
  319. } else {
  320. start_idx_orig = 0;
  321. end_idx_orig = 0;
  322. }
  323. auto coarsest_ancestor_mid = [](const Morton<DIM>& m0) {
  324. Morton<DIM> md;
  325. Integer d0 = m0.Depth();
  326. for (Integer d = 0; d <= d0; d++) {
  327. md = m0.Ancestor(d);
  328. if (md.Ancestor(d0) == m0) break;
  329. }
  330. return md;
  331. };
  332. Morton<DIM> pt_mid0;
  333. Vector<Morton<DIM>> pt_mid;
  334. { // Construct sorted pt_mid
  335. Long Npt = coord.Dim() / DIM;
  336. pt_mid.ReInit(Npt);
  337. for (Long i = 0; i < Npt; i++) {
  338. pt_mid[i] = Morton<DIM>(coord.begin() + i*DIM);
  339. }
  340. Vector<Morton<DIM>> sorted_mid;
  341. comm.HyperQuickSort(pt_mid, sorted_mid);
  342. pt_mid.Swap(sorted_mid);
  343. SCTL_ASSERT(pt_mid.Dim());
  344. pt_mid0 = pt_mid[0];
  345. }
  346. { // Update M = global_min(pt_mid.Dim(), M)
  347. Long M0, M1, Npt = pt_mid.Dim();
  348. comm.Allreduce(Ptr2ConstItr<Long>(&M,1), Ptr2Itr<Long>(&M0,1), 1, Comm::CommOp::MIN);
  349. comm.Allreduce(Ptr2ConstItr<Long>(&Npt,1), Ptr2Itr<Long>(&M1,1), 1, Comm::CommOp::MIN);
  350. M = std::min(M0,M1);
  351. SCTL_ASSERT(M > 0);
  352. }
  353. { // pt_mid <-- [M points from rank-1; pt_mid; M points from rank+1]
  354. Long send_size0 = (rank+1<np ? M : 0);
  355. Long send_size1 = (rank > 0 ? M : 0);
  356. Long recv_size0 = (rank > 0 ? M : 0);
  357. Long recv_size1 = (rank+1<np ? M : 0);
  358. Vector<Morton<DIM>> pt_mid_(recv_size0 + pt_mid.Dim() + recv_size1);
  359. memcopy(pt_mid_.begin()+recv_size0, pt_mid.begin(), pt_mid.Dim());
  360. void* recv_req0 = comm.Irecv(pt_mid_.begin(), recv_size0, (rank+np-1)%np, 0);
  361. void* recv_req1 = comm.Irecv(pt_mid_.begin() + recv_size0 + pt_mid.Dim(), recv_size1, (rank+1)%np, 1);
  362. void* send_req0 = comm.Isend(pt_mid .begin() + pt_mid.Dim() - send_size0, send_size0, (rank+1)%np, 0);
  363. void* send_req1 = comm.Isend(pt_mid .begin(), send_size1, (rank+np-1)%np, 1);
  364. comm.Wait(recv_req0);
  365. comm.Wait(recv_req1);
  366. comm.Wait(send_req0);
  367. comm.Wait(send_req1);
  368. pt_mid.Swap(pt_mid_);
  369. }
  370. { // Build linear MortonID tree from pt_mid
  371. node_mid.ReInit(0);
  372. Long idx = 0;
  373. Morton<DIM> m0;
  374. Morton<DIM> mend = Morton<DIM>().Next();
  375. while (m0 < mend) {
  376. Integer d = m0.Depth();
  377. Morton<DIM> m1 = (idx + M < pt_mid.Dim() ? pt_mid[idx+M] : Morton<DIM>().Next());
  378. while (d < Morton<DIM>::MAX_DEPTH && m0.Ancestor(d) == m1.Ancestor(d)) {
  379. node_mid.PushBack(m0.Ancestor(d));
  380. d++;
  381. }
  382. m0 = m0.Ancestor(d);
  383. node_mid.PushBack(m0);
  384. m0 = m0.Next();
  385. idx = std::lower_bound(pt_mid.begin(), pt_mid.end(), m0) - pt_mid.begin();
  386. }
  387. }
  388. { // Set mins
  389. mins.ReInit(np);
  390. Long min_idx = std::lower_bound(node_mid.begin(), node_mid.end(), pt_mid0) - node_mid.begin() - 1;
  391. if (!rank || min_idx < 0) min_idx = 0;
  392. Morton<DIM> m0 = coarsest_ancestor_mid(node_mid[min_idx]);
  393. comm.Allgather(Ptr2ConstItr<Morton<DIM>>(&m0,1), 1, mins.begin(), 1);
  394. }
  395. if (balance21) { // 2:1 balance refinement // TODO: optimize
  396. Vector<Morton<DIM>> parent_mid;
  397. { // add balancing Morton IDs
  398. Vector<std::set<Morton<DIM>>> parent_mid_set(Morton<DIM>::MAX_DEPTH+1);
  399. Vector<Morton<DIM>> nlst;
  400. for (const auto& m0 : node_mid) {
  401. Integer d0 = m0.Depth();
  402. parent_mid_set[m0.Depth()].insert(m0.Ancestor(d0-1));
  403. }
  404. for (Integer d = Morton<DIM>::MAX_DEPTH; d > 0; d--) {
  405. for (const auto& m : parent_mid_set[d]) {
  406. m.NbrList(nlst, d-1, periodic);
  407. parent_mid_set[d-1].insert(nlst.begin(), nlst.end());
  408. parent_mid.PushBack(m);
  409. }
  410. }
  411. }
  412. Vector<Morton<DIM>> parent_mid_sorted;
  413. { // sort and repartition
  414. comm.HyperQuickSort(parent_mid, parent_mid_sorted);
  415. comm.PartitionS(parent_mid_sorted, mins[comm.Rank()]);
  416. }
  417. Vector<Morton<DIM>> tmp_mid;
  418. { // add children
  419. Vector<Morton<DIM>> clst;
  420. tmp_mid.PushBack(Morton<DIM>()); // include root node
  421. for (Long i = 0; i < parent_mid_sorted.Dim(); i++) {
  422. if (i+1 == parent_mid_sorted.Dim() || parent_mid_sorted[i] != parent_mid_sorted[i+1]) {
  423. const auto& m = parent_mid_sorted[i];
  424. tmp_mid.PushBack(m);
  425. m.Children(clst);
  426. for (const auto& c : clst) tmp_mid.PushBack(c);
  427. }
  428. }
  429. auto insert_ancestor_children = [](Vector<Morton<DIM>>& mvec, const Morton<DIM>& m0) {
  430. Integer d0 = m0.Depth();
  431. Vector<Morton<DIM>> clst;
  432. for (Integer d = 0; d < d0; d++) {
  433. m0.Ancestor(d).Children(clst);
  434. for (const auto& m : clst) mvec.PushBack(m);
  435. }
  436. };
  437. insert_ancestor_children(tmp_mid, mins[rank]);
  438. omp_par::merge_sort(tmp_mid.begin(), tmp_mid.end());
  439. }
  440. node_mid.ReInit(0);
  441. for (Long i = 0; i < tmp_mid.Dim(); i++) { // remove duplicates
  442. if (i+1 == tmp_mid.Dim() || tmp_mid[i] != tmp_mid[i+1]) {
  443. node_mid.PushBack(tmp_mid[i]);
  444. }
  445. }
  446. }
  447. { // Add place-holder for ghost nodes
  448. Long start_idx, end_idx;
  449. { // Set start_idx, end_idx
  450. start_idx = std::lower_bound(node_mid.begin(), node_mid.end(), mins[rank]) - node_mid.begin();
  451. end_idx = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
  452. }
  453. { // Set user_mid, user_cnt
  454. Vector<SortPair<Long,Morton<DIM>>> user_node_lst;
  455. Vector<Morton<DIM>> nlst;
  456. std::set<Long> user_procs;
  457. for (Long i = start_idx; i < end_idx; i++) {
  458. Morton<DIM> m0 = node_mid[i];
  459. Integer d0 = m0.Depth();
  460. m0.NbrList(nlst, std::max<Integer>(d0-2,0), periodic);
  461. user_procs.clear();
  462. for (const auto& m : nlst) {
  463. Morton<DIM> m_start = m.DFD();
  464. Morton<DIM> m_end = m.Next();
  465. Integer p_start = std::lower_bound(mins.begin(), mins.end(), m_start) - mins.begin() - 1;
  466. Integer p_end = std::lower_bound(mins.begin(), mins.end(), m_end ) - mins.begin();
  467. SCTL_ASSERT(0 <= p_start);
  468. SCTL_ASSERT(p_start < p_end);
  469. SCTL_ASSERT(p_end <= np);
  470. for (Long p = p_start; p < p_end; p++) {
  471. if (p != rank) user_procs.insert(p);
  472. }
  473. }
  474. for (const auto p : user_procs) {
  475. SortPair<Long,Morton<DIM>> pair;
  476. pair.key = p;
  477. pair.data = m0;
  478. user_node_lst.PushBack(pair);
  479. }
  480. }
  481. omp_par::merge_sort(user_node_lst.begin(), user_node_lst.end());
  482. user_cnt.ReInit(np);
  483. user_mid.ReInit(user_node_lst.Dim());
  484. for (Integer i = 0; i < np; i++) {
  485. SortPair<Long,Morton<DIM>> pair_start, pair_end;
  486. pair_start.key = i;
  487. pair_end.key = i+1;
  488. Long cnt_start = std::lower_bound(user_node_lst.begin(), user_node_lst.end(), pair_start) - user_node_lst.begin();
  489. Long cnt_end = std::lower_bound(user_node_lst.begin(), user_node_lst.end(), pair_end ) - user_node_lst.begin();
  490. user_cnt[i] = cnt_end - cnt_start;
  491. for (Long j = cnt_start; j < cnt_end; j++) {
  492. user_mid[j] = user_node_lst[j].data;
  493. }
  494. std::sort(user_mid.begin() + cnt_start, user_mid.begin() + cnt_end);
  495. }
  496. }
  497. Vector<Morton<DIM>> ghost_mid;
  498. { // SendRecv user_mid
  499. const Vector<Long>& send_cnt = user_cnt;
  500. Vector<Long> send_dsp(np);
  501. scan(send_dsp, send_cnt);
  502. Vector<Long> recv_cnt(np), recv_dsp(np);
  503. comm.Alltoall(send_cnt.begin(), 1, recv_cnt.begin(), 1);
  504. scan(recv_dsp, recv_cnt);
  505. const Vector<Morton<DIM>>& send_mid = user_mid;
  506. Long Nsend = send_dsp[np-1] + send_cnt[np-1];
  507. Long Nrecv = recv_dsp[np-1] + recv_cnt[np-1];
  508. SCTL_ASSERT(send_mid.Dim() == Nsend);
  509. ghost_mid.ReInit(Nrecv);
  510. comm.Alltoallv(send_mid.begin(), send_cnt.begin(), send_dsp.begin(), ghost_mid.begin(), recv_cnt.begin(), recv_dsp.begin());
  511. }
  512. { // Update node_mid <-- ghost_mid + node_mid
  513. Vector<Morton<DIM>> new_mid(end_idx-start_idx + ghost_mid.Dim());
  514. Long Nsplit = std::lower_bound(ghost_mid.begin(), ghost_mid.end(), mins[rank]) - ghost_mid.begin();
  515. for (Long i = 0; i < Nsplit; i++) {
  516. new_mid[i] = ghost_mid[i];
  517. }
  518. for (Long i = 0; i < end_idx - start_idx; i++) {
  519. new_mid[Nsplit + i] = node_mid[start_idx + i];
  520. }
  521. for (Long i = Nsplit; i < ghost_mid.Dim(); i++) {
  522. new_mid[end_idx - start_idx + i] = ghost_mid[i];
  523. }
  524. node_mid.Swap(new_mid);
  525. }
  526. }
  527. { // Set node_mid, node_attr
  528. Morton<DIM> m0 = (rank ? mins[rank] : Morton<DIM>() );
  529. Morton<DIM> m1 = (rank+1<np ? mins[rank+1] : Morton<DIM>().Next());
  530. Long Nnodes = node_mid.Dim();
  531. node_attr.ReInit(Nnodes);
  532. for (Long i = 0; i < Nnodes; i++) {
  533. node_attr[i].Leaf = !(i+1<Nnodes && node_mid[i].isAncestor(node_mid[i+1]));
  534. node_attr[i].Ghost = (node_mid[i] < m0 || node_mid[i] >= m1);
  535. }
  536. }
  537. { // Set node_lst
  538. static constexpr Integer MAX_CHILD = (1u << DIM);
  539. static constexpr Integer MAX_NBRS = sctl::pow<DIM,Integer>(3);
  540. Long Nnodes = node_mid.Dim();
  541. node_lst.ReInit(Nnodes);
  542. Vector<Long> ancestors(Morton<DIM>::MAX_DEPTH);
  543. Vector<Long> child_cnt(Morton<DIM>::MAX_DEPTH);
  544. #pragma omp parallel for schedule(static)
  545. for (Long i = 0; i < Nnodes; i++) {
  546. node_lst[i].p2n = -1;
  547. node_lst[i].parent = -1;
  548. for (Integer j = 0; j < MAX_CHILD; j++) node_lst[i].child[j] = -1;
  549. for (Integer j = 0; j < MAX_NBRS; j++) node_lst[i].nbr[j] = -1;
  550. }
  551. for (Long i = 0; i < Nnodes; i++) { // Set parent_lst, child_lst_
  552. Integer depth = node_mid[i].Depth();
  553. ancestors[depth] = i;
  554. child_cnt[depth] = 0;
  555. if (depth) {
  556. Long p = ancestors[depth-1];
  557. Long& c = child_cnt[depth-1];
  558. node_lst[i].parent = p;
  559. node_lst[p].child[c] = i;
  560. node_lst[p].p2n = c;
  561. c++;
  562. }
  563. }
  564. // TODO: add nbr-list
  565. }
  566. if (0) { // Check tree
  567. Morton<DIM> m0;
  568. SCTL_ASSERT(node_mid.Dim() && m0 == node_mid[0]);
  569. for (Long i = 1; i < node_mid.Dim(); i++) {
  570. const auto& m = node_mid[i];
  571. if (m0.isAncestor(m)) m0 = m0.Ancestor(m0.Depth()+1);
  572. else m0 = m0.Next();
  573. SCTL_ASSERT(m0 == m);
  574. }
  575. SCTL_ASSERT(m0.Next() == Morton<DIM>().Next());
  576. }
  577. { // Update node_data, node_cnt
  578. Long start_idx, end_idx;
  579. { // Set start_idx, end_idx
  580. start_idx = std::lower_bound(node_mid.begin(), node_mid.end(), mins[rank]) - node_mid.begin();
  581. end_idx = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
  582. }
  583. comm.PartitionS(node_mid_orig, mins[comm.Rank()]);
  584. Vector<Long> new_cnt_range0(node_mid.Dim()), new_cnt_range1(node_mid.Dim());
  585. { // Set new_cnt_range0, new_cnt_range1
  586. for (Long i = 0; i < start_idx; i++) {
  587. new_cnt_range0[i] = 0;
  588. new_cnt_range1[i] = 0;
  589. }
  590. for (Long i = start_idx; i < end_idx; i++) {
  591. auto m0 = (node_mid[i+0]);
  592. auto m1 = (i+1==end_idx ? Morton<DIM>().Next() : (node_mid[i+1]));
  593. new_cnt_range0[i] = std::lower_bound(node_mid_orig.begin(), node_mid_orig.begin() + node_mid_orig.Dim(), m0) - node_mid_orig.begin();
  594. new_cnt_range1[i] = std::lower_bound(node_mid_orig.begin(), node_mid_orig.begin() + node_mid_orig.Dim(), m1) - node_mid_orig.begin();
  595. }
  596. for (Long i = end_idx; i < node_mid.Dim(); i++) {
  597. new_cnt_range0[i] = 0;
  598. new_cnt_range1[i] = 0;
  599. }
  600. }
  601. Vector<Long> cnt_tmp;
  602. Vector<char> data_tmp;
  603. for (const auto& pair : node_data) {
  604. const std::string& data_name = pair.first;
  605. Long dof;
  606. Iterator<Vector<char>> data_;
  607. Iterator<Vector<Long>> cnt_;
  608. GetData_(data_, cnt_, data_name);
  609. { // Set dof
  610. StaticArray<Long,2> Nl, Ng;
  611. Nl[0] = data_->Dim();
  612. Nl[1] = omp_par::reduce(cnt_->begin(), cnt_->Dim());
  613. comm.Allreduce((ConstIterator<Long>)Nl, (Iterator<Long>)Ng, 2, Comm::CommOp::SUM);
  614. dof = Ng[0] / std::max<Long>(Ng[1],1);
  615. SCTL_ASSERT(Nl[0] == Nl[1] * dof);
  616. SCTL_ASSERT(Ng[0] == Ng[1] * dof);
  617. }
  618. Long data_dsp = omp_par::reduce(cnt_->begin(), start_idx_orig);
  619. Long data_cnt = omp_par::reduce(cnt_->begin() + start_idx_orig, end_idx_orig - start_idx_orig);
  620. data_tmp.ReInit(data_cnt * dof, data_->begin() + data_dsp * dof, true);
  621. cnt_tmp.ReInit(end_idx_orig - start_idx_orig, cnt_->begin() + start_idx_orig, true);
  622. comm.PartitionN(cnt_tmp, node_mid_orig.Dim());
  623. cnt_->ReInit(node_mid.Dim());
  624. for (Long i = 0; i < node_mid.Dim(); i++) {
  625. Long sum = 0;
  626. Long j0 = new_cnt_range0[i];
  627. Long j1 = new_cnt_range1[i];
  628. for (Long j = j0; j < j1; j++) sum += cnt_tmp[j];
  629. cnt_[0][i] = sum;
  630. }
  631. SCTL_ASSERT(omp_par::reduce(cnt_->begin(), cnt_->Dim()) == omp_par::reduce(cnt_tmp.begin(), cnt_tmp.Dim()));
  632. Long Ndata = omp_par::reduce(cnt_->begin(), cnt_->Dim()) * dof;
  633. comm.PartitionN(data_tmp, Ndata);
  634. SCTL_ASSERT(data_tmp.Dim() == Ndata);
  635. data_->Swap(data_tmp);
  636. }
  637. }
  638. }
  639. template <class ValueType> void AddData(const std::string& name, const Vector<ValueType>& data, const Vector<Long>& cnt) {
  640. Long dof;
  641. { // Check dof
  642. StaticArray<Long,2> Nl, Ng;
  643. Nl[0] = data.Dim();
  644. Nl[1] = omp_par::reduce(cnt.begin(), cnt.Dim());
  645. comm.Allreduce((ConstIterator<Long>)Nl, (Iterator<Long>)Ng, 2, Comm::CommOp::SUM);
  646. dof = Ng[0] / std::max<Long>(Ng[1],1);
  647. SCTL_ASSERT(Nl[0] == Nl[1] * dof);
  648. SCTL_ASSERT(Ng[0] == Ng[1] * dof);
  649. }
  650. if (dof) SCTL_ASSERT(cnt.Dim() == node_mid.Dim());
  651. SCTL_ASSERT(node_data.find(name) == node_data.end());
  652. node_data[name].ReInit(data.Dim()*sizeof(ValueType), (Iterator<char>)data.begin(), true);
  653. node_cnt [name] = cnt;
  654. }
  655. template <class ValueType> void GetData(Vector<ValueType>& data, Vector<Long>& cnt, const std::string& name) const {
  656. const auto data_ = node_data.find(name);
  657. const auto cnt_ = node_cnt.find(name);
  658. SCTL_ASSERT(data_ != node_data.end());
  659. SCTL_ASSERT( cnt_ != node_cnt .end());
  660. data.ReInit(data_->second.Dim()/sizeof(ValueType), (Iterator<ValueType>)data_->second.begin(), false);
  661. SCTL_ASSERT(data.Dim()*(Long)sizeof(ValueType) == data_->second.Dim());
  662. cnt .ReInit( cnt_->second.Dim(), (Iterator<Long>)cnt_->second.begin(), false);
  663. }
  664. template <class ValueType> void ReduceBroadcast(const std::string& name) {
  665. Integer np = comm.Size();
  666. Integer rank = comm.Rank();
  667. Vector<Long> dsp;
  668. Iterator<Vector<char>> data_;
  669. Iterator<Vector<Long>> cnt_;
  670. GetData_(data_, cnt_, name);
  671. Vector<ValueType> data(data_->Dim()/sizeof(ValueType), (Iterator<ValueType>)data_->begin(), false);
  672. Vector<Long>& cnt = *cnt_;
  673. scan(dsp, cnt);
  674. Long dof;
  675. { // Set dof
  676. StaticArray<Long,2> Nl, Ng;
  677. Nl[0] = data.Dim();
  678. Nl[1] = omp_par::reduce(cnt.begin(), cnt.Dim());
  679. comm.Allreduce((ConstIterator<Long>)Nl, (Iterator<Long>)Ng, 2, Comm::CommOp::SUM);
  680. dof = Ng[0] / std::max<Long>(Ng[1],1);
  681. SCTL_ASSERT(Nl[0] == Nl[1] * dof);
  682. SCTL_ASSERT(Ng[0] == Ng[1] * dof);
  683. }
  684. { // Reduce
  685. Vector<Morton<DIM>> send_mid, recv_mid;
  686. Vector<Long> send_node_cnt(np), send_node_dsp(np);
  687. Vector<Long> recv_node_cnt(np), recv_node_dsp(np);
  688. { // Set send_mid, send_node_cnt, send_node_dsp, recv_mid, recv_node_cnt, recv_node_dsp
  689. { // Set send_mid
  690. Morton<DIM> m0 = mins[rank];
  691. for (Integer d = 0; d < m0.Depth(); d++) {
  692. send_mid.PushBack(m0.Ancestor(d));
  693. }
  694. }
  695. for (Integer p = 0; p < np; p++) {
  696. Long start_idx = std::lower_bound(send_mid.begin(), send_mid.end(), mins[p]) - send_mid.begin();
  697. Long end_idx = std::lower_bound(send_mid.begin(), send_mid.end(), (p+1==np ? Morton<DIM>().Next() : mins[p+1])) - send_mid.begin();
  698. send_node_cnt[p] = end_idx - start_idx;
  699. }
  700. scan(send_node_dsp, send_node_cnt);
  701. SCTL_ASSERT(send_node_dsp[np-1]+send_node_cnt[np-1] == send_mid.Dim());
  702. comm.Alltoall(send_node_cnt.begin(), 1, recv_node_cnt.begin(), 1);
  703. scan(recv_node_dsp, recv_node_cnt);
  704. recv_mid.ReInit(recv_node_dsp[np-1] + recv_node_cnt[np-1]);
  705. comm.Alltoallv(send_mid.begin(), send_node_cnt.begin(), send_node_dsp.begin(), recv_mid.begin(), recv_node_cnt.begin(), recv_node_dsp.begin());
  706. }
  707. Vector<Long> send_data_cnt, send_data_dsp;
  708. Vector<Long> recv_data_cnt, recv_data_dsp;
  709. { // Set send_data_cnt, send_data_dsp
  710. send_data_cnt.ReInit(send_mid.Dim());
  711. recv_data_cnt.ReInit(recv_mid.Dim());
  712. for (Long i = 0; i < send_mid.Dim(); i++) {
  713. Long idx = std::lower_bound(node_mid.begin(), node_mid.end(), send_mid[i]) - node_mid.begin();
  714. SCTL_ASSERT(send_mid[i] == node_mid[idx]);
  715. send_data_cnt[i] = cnt[idx];
  716. }
  717. scan(send_data_dsp, send_data_cnt);
  718. comm.Alltoallv(send_data_cnt.begin(), send_node_cnt.begin(), send_node_dsp.begin(), recv_data_cnt.begin(), recv_node_cnt.begin(), recv_node_dsp.begin());
  719. scan(recv_data_dsp, recv_data_cnt);
  720. }
  721. Vector<ValueType> send_buff, recv_buff;
  722. Vector<Long> send_buff_cnt(np), send_buff_dsp(np);
  723. Vector<Long> recv_buff_cnt(np), recv_buff_dsp(np);
  724. { // Set send_buff, send_buff_cnt, send_buff_dsp, recv_buff, recv_buff_cnt, recv_buff_dsp
  725. Long N_send_nodes = send_mid.Dim();
  726. Long N_recv_nodes = recv_mid.Dim();
  727. if (N_send_nodes) send_buff.ReInit((send_data_dsp[N_send_nodes-1] + send_data_cnt[N_send_nodes-1]) * dof);
  728. if (N_recv_nodes) recv_buff.ReInit((recv_data_dsp[N_recv_nodes-1] + recv_data_cnt[N_recv_nodes-1]) * dof);
  729. for (Long i = 0; i < N_send_nodes; i++) {
  730. Long idx = std::lower_bound(node_mid.begin(), node_mid.end(), send_mid[i]) - node_mid.begin();
  731. SCTL_ASSERT(send_mid[i] == node_mid[idx]);
  732. Long dsp_ = dsp[idx] * dof;
  733. Long cnt_ = cnt[idx] * dof;
  734. Long send_data_dsp_ = send_data_dsp[i] * dof;
  735. Long send_data_cnt_ = send_data_cnt[i] * dof;
  736. SCTL_ASSERT(send_data_cnt_ == cnt_);
  737. for (Long j = 0; j < cnt_; j++) {
  738. send_buff[send_data_dsp_+j] = data[dsp_+j];
  739. }
  740. }
  741. for (Integer p = 0; p < np; p++) {
  742. Long send_buff_cnt_ = 0;
  743. Long recv_buff_cnt_ = 0;
  744. for (Long i = 0; i < send_node_cnt[p]; i++) {
  745. send_buff_cnt_ += send_data_cnt[send_node_dsp[p]+i];
  746. }
  747. for (Long i = 0; i < recv_node_cnt[p]; i++) {
  748. recv_buff_cnt_ += recv_data_cnt[recv_node_dsp[p]+i];
  749. }
  750. send_buff_cnt[p] = send_buff_cnt_ * dof;
  751. recv_buff_cnt[p] = recv_buff_cnt_ * dof;
  752. }
  753. scan(send_buff_dsp, send_buff_cnt);
  754. scan(recv_buff_dsp, recv_buff_cnt);
  755. comm.Alltoallv(send_buff.begin(), send_buff_cnt.begin(), send_buff_dsp.begin(), recv_buff.begin(), recv_buff_cnt.begin(), recv_buff_dsp.begin());
  756. }
  757. { // Reduction
  758. Long N_recv_nodes = recv_mid.Dim();
  759. for (Long i = 0; i < N_recv_nodes; i++) {
  760. Long idx = std::lower_bound(node_mid.begin(), node_mid.end(), recv_mid[i]) - node_mid.begin();
  761. Long dsp_ = dsp[idx] * dof;
  762. Long cnt_ = cnt[idx] * dof;
  763. Long recv_data_dsp_ = recv_data_dsp[i] * dof;
  764. Long recv_data_cnt_ = recv_data_cnt[i] * dof;
  765. if (recv_data_cnt_ == cnt_) {
  766. for (Long j = 0; j < cnt_; j++) {
  767. data[dsp_+j] += recv_buff[recv_data_dsp_+j];
  768. }
  769. }
  770. }
  771. }
  772. }
  773. Broadcast<ValueType>(name);
  774. }
  775. template <class ValueType> void Broadcast(const std::string& name) {
  776. Integer np = comm.Size();
  777. Integer rank = comm.Rank();
  778. Vector<Long> dsp;
  779. Iterator<Vector<char>> data_;
  780. Iterator<Vector<Long>> cnt_;
  781. GetData_(data_, cnt_, name);
  782. Vector<ValueType> data(data_->Dim()/sizeof(ValueType), (Iterator<ValueType>)data_->begin(), false);
  783. Vector<Long>& cnt = *cnt_;
  784. scan(dsp, cnt);
  785. Long dof;
  786. { // Set dof
  787. StaticArray<Long,2> Nl, Ng;
  788. Nl[0] = data.Dim();
  789. Nl[1] = omp_par::reduce(cnt.begin(), cnt.Dim());
  790. comm.Allreduce((ConstIterator<Long>)Nl, (Iterator<Long>)Ng, 2, Comm::CommOp::SUM);
  791. dof = Ng[0] / std::max<Long>(Ng[1],1);
  792. SCTL_ASSERT(Nl[0] == Nl[1] * dof);
  793. SCTL_ASSERT(Ng[0] == Ng[1] * dof);
  794. }
  795. { // Broadcast
  796. const Vector<Morton<DIM>>& send_mid = user_mid;
  797. const Vector<Long>& send_node_cnt = user_cnt;
  798. Vector<Long> send_node_dsp(np);
  799. { // Set send_dsp
  800. SCTL_ASSERT(send_node_cnt.Dim() == np);
  801. scan(send_node_dsp, send_node_cnt);
  802. SCTL_ASSERT(send_node_dsp[np-1] + send_node_cnt[np-1] == send_mid.Dim());
  803. }
  804. Vector<Morton<DIM>> recv_mid;
  805. Vector<Long> recv_node_cnt(np), recv_node_dsp(np);
  806. { // Set recv_mid, recv_node_cnt, recv_node_dsp
  807. comm.Alltoall(send_node_cnt.begin(), 1, recv_node_cnt.begin(), 1);
  808. scan(recv_node_dsp, recv_node_cnt);
  809. recv_mid.ReInit(recv_node_dsp[np-1] + recv_node_cnt[np-1]);
  810. comm.Alltoallv(send_mid.begin(), send_node_cnt.begin(), send_node_dsp.begin(), recv_mid.begin(), recv_node_cnt.begin(), recv_node_dsp.begin());
  811. }
  812. Vector<Long> send_data_cnt, send_data_dsp;
  813. Vector<Long> recv_data_cnt, recv_data_dsp;
  814. { // Set send_data_cnt, send_data_dsp
  815. send_data_cnt.ReInit(send_mid.Dim());
  816. recv_data_cnt.ReInit(recv_mid.Dim());
  817. for (Long i = 0; i < send_mid.Dim(); i++) {
  818. Long idx = std::lower_bound(node_mid.begin(), node_mid.end(), send_mid[i]) - node_mid.begin();
  819. SCTL_ASSERT(send_mid[i] == node_mid[idx]);
  820. send_data_cnt[i] = cnt[idx];
  821. }
  822. scan(send_data_dsp, send_data_cnt);
  823. comm.Alltoallv(send_data_cnt.begin(), send_node_cnt.begin(), send_node_dsp.begin(), recv_data_cnt.begin(), recv_node_cnt.begin(), recv_node_dsp.begin());
  824. scan(recv_data_dsp, recv_data_cnt);
  825. }
  826. Vector<ValueType> send_buff, recv_buff;
  827. Vector<Long> send_buff_cnt(np), send_buff_dsp(np);
  828. Vector<Long> recv_buff_cnt(np), recv_buff_dsp(np);
  829. { // Set send_buff, send_buff_cnt, send_buff_dsp, recv_buff, recv_buff_cnt, recv_buff_dsp
  830. Long N_send_nodes = send_mid.Dim();
  831. Long N_recv_nodes = recv_mid.Dim();
  832. if (N_send_nodes) send_buff.ReInit((send_data_dsp[N_send_nodes-1] + send_data_cnt[N_send_nodes-1]) * dof);
  833. if (N_recv_nodes) recv_buff.ReInit((recv_data_dsp[N_recv_nodes-1] + recv_data_cnt[N_recv_nodes-1]) * dof);
  834. for (Long i = 0; i < N_send_nodes; i++) {
  835. Long idx = std::lower_bound(node_mid.begin(), node_mid.end(), send_mid[i]) - node_mid.begin();
  836. SCTL_ASSERT(send_mid[i] == node_mid[idx]);
  837. Long dsp_ = dsp[idx] * dof;
  838. Long cnt_ = cnt[idx] * dof;
  839. Long send_data_dsp_ = send_data_dsp[i] * dof;
  840. Long send_data_cnt_ = send_data_cnt[i] * dof;
  841. SCTL_ASSERT(send_data_cnt_ == cnt_);
  842. for (Long j = 0; j < cnt_; j++) {
  843. send_buff[send_data_dsp_+j] = data[dsp_+j];
  844. }
  845. }
  846. for (Integer p = 0; p < np; p++) {
  847. Long send_buff_cnt_ = 0;
  848. Long recv_buff_cnt_ = 0;
  849. for (Long i = 0; i < send_node_cnt[p]; i++) {
  850. send_buff_cnt_ += send_data_cnt[send_node_dsp[p]+i];
  851. }
  852. for (Long i = 0; i < recv_node_cnt[p]; i++) {
  853. recv_buff_cnt_ += recv_data_cnt[recv_node_dsp[p]+i];
  854. }
  855. send_buff_cnt[p] = send_buff_cnt_ * dof;
  856. recv_buff_cnt[p] = recv_buff_cnt_ * dof;
  857. }
  858. scan(send_buff_dsp, send_buff_cnt);
  859. scan(recv_buff_dsp, recv_buff_cnt);
  860. comm.Alltoallv(send_buff.begin(), send_buff_cnt.begin(), send_buff_dsp.begin(), recv_buff.begin(), recv_buff_cnt.begin(), recv_buff_dsp.begin());
  861. }
  862. Long start_idx, end_idx;
  863. { // Set start_idx, end_idx
  864. start_idx = std::lower_bound(node_mid.begin(), node_mid.end(), mins[rank]) - node_mid.begin();
  865. end_idx = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
  866. SCTL_ASSERT(0 <= start_idx);
  867. SCTL_ASSERT(start_idx < end_idx);
  868. SCTL_ASSERT(end_idx <= node_mid.Dim());
  869. }
  870. { // Update data <-- data + recv_buff
  871. Long Nsplit = std::lower_bound(recv_mid.begin(), recv_mid.end(), mins[rank]) - recv_mid.begin();
  872. SCTL_ASSERT(recv_mid.Dim()-Nsplit == node_mid.Dim()-end_idx);
  873. SCTL_ASSERT(Nsplit == start_idx);
  874. Long N0 = (start_idx ? dsp[start_idx-1] + cnt[start_idx-1] : 0) * dof;
  875. Long N1 = (end_idx ? dsp[end_idx-1] + cnt[end_idx-1] : 0) * dof;
  876. Long Ns = (Nsplit ? recv_data_dsp[Nsplit-1] + recv_data_cnt[Nsplit-1] : 0) * dof;
  877. if (N0 != Ns || recv_buff.Dim() != N0+data.Dim()-N1) { // resize data and preserve non-ghost data
  878. Vector<char> data_new((recv_buff.Dim() + N1-N0) * sizeof(ValueType));
  879. memcopy(data_new.begin() + Ns * sizeof(ValueType), data_->begin() + N0 * sizeof(ValueType), (N1-N0) * sizeof(ValueType));
  880. data_->Swap(data_new);
  881. data.ReInit(data_->Dim()/sizeof(ValueType), (Iterator<ValueType>)data_->begin(), false);
  882. }
  883. memcopy(cnt.begin(), recv_data_cnt.begin(), start_idx);
  884. memcopy(cnt.begin()+end_idx, recv_data_cnt.begin()+Nsplit, node_mid.Dim()-end_idx);
  885. memcopy(data.begin(), recv_buff.begin(), Ns);
  886. memcopy(data.begin()+data.Dim()+Ns-recv_buff.Dim(), recv_buff.begin()+Ns, recv_buff.Dim()-Ns);
  887. }
  888. }
  889. }
  890. void DeleteData(const std::string& name) {
  891. SCTL_ASSERT(node_data.find(name) != node_data.end());
  892. SCTL_ASSERT(node_cnt .find(name) != node_cnt .end());
  893. node_data.erase(name);
  894. node_cnt .erase(name);
  895. }
  896. void WriteTreeVTK(std::string fname, bool show_ghost = false) const {
  897. typedef typename VTUData::VTKReal VTKReal;
  898. VTUData vtu_data;
  899. if (DIM <= 3) { // Set vtu data
  900. static const Integer Ncorner = (1u << DIM);
  901. Vector<VTKReal> &coord = vtu_data.coord;
  902. //Vector<VTKReal> &value = vtu_data.value;
  903. Vector<int32_t> &connect = vtu_data.connect;
  904. Vector<int32_t> &offset = vtu_data.offset;
  905. Vector<uint8_t> &types = vtu_data.types;
  906. StaticArray<VTKReal, DIM> c;
  907. Long point_cnt = coord.Dim() / 3;
  908. Long connect_cnt = connect.Dim();
  909. for (Long nid = 0; nid < node_mid.Dim(); nid++) {
  910. const Morton<DIM> &mid = node_mid[nid];
  911. const NodeAttr &attr = node_attr[nid];
  912. if (!show_ghost && attr.Ghost) continue;
  913. if (!attr.Leaf) continue;
  914. mid.Coord((Iterator<VTKReal>)c);
  915. VTKReal s = sctl::pow<VTKReal>(0.5, mid.Depth());
  916. for (Integer j = 0; j < Ncorner; j++) {
  917. for (Integer i = 0; i < DIM; i++) coord.PushBack(c[i] + (j & (1u << i) ? 1 : 0) * s);
  918. for (Integer i = DIM; i < 3; i++) coord.PushBack(0);
  919. connect.PushBack(point_cnt);
  920. connect_cnt++;
  921. point_cnt++;
  922. }
  923. offset.PushBack(connect_cnt);
  924. if (DIM == 2)
  925. types.PushBack(8);
  926. else if (DIM == 3)
  927. types.PushBack(11);
  928. else
  929. types.PushBack(4);
  930. }
  931. }
  932. vtu_data.WriteVTK(fname, comm);
  933. }
  934. protected:
  935. void GetData_(Iterator<Vector<char>>& data, Iterator<Vector<Long>>& cnt, const std::string& name) {
  936. auto data_ = node_data.find(name);
  937. const auto cnt_ = node_cnt.find(name);
  938. SCTL_ASSERT(data_ != node_data.end());
  939. SCTL_ASSERT( cnt_ != node_cnt .end());
  940. data = Ptr2Itr<Vector<char>>(&data_->second,1);
  941. cnt = Ptr2Itr<Vector<Long>>(& cnt_->second,1);
  942. }
  943. static void scan(Vector<Long>& dsp, const Vector<Long>& cnt) {
  944. dsp.ReInit(cnt.Dim());
  945. if (cnt.Dim()) dsp[0] = 0;
  946. omp_par::scan(cnt.begin(), dsp.begin(), cnt.Dim());
  947. }
  948. template <typename A, typename B> struct SortPair {
  949. int operator<(const SortPair<A, B> &p1) const { return key < p1.key; }
  950. A key;
  951. B data;
  952. };
  953. private:
  954. Vector<Morton<DIM>> mins;
  955. Vector<Morton<DIM>> node_mid;
  956. Vector<NodeAttr> node_attr;
  957. Vector<NodeLists> node_lst;
  958. std::map<std::string, Vector<char>> node_data;
  959. std::map<std::string, Vector<Long>> node_cnt;
  960. Vector<Morton<DIM>> user_mid;
  961. Vector<Long> user_cnt;
  962. Comm comm;
  963. };
  964. template <class Real, Integer DIM, class BaseTree = Tree<DIM>> class PtTree : public BaseTree {
  965. public:
  966. PtTree(const Comm& comm = Comm::Self()) : BaseTree(comm) {}
  967. ~PtTree() {
  968. #ifdef SCTL_MEMDEBUG
  969. for (auto& pair : data_pt_name) {
  970. Vector<Real> data;
  971. Vector<Long> cnt;
  972. this->GetData(data, cnt, pair.second);
  973. SCTL_ASSERT(scatter_idx.find(pair.second) != scatter_idx.end());
  974. }
  975. #endif
  976. }
  977. void UpdateRefinement(const Vector<Real>& coord, Long M = 1, bool balance21 = 0, bool periodic = 0) {
  978. const auto& comm = this->GetComm();
  979. BaseTree::UpdateRefinement(coord, M, balance21, periodic);
  980. Long start_node_idx, end_node_idx;
  981. { // Set start_node_idx, end_node_idx
  982. const auto& mins = this->GetPartitionMID();
  983. const auto& node_mid = this->GetNodeMID();
  984. Integer np = comm.Size();
  985. Integer rank = comm.Rank();
  986. start_node_idx = std::lower_bound(node_mid.begin(), node_mid.end(), mins[rank]) - node_mid.begin();
  987. end_node_idx = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
  988. }
  989. const auto& mins = this->GetPartitionMID();
  990. const auto& node_mid = this->GetNodeMID();
  991. for (const auto& pair : pt_mid) {
  992. const auto& pt_name = pair.first;
  993. auto& pt_mid_ = pt_mid[pt_name];
  994. auto& scatter_idx_ = scatter_idx[pt_name];
  995. comm.PartitionS(pt_mid_, mins[comm.Rank()]);
  996. comm.PartitionN(scatter_idx_, pt_mid_.Dim());
  997. Vector<Long> pt_cnt(node_mid.Dim());
  998. for (Long i = 0; i < node_mid.Dim(); i++) { // Set pt_cnt
  999. Long start = std::lower_bound(pt_mid_.begin(), pt_mid_.end(), node_mid[i]) - pt_mid_.begin();
  1000. Long end = std::lower_bound(pt_mid_.begin(), pt_mid_.end(), (i+1==node_mid.Dim() ? Morton<DIM>().Next() : node_mid[i+1])) - pt_mid_.begin();
  1001. if (i == 0) SCTL_ASSERT(start == 0);
  1002. if (i+1 == node_mid.Dim()) SCTL_ASSERT(end == pt_mid_.Dim());
  1003. pt_cnt[i] = end - start;
  1004. }
  1005. for (const auto& pair : data_pt_name) {
  1006. if (pair.second == pt_name) {
  1007. const auto& data_name = pair.first;
  1008. Iterator<Vector<char>> data;
  1009. Iterator<Vector<Long>> cnt;
  1010. this->GetData_(data, cnt, data_name);
  1011. { // Update data
  1012. Long dof = 0;
  1013. { // Set dof
  1014. StaticArray<Long,2> Nl = {0, 0}, Ng;
  1015. Nl[0] = data->Dim();
  1016. for (Long i = 0; i < cnt->Dim(); i++) Nl[1] += cnt[0][i];
  1017. comm.Allreduce((ConstIterator<Long>)Nl, (Iterator<Long>)Ng, 2, Comm::CommOp::SUM);
  1018. dof = Ng[0] / std::max<Long>(Ng[1],1);
  1019. }
  1020. Long offset = 0, count = 0;
  1021. SCTL_ASSERT(0 <= start_node_idx);
  1022. SCTL_ASSERT(start_node_idx <= end_node_idx);
  1023. SCTL_ASSERT(end_node_idx <= cnt->Dim());
  1024. for (Long i = 0; i < start_node_idx; i++) offset += cnt[0][i];
  1025. for (Long i = start_node_idx; i < end_node_idx; i++) count += cnt[0][i];
  1026. offset *= dof;
  1027. count *= dof;
  1028. Vector<char> data_(count, data->begin() + offset);
  1029. comm.PartitionN(data_, pt_mid_.Dim());
  1030. data->Swap(data_);
  1031. }
  1032. cnt[0] = pt_cnt;
  1033. }
  1034. }
  1035. }
  1036. }
  1037. void AddParticles(const std::string& name, const Vector<Real>& coord) {
  1038. const auto& mins = this->GetPartitionMID();
  1039. const auto& node_mid = this->GetNodeMID();
  1040. const auto& comm = this->GetComm();
  1041. SCTL_ASSERT(scatter_idx.find(name) == scatter_idx.end());
  1042. Vector<Long>& scatter_idx_ = scatter_idx[name];
  1043. Long N = coord.Dim() / DIM;
  1044. SCTL_ASSERT(coord.Dim() == N * DIM);
  1045. Nlocal[name] = N;
  1046. Vector<Morton<DIM>>& pt_mid_ = pt_mid[name];
  1047. if (pt_mid_.Dim() != N) pt_mid_.ReInit(N);
  1048. for (Long i = 0; i < N; i++) {
  1049. pt_mid_[i] = Morton<DIM>(coord.begin() + i*DIM);
  1050. }
  1051. comm.SortScatterIndex(pt_mid_, scatter_idx_, &mins[comm.Rank()]);
  1052. comm.ScatterForward(pt_mid_, scatter_idx_);
  1053. AddParticleData(name, name, coord);
  1054. { // Set node_cnt
  1055. Iterator<Vector<char>> data_;
  1056. Iterator<Vector<Long>> cnt_;
  1057. this->GetData_(data_,cnt_,name);
  1058. cnt_[0].ReInit(node_mid.Dim());
  1059. for (Long i = 0; i < node_mid.Dim(); i++) {
  1060. Long start = std::lower_bound(pt_mid_.begin(), pt_mid_.end(), node_mid[i]) - pt_mid_.begin();
  1061. Long end = std::lower_bound(pt_mid_.begin(), pt_mid_.end(), (i+1==node_mid.Dim() ? Morton<DIM>().Next() : node_mid[i+1])) - pt_mid_.begin();
  1062. if (i == 0) SCTL_ASSERT(start == 0);
  1063. if (i+1 == node_mid.Dim()) SCTL_ASSERT(end == pt_mid_.Dim());
  1064. cnt_[0][i] = end - start;
  1065. }
  1066. }
  1067. }
  1068. void AddParticleData(const std::string& data_name, const std::string& particle_name, const Vector<Real>& data) {
  1069. SCTL_ASSERT(scatter_idx.find(particle_name) != scatter_idx.end());
  1070. SCTL_ASSERT(data_pt_name.find(data_name) == data_pt_name.end());
  1071. data_pt_name[data_name] = particle_name;
  1072. Iterator<Vector<char>> data_;
  1073. Iterator<Vector<Long>> cnt_;
  1074. this->AddData(data_name, Vector<Real>(), Vector<Long>());
  1075. this->GetData_(data_,cnt_,data_name);
  1076. { // Set data_[0]
  1077. data_[0].ReInit(data.Dim()*sizeof(Real), (Iterator<char>)data.begin(), true);
  1078. this->GetComm().ScatterForward(data_[0], scatter_idx[particle_name]);
  1079. }
  1080. if (data_name != particle_name) { // Set cnt_[0]
  1081. Vector<Real> pt_coord;
  1082. Vector<Long> pt_cnt;
  1083. this->GetData(pt_coord, pt_cnt, particle_name);
  1084. cnt_[0] = pt_cnt;
  1085. }
  1086. }
  1087. void GetParticleData(Vector<Real>& data, const std::string& data_name) const {
  1088. SCTL_ASSERT(data_pt_name.find(data_name) != data_pt_name.end());
  1089. const std::string& particle_name = data_pt_name.find(data_name)->second;
  1090. SCTL_ASSERT(scatter_idx.find(particle_name) != scatter_idx.end());
  1091. const auto& scatter_idx_ = scatter_idx.find(particle_name)->second;
  1092. const Long Nlocal_ = Nlocal.find(particle_name)->second;
  1093. const auto& mins = this->GetPartitionMID();
  1094. const auto& node_mid = this->GetNodeMID();
  1095. const auto& comm = this->GetComm();
  1096. Long dof;
  1097. Vector<Long> dsp;
  1098. Vector<Long> cnt_;
  1099. Vector<Real> data_;
  1100. this->GetData(data_, cnt_, data_name);
  1101. SCTL_ASSERT(cnt_.Dim() == node_mid.Dim());
  1102. BaseTree::scan(dsp, cnt_);
  1103. { // Set dof
  1104. Long Nn = node_mid.Dim();
  1105. StaticArray<Long,2> Ng, Nl = {data_.Dim(), dsp[Nn-1]+cnt_[Nn-1]};
  1106. comm.Allreduce((ConstIterator<Long>)Nl, (Iterator<Long>)Ng, 2, Comm::CommOp::SUM);
  1107. dof = Ng[0] / std::max<Long>(Ng[1],1);
  1108. }
  1109. { // Set data
  1110. Integer np = comm.Size();
  1111. Integer rank = comm.Rank();
  1112. Long N0 = std::lower_bound(node_mid.begin(), node_mid.end(), mins[rank]) - node_mid.begin();
  1113. Long N1 = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
  1114. Long start = dsp[N0] * dof;
  1115. Long end = (N1<dsp.Dim() ? dsp[N1] : dsp[N1-1]+cnt_[N1-1]) * dof;
  1116. data.ReInit(end-start, data_.begin()+start, true);
  1117. comm.ScatterReverse(data, scatter_idx_, Nlocal_ * dof);
  1118. }
  1119. }
  1120. void DeleteParticleData(const std::string& data_name) {
  1121. SCTL_ASSERT(data_pt_name.find(data_name) != data_pt_name.end());
  1122. auto particle_name = data_pt_name[data_name];
  1123. if (data_name == particle_name) {
  1124. std::vector<std::string> data_name_lst;
  1125. for (auto& pair : data_pt_name) {
  1126. if (pair.second == particle_name) {
  1127. data_name_lst.push_back(pair.first);
  1128. }
  1129. }
  1130. for (auto x : data_name_lst) {
  1131. if (x != particle_name) {
  1132. DeleteParticleData(x);
  1133. }
  1134. }
  1135. Nlocal.erase(particle_name);
  1136. }
  1137. this->DeleteData(data_name);
  1138. data_pt_name.erase(data_name);
  1139. }
  1140. void WriteParticleVTK(std::string fname, std::string data_name, bool show_ghost = false) const {
  1141. typedef typename VTUData::VTKReal VTKReal;
  1142. const auto& node_mid = this->GetNodeMID();
  1143. const auto& node_attr = this->GetNodeAttr();
  1144. VTUData vtu_data;
  1145. if (DIM <= 3) { // Set vtu data
  1146. SCTL_ASSERT(data_pt_name.find(data_name) != data_pt_name.end());
  1147. std::string particle_name = data_pt_name.find(data_name)->second;
  1148. Vector<Real> pt_coord;
  1149. Vector<Real> pt_value;
  1150. Vector<Long> pt_cnt;
  1151. Vector<Long> pt_dsp;
  1152. Long value_dof = 0;
  1153. { // Set pt_coord, pt_cnt, pt_dsp
  1154. this->GetData(pt_coord, pt_cnt, particle_name);
  1155. Tree<DIM>::scan(pt_dsp, pt_cnt);
  1156. }
  1157. if (particle_name != data_name) { // Set pt_value, value_dof
  1158. Vector<Long> pt_cnt;
  1159. this->GetData(pt_value, pt_cnt, data_name);
  1160. Long Npt = omp_par::reduce(pt_cnt.begin(), pt_cnt.Dim());
  1161. value_dof = pt_value.Dim() / std::max<Long>(Npt,1);
  1162. }
  1163. Vector<VTKReal> &coord = vtu_data.coord;
  1164. Vector<VTKReal> &value = vtu_data.value;
  1165. Vector<int32_t> &connect = vtu_data.connect;
  1166. Vector<int32_t> &offset = vtu_data.offset;
  1167. Vector<uint8_t> &types = vtu_data.types;
  1168. Long point_cnt = coord.Dim() / DIM;
  1169. Long connect_cnt = connect.Dim();
  1170. value.ReInit(point_cnt * value_dof);
  1171. value.SetZero();
  1172. SCTL_ASSERT(node_mid.Dim() == node_attr.Dim());
  1173. SCTL_ASSERT(node_mid.Dim() == pt_cnt.Dim());
  1174. for (Long i = 0; i < node_mid.Dim(); i++) {
  1175. if (!show_ghost && node_attr[i].Ghost) continue;
  1176. if (!node_attr[i].Leaf) continue;
  1177. for (Long j = 0; j < pt_cnt[i]; j++) {
  1178. ConstIterator<Real> pt_coord_ = pt_coord.begin() + (pt_dsp[i] + j) * DIM;
  1179. ConstIterator<Real> pt_value_ = (value_dof ? pt_value.begin() + (pt_dsp[i] + j) * value_dof : NullIterator<Real>());
  1180. for (Integer k = 0; k < DIM; k++) coord.PushBack((VTKReal)pt_coord_[k]);
  1181. for (Integer k = DIM; k < 3; k++) coord.PushBack(0);
  1182. for (Integer k = 0; k < value_dof; k++) value.PushBack((VTKReal)pt_value_[k]);
  1183. connect.PushBack(point_cnt);
  1184. connect_cnt++;
  1185. point_cnt++;
  1186. offset.PushBack(connect_cnt);
  1187. types.PushBack(1);
  1188. }
  1189. }
  1190. }
  1191. vtu_data.WriteVTK(fname, this->GetComm());
  1192. }
  1193. private:
  1194. std::map<std::string, Long> Nlocal;
  1195. std::map<std::string, Vector<Morton<DIM>>> pt_mid;
  1196. std::map<std::string, Vector<Long>> scatter_idx;
  1197. std::map<std::string, std::string> data_pt_name;
  1198. };
  1199. }
  1200. #endif //_SCTL_TREE_