Dhairya Malhotra 5 gadi atpakaļ
vecāks
revīzija
c8de7e7717
4 mainītis faili ar 183 papildinājumiem un 83 dzēšanām
  1. 77 7
      include/sctl/tree.hpp
  2. 105 34
      include/tree.f90
  3. 0 42
      include/tree.h
  4. 1 0
      src/test-fortran.f90

+ 77 - 7
include/sctl/tree.hpp

@@ -458,6 +458,74 @@ template <class Real, Integer DIM, class BaseTree = Tree<Real,DIM>> class PtTree
       #endif
     }
 
+    void UpdateRefinement(const Vector<Real>& coord, Long M = 1) {
+      const auto& comm = this->GetComm();
+      Long start_node_idx, end_node_idx;
+      { // Set start_node_idx, end_node_idx
+        const auto& mins = this->GetPartitionMID();
+        const auto& node_mid = this->GetNodeMID();
+        Integer np = comm.Size();
+        Integer rank = comm.Rank();
+        start_node_idx = std::lower_bound(node_mid.begin(), node_mid.end(), mins[rank]) - node_mid.begin();
+        end_node_idx = std::lower_bound(node_mid.begin(), node_mid.end(), (rank+1==np ? Morton<DIM>().Next() : mins[rank+1])) - node_mid.begin();
+      }
+
+      BaseTree::UpdateRefinement(coord, M);
+
+      const auto& mins = this->GetPartitionMID();
+      const auto& node_mid = this->GetNodeMID();
+      for (const auto& pair : pt_mid) {
+        const auto& pt_name = pair.first;
+        auto& pt_mid_ = pt_mid[pt_name];
+        auto& scatter_idx_ = scatter_idx[pt_name];
+        comm.PartitionS(pt_mid_, mins[comm.Rank()]);
+        comm.PartitionN(scatter_idx_, pt_mid_.Dim());
+
+        Vector<Long> pt_cnt(node_mid.Dim());
+        for (Long i = 0; i < node_mid.Dim(); i++) { // Set pt_cnt
+          Long start = std::lower_bound(pt_mid_.begin(), pt_mid_.end(), node_mid[i]) - pt_mid_.begin();
+          Long end = std::lower_bound(pt_mid_.begin(), pt_mid_.end(), (i+1==node_mid.Dim() ? Morton<DIM>().Next() : node_mid[i+1])) - pt_mid_.begin();
+          if (i == 0) SCTL_ASSERT(start == 0);
+          if (i+1 == node_mid.Dim()) SCTL_ASSERT(end == pt_mid_.Dim());
+          pt_cnt[i] = end - start;
+        }
+
+        for (const auto& pair : data_pt_name) {
+          if (pair.second == pt_name) {
+            const auto& data_name = pair.first;
+
+            Iterator<Vector<Real>> data;
+            Iterator<Vector<Long>> cnt;
+            this->GetData(data, cnt, data_name);
+
+            { // Update data
+              Long dof = 0;
+              { // Set dof
+                StaticArray<Long,2> Nl = {0, 0}, Ng;
+                Nl[0] = data->Dim();
+                for (Long i = 0; i < cnt->Dim(); i++) Nl[1] += cnt[0][i];
+                comm.Allreduce((ConstIterator<Long>)Nl, (Iterator<Long>)Ng, 2, Comm::CommOp::SUM);
+                dof = Ng[0] / std::max<Long>(Ng[1],1);
+              }
+              Long offset = 0, count = 0;
+              SCTL_ASSERT(0 <= start_node_idx);
+              SCTL_ASSERT(start_node_idx <= end_node_idx);
+              SCTL_ASSERT(end_node_idx <= cnt->Dim());
+              for (Long i = 0; i < start_node_idx; i++) offset += cnt[0][i];
+              for (Long i = start_node_idx; i < end_node_idx; i++) count += cnt[0][i];
+              offset *= dof;
+              count *= dof;
+
+              Vector<Real> data_(count, data->begin() + offset);
+              comm.PartitionN(data_, pt_mid_.Dim());
+              data->Swap(data_);
+            }
+            cnt[0] = pt_cnt;
+          }
+        }
+      }
+    }
+
     void AddParticles(const std::string& name, const Vector<Real>& coord) {
       const auto& mins = this->GetPartitionMID();
       const auto& node_mid = this->GetNodeMID();
@@ -470,12 +538,13 @@ template <class Real, Integer DIM, class BaseTree = Tree<Real,DIM>> class PtTree
       SCTL_ASSERT(coord.Dim() == N * DIM);
       Nlocal[name] = N;
 
-      Vector<Morton<DIM>> pt_mid(N);
+      Vector<Morton<DIM>>& pt_mid_ = pt_mid[name];
+      if (pt_mid_.Dim() != N) pt_mid_.ReInit(N);
       for (Long i = 0; i < N; i++) {
-        pt_mid[i] = Morton<DIM>(coord.begin() + i*DIM);
+        pt_mid_[i] = Morton<DIM>(coord.begin() + i*DIM);
       }
-      comm.SortScatterIndex(pt_mid, scatter_idx_, &mins[comm.Rank()]);
-      comm.ScatterForward(pt_mid, scatter_idx_);
+      comm.SortScatterIndex(pt_mid_, scatter_idx_, &mins[comm.Rank()]);
+      comm.ScatterForward(pt_mid_, scatter_idx_);
       AddParticleData(name, name, coord);
 
       { // Set node_cnt
@@ -484,10 +553,10 @@ template <class Real, Integer DIM, class BaseTree = Tree<Real,DIM>> class PtTree
         this->GetData(data_,cnt_,name);
         cnt_[0].ReInit(node_mid.Dim());
         for (Long i = 0; i < node_mid.Dim(); i++) {
-          Long start = std::lower_bound(pt_mid.begin(), pt_mid.end(), node_mid[i]) - pt_mid.begin();
-          Long end = std::lower_bound(pt_mid.begin(), pt_mid.end(), (i+1==node_mid.Dim()? Morton<DIM>().Next():node_mid[i+1])) - pt_mid.begin();
+          Long start = std::lower_bound(pt_mid_.begin(), pt_mid_.end(), node_mid[i]) - pt_mid_.begin();
+          Long end = std::lower_bound(pt_mid_.begin(), pt_mid_.end(), (i+1==node_mid.Dim() ? Morton<DIM>().Next() : node_mid[i+1])) - pt_mid_.begin();
           if (i == 0) SCTL_ASSERT(start == 0);
-          if (i+1 == node_mid.Dim()) SCTL_ASSERT(end == pt_mid.Dim());
+          if (i+1 == node_mid.Dim()) SCTL_ASSERT(end == pt_mid_.Dim());
           cnt_[0][i] = end - start;
         }
       }
@@ -637,6 +706,7 @@ template <class Real, Integer DIM, class BaseTree = Tree<Real,DIM>> class PtTree
   private:
 
     std::map<std::string, Long> Nlocal;
+    std::map<std::string, Vector<Morton<DIM>>> pt_mid;
     std::map<std::string, Vector<Long>> scatter_idx;
     std::map<std::string, std::string> data_pt_name;
 };

