Dhairya Malhotra преди 5 години
родител
ревизия
d541c1e5e1
променени са 2 файла, в които са добавени 97 реда и са изтрити 12 реда
  1. 85 6
      include/sctl/tree.hpp
  2. 12 6
      src/test-cpp.cpp

+ 85 - 6
include/sctl/tree.hpp

@@ -243,6 +243,27 @@ template <class Real, Integer DIM> class Tree {
       Integer np = comm.Size();
       Integer rank = comm.Rank();
 
+      Vector<Morton<DIM>> node_mid_orig;
+      Long start_idx_orig, end_idx_orig;
+      if (mins.Dim()) { // Set start_idx_orig, end_idx_orig
+        start_idx_orig = std::lower_bound(node_mid.begin(), node_mid.end(), mins[rank]) - node_mid.begin();
+        end_idx_orig = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
+        node_mid_orig.ReInit(end_idx_orig - start_idx_orig, node_mid.begin() + start_idx_orig, true);
+      } else {
+        start_idx_orig = 0;
+        end_idx_orig = 0;
+      }
+
+      auto coarsest_ancestor_mid = [](const Morton<DIM>& m0) {
+        Morton<DIM> md;
+        Integer d0 = m0.Depth();
+        for (Integer d = 0; d <= d0; d++) {
+          md = m0.Ancestor(d);
+          if (md.Ancestor(d0) == m0) break;
+        }
+        return md;
+      };
+
       Morton<DIM> pt_mid0;
       Vector<Morton<DIM>> pt_mid;
       { // Construct sorted pt_mid
@@ -300,8 +321,9 @@ template <class Real, Integer DIM> class Tree {
       { // Set mins
         mins.ReInit(np);
         Long min_idx = std::lower_bound(node_mid.begin(), node_mid.end(), pt_mid0) - node_mid.begin() - 1;
-        if (min_idx < 0) min_idx = 0;
-        comm.Allgather(node_mid.begin() + min_idx, 1, mins.begin(), 1);
+        if (!rank || min_idx < 0) min_idx = 0;
+        Morton<DIM> m0 = coarsest_ancestor_mid(node_mid[min_idx]);
+        comm.Allgather(Ptr2ConstItr<Morton<DIM>>(&m0,1), 1, mins.begin(), 1);
       }
       { // Set node_mid, node_attr
         Morton<DIM> m0 = (rank      ? mins[rank]   : Morton<DIM>()       );
@@ -318,7 +340,64 @@ template <class Real, Integer DIM> class Tree {
         // TODO
       }
       { // Update node_data, node_cnt
-        // TODO
+        Long start_idx, end_idx;
+        { // Set start_idx, end_idx
+          start_idx = std::lower_bound(node_mid.begin(), node_mid.end(), mins[rank]) - node_mid.begin();
+          end_idx = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
+        }
+
+        comm.PartitionS(node_mid_orig, mins[comm.Rank()]);
+
+        Vector<Long> cnt_tmp;
+        Vector<Real> data_tmp;
+        for (const auto& pair : node_data) {
+          const std::string& data_name = pair.first;
+
+          Long dof = 0;
+          Iterator<Vector<Real>> data_;
+          Iterator<Vector<Long>> cnt_;
+          GetData(data_, cnt_, data_name);
+          { // Set dof
+            StaticArray<Long,2> Nl, Ng;
+            Nl[0] = data_->Dim();
+            Nl[1] = omp_par::reduce(cnt_->begin(), cnt_->Dim());
+            comm.Allreduce((ConstIterator<Long>)Nl, (Iterator<Long>)Ng, 2, Comm::CommOp::SUM);
+            if (Ng[1]) dof = Ng[0] / Ng[1];
+            SCTL_ASSERT(Nl[0] == Nl[1] * dof);
+            SCTL_ASSERT(Ng[0] == Ng[1] * dof);
+          }
+
+          Long data_dsp = omp_par::reduce(cnt_->begin(), start_idx_orig);
+          Long data_cnt = omp_par::reduce(cnt_->begin() + start_idx_orig, end_idx_orig - start_idx_orig);
+          data_tmp.ReInit(data_cnt * dof, data_->begin() + data_dsp * dof, true);
+
+          cnt_tmp.ReInit(end_idx_orig - start_idx_orig, cnt_->begin() + start_idx_orig, true);
+          comm.PartitionN(cnt_tmp, node_mid_orig.Dim());
+
+          cnt_->ReInit(node_mid.Dim());
+          for (Long i = 0; i < start_idx; i++) {
+            cnt_[0][i] = 0;
+          }
+          for (Long i = start_idx; i < end_idx; i++) {
+            auto m0 = coarsest_ancestor_mid(node_mid[i+0]);
+            auto m1 = (i+1==end_idx ? Morton<DIM>().Next() : coarsest_ancestor_mid(node_mid[i+1]));
+            Long a = std::lower_bound(node_mid_orig.begin(), node_mid_orig.begin() + node_mid_orig.Dim(), m0) - node_mid_orig.begin();
+            Long b = std::lower_bound(node_mid_orig.begin(), node_mid_orig.begin() + node_mid_orig.Dim(), m1) - node_mid_orig.begin();
+            // TODO: precompute a and b
+
+            cnt_[0][i] = 0;
+            for (Long j = a; j < b; j++) cnt_[0][i] += cnt_tmp[j];
+          }
+          for (Long i = end_idx; i < node_mid.Dim(); i++) {
+            cnt_[0][i] = 0;
+          }
+          SCTL_ASSERT(omp_par::reduce(cnt_->begin(), cnt_->Dim()) == omp_par::reduce(cnt_tmp.begin(), cnt_tmp.Dim()));
+
+          Long Ndata = omp_par::reduce(cnt_->begin(), cnt_->Dim()) * dof;
+          comm.PartitionN(data_tmp, Ndata);
+          SCTL_ASSERT(data_tmp.Dim() == Ndata);
+          data_->Swap(data_tmp);
+        }
       }
     }
 
@@ -460,6 +539,8 @@ template <class Real, Integer DIM, class BaseTree = Tree<Real,DIM>> class PtTree
 
     void UpdateRefinement(const Vector<Real>& coord, Long M = 1) {
       const auto& comm = this->GetComm();
+      BaseTree::UpdateRefinement(coord, M);
+
       Long start_node_idx, end_node_idx;
       { // Set start_node_idx, end_node_idx
         const auto& mins = this->GetPartitionMID();
@@ -470,8 +551,6 @@ template <class Real, Integer DIM, class BaseTree = Tree<Real,DIM>> class PtTree
         end_node_idx = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
       }
 
-      BaseTree::UpdateRefinement(coord, M);
-
       const auto& mins = this->GetPartitionMID();
       const auto& node_mid = this->GetNodeMID();
       for (const auto& pair : pt_mid) {
@@ -614,7 +693,7 @@ template <class Real, Integer DIM, class BaseTree = Tree<Real,DIM>> class PtTree
         Long N1 = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
         Long start = dsp[N0] * dof;
         Long end = (N1<dsp.Dim() ? dsp[N1] : dsp[N1-1]+cnt_[0][N1-1]) * dof;
-        data.ReInit(end-start, (Iterator<Real>)data_->begin()+start);
+        data.ReInit(end-start, (Iterator<Real>)data_->begin()+start, true);
         comm.ScatterReverse(data, scatter_idx_, Nlocal_ * dof);
       }
     }

+ 12 - 6
src/test-cpp.cpp

@@ -87,8 +87,8 @@ int main(int argc, char** argv) {
     sctl::Comm comm = sctl::Comm::World();
     srand48(comm.Rank());
 
-    sctl::Vector<Real> coord;
-    { // Set coord
+    sctl::Vector<Real> coord, coord_unif;
+    { // Set coord, coord_unif
       long N_total = 2e5;
       int np = comm.Size();
       int myrank = comm.Rank();
@@ -103,24 +103,30 @@ int main(int argc, char** argv) {
         coord[i*DIM+1] = 0.5 + 0.1125 * sin(theta) * sin(phi);
         coord[i*DIM+2] = 0.5 + 0.45 * cos(theta);
       }
+
+      coord_unif.ReInit(1000);
+      for (auto& x : coord_unif) x = drand48();
     }
 
     sctl::Profile::Tic("Refine",&comm);
     sctl::PtTree<Real,DIM> t(comm);
-    t.UpdateRefinement(coord, 100);
+    t.UpdateRefinement(coord_unif, 1000);
     sctl::Profile::Toc();
 
     sctl::Profile::Tic("AddPts",&comm);
     t.AddParticles("src_pts", coord);
     sctl::Profile::Toc();
 
+    sctl::Profile::Tic("UpdateRefine",&comm);
+    t.UpdateRefinement(coord, 100);
+    sctl::Profile::Toc();
+
     sctl::Profile::Tic("GetPts",&comm);
     { // Verify GetParticleData
       sctl::Vector<Real> data;
       t.GetParticleData(data, "src_pts");
-      Real err = 0;
-      for (long i = 0; i < coord.Dim(); i++) err = std::max(err, fabs(data[i] - coord[i]));
-      SCTL_ASSERT(err == 0);
+      SCTL_ASSERT(data.Dim() == coord.Dim());
+      for (long i = 0; i < coord.Dim(); i++) SCTL_ASSERT(data[i] == coord[i]);
     }
     sctl::Profile::Toc();