tree.hpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. #ifndef _SCTL_TREE_
  2. #define _SCTL_TREE_
  3. #include <sctl/common.hpp>
  4. #include SCTL_INCLUDE(comm.hpp)
  5. #include SCTL_INCLUDE(morton.hpp)
  6. #include SCTL_INCLUDE(vtudata.hpp)
  7. #include SCTL_INCLUDE(ompUtils.hpp)
  8. #include <string>
  9. #include <vector>
  10. #include <algorithm>
  11. namespace SCTL_NAMESPACE {
  12. template <Integer DIM> class Tree {
  13. public:
  14. struct NodeAttr {
  15. unsigned char Leaf : 1, Ghost : 1;
  16. };
  17. struct NodeLists {
  18. Long p2n;
  19. Long parent;
  20. Long child[1 << DIM];
  21. Long nbr[sctl::pow<DIM,Integer>(3)];
  22. };
  23. static constexpr Integer Dim();
  24. Tree(const Comm& comm_ = Comm::Self());
  25. ~Tree();
  26. const Vector<Morton<DIM>>& GetPartitionMID() const;
  27. const Vector<Morton<DIM>>& GetNodeMID() const;
  28. const Vector<NodeAttr>& GetNodeAttr() const;
  29. const Vector<NodeLists>& GetNodeLists() const;
  30. const Comm& GetComm() const;
  31. template <class Real> void UpdateRefinement(const Vector<Real>& coord, Long M = 1, bool balance21 = 0, bool periodic = 0);
  32. template <class ValueType> void AddData(const std::string& name, const Vector<ValueType>& data, const Vector<Long>& cnt);
  33. template <class ValueType> void GetData(Vector<ValueType>& data, Vector<Long>& cnt, const std::string& name) const;
  34. template <class ValueType> void ReduceBroadcast(const std::string& name);
  35. template <class ValueType> void Broadcast(const std::string& name);
  36. void DeleteData(const std::string& name);
  37. void WriteTreeVTK(std::string fname, bool show_ghost = false) const;
  38. protected:
  39. void GetData_(Iterator<Vector<char>>& data, Iterator<Vector<Long>>& cnt, const std::string& name);
  40. static void scan(Vector<Long>& dsp, const Vector<Long>& cnt);
  41. template <typename A, typename B> struct SortPair {
  42. int operator<(const SortPair<A, B> &p1) const { return key < p1.key; }
  43. A key;
  44. B data;
  45. };
  46. private:
  47. Vector<Morton<DIM>> mins;
  48. Vector<Morton<DIM>> node_mid;
  49. Vector<NodeAttr> node_attr;
  50. Vector<NodeLists> node_lst;
  51. std::map<std::string, Vector<char>> node_data;
  52. std::map<std::string, Vector<Long>> node_cnt;
  53. Vector<Morton<DIM>> user_mid;
  54. Vector<Long> user_cnt;
  55. Comm comm;
  56. };
  57. template <class Real, Integer DIM, class BaseTree = Tree<DIM>> class PtTree : public BaseTree {
  58. public:
  59. PtTree(const Comm& comm = Comm::Self());
  60. ~PtTree();
  61. void UpdateRefinement(const Vector<Real>& coord, Long M = 1, bool balance21 = 0, bool periodic = 0);
  62. void AddParticles(const std::string& name, const Vector<Real>& coord);
  63. void AddParticleData(const std::string& data_name, const std::string& particle_name, const Vector<Real>& data);
  64. void GetParticleData(Vector<Real>& data, const std::string& data_name) const;
  65. void DeleteParticleData(const std::string& data_name);
  66. void WriteParticleVTK(std::string fname, std::string data_name, bool show_ghost = false) const;
  67. static void test() {
  68. Long N = 100000;
  69. Vector<Real> X(N*DIM), f(N);
  70. for (Long i = 0; i < N; i++) { // Set coordinates (X), and values (f)
  71. f[i] = 0;
  72. for (Integer k = 0; k < DIM; k++) {
  73. X[i*DIM+k] = pow<3>(drand48()*2-1.0)*0.5+0.5;
  74. f[i] += X[i*DIM+k]*k;
  75. }
  76. }
  77. PtTree<Real,DIM> tree;
  78. tree.AddParticles("pt", X);
  79. tree.AddParticleData("pt-value", "pt", f);
  80. tree.UpdateRefinement(X, 1000); // refine tree with max 1000 points per box.
  81. { // manipulate tree node data
  82. const auto& node_lst = tree.GetNodeLists(); // Get interaction lists
  83. //const auto& node_mid = tree.GetNodeMID();
  84. //const auto& node_attr = tree.GetNodeAttr();
  85. // get point values and count for each node
  86. Vector<Real> value;
  87. Vector<Long> cnt, dsp;
  88. tree.GetData(value, cnt, "pt-value");
  89. // compute the dsp (the point offset) for each node
  90. dsp.ReInit(cnt.Dim()); dsp = 0;
  91. omp_par::scan(cnt.begin(), dsp.begin(), cnt.Dim());
  92. Long node_idx = 0;
  93. for (Long i = 0; i < cnt.Dim(); i++) { // find the tree node with maximum points
  94. if (cnt[node_idx] < cnt[i]) node_idx = i;
  95. }
  96. for (Long j = 0; j < cnt[node_idx]; j++) { // for this node, set all pt-value to -1
  97. value[dsp[node_idx]+j] = -1;
  98. }
  99. for (const Long nbr_idx : node_lst[node_idx].nbr) { // loop over the neighbors and set pt-value to 2
  100. if (nbr_idx >= 0 && nbr_idx != node_idx) {
  101. for (Long j = 0; j < cnt[nbr_idx]; j++) {
  102. value[dsp[nbr_idx]+j] = 2;
  103. }
  104. }
  105. }
  106. }
  107. // Generate visualization
  108. tree.WriteParticleVTK("pt", "pt-value");
  109. tree.WriteTreeVTK("tree");
  110. }
  111. private:
  112. std::map<std::string, Long> Nlocal;
  113. std::map<std::string, Vector<Morton<DIM>>> pt_mid;
  114. std::map<std::string, Vector<Long>> scatter_idx;
  115. std::map<std::string, std::string> data_pt_name;
  116. };
  117. }
  118. #include SCTL_INCLUDE(tree.txx)
  119. #endif //_SCTL_TREE_