+ 105 - 34
include/tree.f90

@@ -1,34 +1,105 @@
-!interface
-!
-!  subroutine CreateTree(tree)
-!    use, intrinsic :: ISO_C_BINDING
-!    type(c_ptr) :: tree
-!  end subroutine
-!
-!  subroutine DeleteTree(tree)
-!    use, intrinsic :: ISO_C_BINDING
-!    type(c_ptr) :: tree
-!  end subroutine
-!
-!  subroutine myalloc(A, N) !bind(C)
-!    use, intrinsic :: ISO_C_BINDING
-!    implicit none
-!    type(C_PTR), intent(inout), allocatable :: A(:)
-!    integer *4, intent(inout)  :: N
-!  end subroutine
-!
-!  subroutine myprint(A, N) !bind(C)
-!    use, intrinsic :: ISO_C_BINDING
-!    implicit none
-!    real    *8, intent(inout), allocatable :: A(:)
-!    integer *4, intent(inout) :: N
-!  end subroutine
-!
-!  subroutine myfree(A) !bind(C)
-!    use, intrinsic :: ISO_C_BINDING
-!    implicit none
-!    real    *8, intent(inout), allocatable :: A(:)
-!  end subroutine
-!
-!end interface
-!
+interface
+
+  subroutine Createtree_(tree_ctx)
+    use, intrinsic :: ISO_C_BINDING
+    type(c_ptr), intent(out) :: tree_ctx
+  end subroutine
+
+  subroutine DeleteTree(tree_ctx)
+    use, intrinsic :: ISO_C_BINDING
+    type(c_ptr), intent(in) :: tree_ctx
+  end subroutine
+
+  subroutine GetTree(node_coord, node_depth, node_ghost, node_leaf, Nnodes, tree_ctx)
+    use, intrinsic :: ISO_C_BINDING
+    type(c_ptr), intent(out) :: node_coord ! real*8, dimension(Nnode*3)
+    type(c_ptr), intent(out) :: node_depth ! integer*1, dimension(Nnode)
+    type(c_ptr), intent(out) :: node_ghost ! integer*1, dimension(Nnode)
+    type(c_ptr), intent(out) :: node_leaf  ! integer*1, dimension(Nnode)
+    integer*4  , intent(out) :: Nnodes
+    type(c_ptr), intent(in)  :: tree_ctx
+  end subroutine
+
+  subroutine UpdateRefinement(pt_coord, Npt, max_pts, tree_ctx)
+    use, intrinsic :: ISO_C_BINDING
+    real*8     , intent(in) :: pt_coord(*)
+    integer*4  , intent(in) :: Npt
+    integer*4  , intent(in) :: max_pts
+    type(c_ptr), intent(in) :: tree_ctx
+  end subroutine
+
+  subroutine AddData(data_name, node_data, Ndata, cnt, Ncnt, tree_ctx)
+    use, intrinsic :: ISO_C_BINDING
+    character  , intent(in) :: data_name(*)
+    real*8     , intent(in) :: node_data(*)
+    integer*4  , intent(in) :: Ndata
+    integer*4  , intent(in) :: cnt(*)
+    integer*4  , intent(in) :: Ncnt
+    type(c_ptr), intent(in) :: tree_ctx
+  end subroutine
+
+  subroutine GetData(node_data, Ndata, cnt, Ncnt, data_name, tree_ctx)
+    use, intrinsic :: ISO_C_BINDING
+    type(c_ptr), intent(out) :: node_data  ! real*8   , dimension(Ndata)
+    integer*4  , intent(out) :: Ndata
+    type(c_ptr), intent(out) :: cnt        ! integer*4, dimension(Ncnt)
+    integer*4  , intent(out) :: Ncnt
+    character  , intent(in)  :: data_name(*)
+    type(c_ptr), intent(in)  :: tree_ctx
+  end subroutine
+
+  subroutine DeleteData(data_name, tree_ctx)
+    use, intrinsic :: ISO_C_BINDING
+    character  , intent(in) :: data_name(*)
+    type(c_ptr), intent(in) :: tree_ctx
+  end subroutine
+
+  subroutine WriteTreeVTK(fname, show_ghost, tree_ctx)
+    use, intrinsic :: ISO_C_BINDING
+    character  , intent(in) :: fname(*)
+    logical    , intent(in) :: show_ghost
+    type(c_ptr), intent(in) :: tree_ctx
+  end subroutine
+
+
+  subroutine AddParticles(pt_name, coord, Npt, tree_ctx)
+    use, intrinsic :: ISO_C_BINDING
+    character  , intent(in) :: pt_name
+    real*8     , intent(in) :: coord(*)
+    integer*4  , intent(in) :: Npt
+    type(c_ptr), intent(in) :: tree_ctx
+  end subroutine
+
+  subroutine AddParticleData(data_name, pt_name, pt_data, Ndata, tree_ctx)
+    use, intrinsic :: ISO_C_BINDING
+    character  , intent(in) :: data_name(*)
+    character  , intent(in) :: pt_name(*)
+    real*8     , intent(in) :: pt_data(*)
+    integer*4  , intent(in) :: Ndata
+    type(c_ptr), intent(in) :: tree_ctx
+  end subroutine
+
+  subroutine GetParticleData(pt_data, N, data_name, tree_ctx)
+    use, intrinsic :: ISO_C_BINDING
+    type(c_ptr), intent(out) :: pt_data    ! real*8, dimension(N)
+    integer*4  , intent(out) :: N
+    character  , intent(in)  :: data_name(*)
+    type(c_ptr), intent(in)  :: tree_ctx
+  end subroutine
+
+  subroutine DeleteParticleData(data_name, tree_ctx)
+    use, intrinsic :: ISO_C_BINDING
+    character  , intent(in) :: data_name
+    type(c_ptr), intent(in) :: tree_ctx
+  end subroutine
+
+  subroutine WriteParticleVTK(fname, data_name, show_ghost, tree_ctx)
+    use, intrinsic :: ISO_C_BINDING
+    character  , intent(in) :: fname(*)
+    character  , intent(in) :: data_name(*)
+    logical    , intent(in) :: show_ghost
+    type(c_ptr), intent(in) :: tree_ctx
+  end subroutine
+
+end interface
+

