Dhairya Malhotra vor 6 Jahren
Ursprung
Commit
7e4b9b8ac3
3 geänderte Dateien mit 33 neuen und 9 gelöschten Zeilen
  1. 7 0
      include/sctl/comm.txx
  2. 19 4
      include/sctl/fft_wrapper.hpp
  3. 7 5
      include/sctl/parallel_solver.hpp

+ 7 - 0
include/sctl/comm.txx

@@ -38,6 +38,7 @@ inline Comm Comm::World() {
 
 inline Comm& Comm::operator=(const Comm& c) {
 #ifdef SCTL_HAVE_MPI
+  #pragma omp critical(SCTL_COMM_DUP)
   MPI_Comm_free(&mpi_comm_);
   Init(c.mpi_comm_);
 #endif
@@ -50,6 +51,7 @@ inline Comm::~Comm() {
     delete (Vector<MPI_Request>*)req.top();
     req.pop();
   }
+  #pragma omp critical(SCTL_COMM_DUP)
   MPI_Comm_free(&mpi_comm_);
 #endif
 }
@@ -57,8 +59,10 @@ inline Comm::~Comm() {
 inline Comm Comm::Split(Integer clr) const {
 #ifdef SCTL_HAVE_MPI
   MPI_Comm new_comm;
+  #pragma omp critical(SCTL_COMM_DUP)
   MPI_Comm_split(mpi_comm_, clr, mpi_rank_, &new_comm);
   Comm c(new_comm);
+  #pragma omp critical(SCTL_COMM_DUP)
   MPI_Comm_free(&new_comm);
   return c;
 #else
@@ -1136,7 +1140,9 @@ template <class Type> void Comm::HyperQuickSort(const Vector<Type>& arr_, Vector
 
     {  // Split comm.  O( log(p) ) ??
       MPI_Comm scomm;
+      #pragma omp critical(SCTL_COMM_DUP)
       MPI_Comm_split(comm, myrank <= split_id, myrank, &scomm);
+      #pragma omp critical(SCTL_COMM_DUP)
       if (free_comm) MPI_Comm_free(&comm);
       comm = scomm;
       free_comm = true;
@@ -1145,6 +1151,7 @@ template <class Type> void Comm::HyperQuickSort(const Vector<Type>& arr_, Vector
       myrank = (myrank <= split_id ? myrank : myrank - split_id - 1);
     }
   }
+  #pragma omp critical(SCTL_COMM_DUP)
   if (free_comm) MPI_Comm_free(&comm);
 
   SortedElem = arr;

+ 19 - 4
include/sctl/fft_wrapper.hpp

@@ -192,7 +192,7 @@ template <class ValueType, class FFT_Derived> class FFT_Generic {
     return dim[i];
   }
 
-  void Setup(FFT_Type fft_type_, Long howmany_, const Vector<Long>& dim_vec) {
+  void Setup(FFT_Type fft_type_, Long howmany_, const Vector<Long>& dim_vec, Integer Nthreads = 1) {
     Long rank = dim_vec.Dim();
     fft_type = fft_type_;
     howmany = howmany_;
@@ -377,6 +377,18 @@ template <class ValueType, class FFT_Derived> class FFT_Generic {
 
 template <class ValueType> class FFT : public FFT_Generic<ValueType, FFT<ValueType>> {};
 
+static inline void FFTWInitThreads(Integer Nthreads) {
+#ifdef SCTL_FFTW_THREADS
+  static bool first_time = true;
+  #pragma omp critical
+  if (first_time) {
+    fftw_init_threads();
+    first_time = false;
+  }
+  fftw_plan_with_nthreads(Nthreads);
+#endif
+}
+
 #ifdef SCTL_HAVE_FFTW
 template <> class FFT<double> : public FFT_Generic<double, FFT<double>> {
 
@@ -386,7 +398,8 @@ template <> class FFT<double> : public FFT_Generic<double, FFT<double>> {
 
   ~FFT() { if (this->Dim(0) && this->Dim(1)) fftw_destroy_plan(plan); }
 
-  void Setup(FFT_Type fft_type_, Long howmany_, const Vector<Long>& dim_vec) {
+  void Setup(FFT_Type fft_type_, Long howmany_, const Vector<Long>& dim_vec, Integer Nthreads = 1) {
+    FFTWInitThreads(Nthreads);
     if (Dim(0) && Dim(1)) fftw_destroy_plan(plan);
     this->fft_type = fft_type_;
     this->howmany = howmany_;
@@ -497,7 +510,8 @@ template <> class FFT<float> : public FFT_Generic<float, FFT<float>> {
 
   ~FFT() { if (this->Dim(0) && this->Dim(1)) fftwf_destroy_plan(plan); }
 
-  void Setup(FFT_Type fft_type_, Long howmany_, const Vector<Long>& dim_vec) {
+  void Setup(FFT_Type fft_type_, Long howmany_, const Vector<Long>& dim_vec, Integer Nthreads = 1) {
+    FFTWInitThreads(Nthreads);
     if (Dim(0) && Dim(1)) fftwf_destroy_plan(plan);
     this->fft_type = fft_type_;
     this->howmany = howmany_;
@@ -608,7 +622,8 @@ template <> class FFT<long double> : public FFT_Generic<long double, FFT<long do
 
   ~FFT() { if (this->Dim(0) && this->Dim(1)) fftwl_destroy_plan(plan); }
 
-  void Setup(FFT_Type fft_type_, Long howmany_, const Vector<Long>& dim_vec) {
+  void Setup(FFT_Type fft_type_, Long howmany_, const Vector<Long>& dim_vec, Integer Nthreads = 1) {
+    FFTWInitThreads(Nthreads);
     if (Dim(0) && Dim(1)) fftwl_destroy_plan(plan);
     this->fft_type = fft_type_;
     this->howmany = howmany_;

+ 7 - 5
include/sctl/parallel_solver.hpp

@@ -15,7 +15,7 @@ template <class Real> class ParallelSolver {
 
   ParallelSolver(const Comm& comm = Comm::Self(), bool verbose = true) : comm_(comm), verbose_(verbose) {}
 
-  void operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, Real tol, Integer max_iter = -1);
+  void operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, Real tol, Integer max_iter = -1, bool use_abs_tol = false);
 
  private:
   Comm comm_;
@@ -62,7 +62,7 @@ template <class Real> int ParallelSolverMatVec(Mat M_, Vec x_, Vec Mx_) {
   return 0;
 }
 
-template <class Real> inline void ParallelSolver<Real>::operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, Real tol, Integer max_iter) {
+template <class Real> inline void ParallelSolver<Real>::operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, Real tol, Integer max_iter, bool use_abs_tol) {
   PetscInt N = b.Dim();
   if (max_iter < 0) max_iter = N;
   MPI_Comm comm = comm_.GetMPI_Comm();
@@ -100,7 +100,8 @@ template <class Real> inline void ParallelSolver<Real>::operator()(Vector<Real>*
   // Set runtime options
   KSPSetType(ksp, KSPGMRES);
   KSPSetNormType(ksp, KSP_NORM_UNPRECONDITIONED);
-  KSPSetTolerances(ksp, tol, PETSC_DEFAULT, PETSC_DEFAULT, max_iter);
+  if (use_abs_tol) KSPSetTolerances(ksp, PETSC_DEFAULT, tol, PETSC_DEFAULT, max_iter);
+  else KSPSetTolerances(ksp, tol, PETSC_DEFAULT, PETSC_DEFAULT, max_iter);
   KSPGMRESSetOrthogonalization(ksp, KSPGMRESModifiedGramSchmidtOrthogonalization);
   //if (verbose_) KSPMonitorSet(ksp, KSPMonitorDefault, nullptr, nullptr); // Doesn't work for some versions of PETSc!! WTH!!
   KSPGMRESSetRestart(ksp, max_iter);
@@ -158,7 +159,7 @@ template <class Real> static Real inner_prod(const Vector<Real>& x, const Vector
   return x_dot_y_glb;
 }
 
-template <class Real> inline void ParallelSolver<Real>::operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, Real tol, Integer max_iter) {
+template <class Real> inline void ParallelSolver<Real>::operator()(Vector<Real>* x, const ParallelOp& A, const Vector<Real>& b, Real tol, Integer max_iter, bool use_abs_tol) {
   Long N = b.Dim();
   if (max_iter < 0) max_iter = N;
 
@@ -181,10 +182,11 @@ template <class Real> inline void ParallelSolver<Real>::operator()(Vector<Real>*
 
   Matrix<Real> H_;
   Vector<Real> Aq(N), y, h, r = b;
+  Real abs_tol = tol * (use_abs_tol ? 1 : b_norm);
   while (1) {
     Real r_norm = sqrt(inner_prod(r, r, comm_));
     if (verbose_ && !comm_.Rank()) printf("%3lld KSP Residual norm %.12e\n", (long long)H.Dim(), r_norm);
-    if (r_norm < tol * b_norm || H.Dim() == max_iter) break;
+    if (r_norm < abs_tol || H.Dim() == max_iter) break;
 
     A(&Aq, q);
     q = Aq;