Prechádzať zdrojové kódy

In Cheb_Node replace function pointer with functor.

Users can now use derived classes of this functor class (Function_t) to
pass custom evaluators for tree construction.
Dhairya Malhotra 10 rokov pred
rodič
commit
b57f57a1ee
2 zmenil súbory, kde vykonal 29 pridanie a 6 odobranie
  1. 27 3
      include/cheb_node.hpp
  2. 2 3
      include/cheb_node.txx

+ 27 - 3
include/cheb_node.hpp

@@ -27,7 +27,31 @@ class Cheb_Node: public MPI_Node<Real_t>{
 
  public:
 
-  typedef void (*fn_ptr)(const Real_t* coord, int n, Real_t* out);
+  template<class Real_t>
+  class Function_t{
+
+    typedef void (*fn_ptr)(const Real_t* coord, int n, Real_t* out);
+
+    public:
+
+    Function_t(): fn_(NULL){}
+
+    Function_t(fn_ptr fn): fn_(fn){}
+
+    virtual ~Function_t(){}
+
+    virtual void operator()(const Real_t* coord, int n, Real_t* out){
+      fn_(coord, n, out);
+    }
+
+    virtual bool IsEmpty(){
+      return (fn_==NULL);
+    }
+
+    private:
+
+    fn_ptr fn_;
+  };
 
   /**
    * \brief Base class for node data. Contains initialization data for the node.
@@ -39,7 +63,7 @@ class Cheb_Node: public MPI_Node<Real_t>{
      Vector<Real_t> cheb_coord; //Chebyshev point samples.
      Vector<Real_t> cheb_value;
 
-     fn_ptr input_fn; // Function pointer.
+     Function_t<Real_t> input_fn; // Function pointer.
      int data_dof;    // Dimension of Chebyshev data.
      int cheb_deg;    // Chebyshev degree
      Real_t tol;      // Tolerance for adaptive refinement.
@@ -169,7 +193,7 @@ class Cheb_Node: public MPI_Node<Real_t>{
    */
   void Curl();
 
-  fn_ptr input_fn;
+  Function_t<Real_t> input_fn;
   Vector<Real_t> cheb_coord;   //coordinates of points
   Vector<Real_t> cheb_value;   //value at points
   Vector<size_t> cheb_scatter; //scatter index mapping original data.

+ 2 - 3
include/cheb_node.txx

@@ -38,7 +38,7 @@ void Cheb_Node<Real_t>::Initialize(TreeNode* parent_, int path2node_, TreeNode::
 
   //Compute Chebyshev approximation.
   if(this->IsLeaf() && !this->IsGhost()){
-    if(input_fn!=NULL && data_dof>0){
+    if(!input_fn.IsEmpty() && data_dof>0){
       Real_t s=pow(0.5,this->Depth());
       int n1=(int)(pow((Real_t)(cheb_deg+1),this->Dim())+0.5);
       std::vector<Real_t> coord=cheb_nodes<Real_t>(cheb_deg,this->Dim());
@@ -126,7 +126,7 @@ template <class Real_t>
 void Cheb_Node<Real_t>::Subdivide() {
   if(!this->IsLeaf()) return;
   MPI_Node<Real_t>::Subdivide();
-  if(cheb_deg<0 || cheb_coeff.Dim()==0 || input_fn!=NULL) return;
+  if(cheb_deg<0 || cheb_coeff.Dim()==0 || !input_fn.IsEmpty()) return;
 
   std::vector<Real_t> x(cheb_deg+1);
   std::vector<Real_t> y(cheb_deg+1);
@@ -152,7 +152,6 @@ void Cheb_Node<Real_t>::Subdivide() {
     Cheb_Node<Real_t>* child=static_cast<Cheb_Node<Real_t>*>(this->Child(i));
     child->cheb_coeff=child_cheb_coeff[i];
     assert(child->cheb_deg==cheb_deg);
-    assert(child->input_fn==input_fn);
     assert(child->tol==tol);
   }
 }