+ 0 - 42
include/tree.h

@@ -1,42 +0,0 @@
-#ifndef _TEST_H_
-#define _TEST_H_
-
-#include <stdint.h>
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-void createtree_(void** tree);
-
-void deletetree_(void** tree);
-
-void gettree_(const double** node_coord, const int8_t** node_depth, const int8_t** node_ghost, const int8_t** node_leaf, int32_t* Nnodes, const void** tree_ctx);
-
-void updaterefinement_(const double* pt_coord_, const int32_t* Npt, const int32_t* max_pts, void** tree_ctx);
-
-void adddata_(const char* name, const double* data, const int32_t* Ndata, const int32_t* cnt, const int32_t* Ncnt, void** tree_ctx);
-
-void getdata_(double** data, int32_t* Ndata, const int32_t** cnt, int32_t* Ncnt, const char* name, void** tree_ctx);
-
-void deletedata_(const char* name, void** tree_ctx);
-
-void writetreevtk_(const char* fname, const bool* show_ghost, const void** tree_ctx);
-
-
-void addparticles_(const char* name, const double* coord, const int32_t* Npt, void** tree_ctx);
-
-void addparticledata_(const char* name, const char* particle_name, const double* data, const int32_t* Ndata, void** tree_ctx);
-
-void getparticledata_(double** data, int32_t* N, const char* name, const void** tree_ctx);
-
-void deleteparticledata_(const char* name, void** tree_ctx);
-
-void writeparticlevtk_(const char* fname, const char* data_name, bool* show_ghost, const void** tree_ctx);
-
-#ifdef __cplusplus
-}
-#endif
-
-
-#endif //_TEST_H_

+ 1 - 0
src/test-fortran.f90

@@ -2,6 +2,7 @@ program main
   use iso_c_binding
   implicit none
   include 'mpif.h'
+  include 'tree.f90'
   type(c_ptr) :: tree_ctx
   real*8 :: pt_coord(100000*3)
   integer*4 :: Npt, mpi_rank, mpi_size, i, ierr