Dhairya Malhotra 7 éve
szülő
commit
b638e4b7d6
2 módosított fájl, 298 hozzáadás és 32 törlés
  1. 7 1
      Makefile
  2. 291 31
      include/sctl/fft_wrapper.hpp

+ 7 - 1
Makefile

@@ -14,11 +14,17 @@ endif
 
 CXXFLAGS += -DSCTL_MEMDEBUG # Enable memory checks
 
+CXXFLAGS += -DSCTL_QUAD_T=__float128 -Wfloat-conversion
+
+#CXXFLAGS += -DSCTL_HAVE_MPI #use MPI
+
 CXXFLAGS += -lblas -DSCTL_HAVE_BLAS # use BLAS
 CXXFLAGS += -llapack -DSCTL_HAVE_LAPACK # use LAPACK
 #CXXFLAGS += -mkl -DSCTL_HAVE_BLAS -DSCTL_HAVE_LAPACK # use MKL BLAS and LAPACK
 
-#CXXFLAGS += -DSCTL_HAVE_MPI #use MPI
+CXXFLAGS += -lfftw3 -DSCTL_HAVE_FFTW
+CXXFLAGS += -lfftw3f -DSCTL_HAVE_FFTWF
+CXXFLAGS += -lfftw3l -DSCTL_HAVE_FFTWL
 
 
 RM = rm -f

+ 291 - 31
include/sctl/fft_wrapper.hpp

@@ -5,6 +5,12 @@
 #include <cassert>
 #include <cstdlib>
 #include <vector>
+#if defined(SCTL_HAVE_FFTW) || defined(SCTL_HAVE_FFTWF)
+#include <fftw3.h>
+#ifdef SCTL_FFTW3_MKL
+#include <fftw3_mkl.h>
+#endif
+#endif
 
 #include SCTL_INCLUDE(common.hpp)
 #include SCTL_INCLUDE(mem_mgr.hpp)
@@ -12,7 +18,6 @@
 
 namespace SCTL_NAMESPACE {
 
-
 template <class ValueType> class Complex {
   public:
 
@@ -83,26 +88,33 @@ template <class ValueType> Complex<ValueType> operator-(const ValueType& x, cons
   return z;
 }
 
-
-
 enum class FFT_Type {R2C, C2C, C2C_INV, C2R};
 
-template <class ValueType> class FFT {
+template <class ValueType, class FFT_Derived> class FFT_Generic {
 
   typedef Complex<ValueType> ComplexType;
 
   struct FFTPlan {
     std::vector<Matrix<ValueType>> M;
-    FFT_Type fft_type;
-    Long howmany;
   };
 
  public:
 
-  void Setup(FFT_Type fft_type, Long howmany, const Vector<Long>& dim_vec) {
+  FFT_Generic() {
+    dim[0] = 0;
+    dim[1] = 0;
+  }
+  FFT_Generic(const FFT_Generic&) = delete;
+  FFT_Generic& operator=(const FFT_Generic&) = delete;
+
+  Long Dim(Integer i) const {
+    return dim[i];
+  }
+
+  void Setup(FFT_Type fft_type_, Long howmany_, const Vector<Long>& dim_vec) {
     Long rank = dim_vec.Dim();
-    plan.fft_type = fft_type;
-    plan.howmany = howmany;
+    fft_type = fft_type_;
+    howmany = howmany_;
     plan.M.resize(0);
 
     if (fft_type == FFT_Type::R2C) {
@@ -123,17 +135,11 @@ template <class ValueType> class FFT {
       N0 = N0 * M.Dim(0) / 2;
       N1 = N1 * M.Dim(1) / 2;
     }
-  }
-
-  Long Dim(Integer i) const {
-    Long N = plan.howmany * 2;
-    for (const auto M : plan.M) N = N * M.Dim(i) / 2;
-    return N;
+    dim[0] = N0;
+    dim[1] = N1;
   }
 
   void Execute(const Vector<ValueType>& in, Vector<ValueType>& out) const {
-
-    Long howmany = plan.howmany;
     Long N0 = Dim(0);
     Long N1 = Dim(1);
     SCTL_ASSERT_MSG(in.Dim() == N0, "FFT: Wrong input size.");
@@ -145,7 +151,7 @@ template <class ValueType> class FFT {
     if (rank <= 0) return;
     Long N = N0;
 
-    if (plan.fft_type == FFT_Type::C2R) {
+    if (fft_type == FFT_Type::C2R) {
       const Matrix<ValueType>& M = plan.M[rank - 1];
       transpose<ComplexType>(buff0.begin(), in.begin(), N / M.Dim(0), M.Dim(0) / 2);
 
@@ -181,11 +187,12 @@ template <class ValueType> class FFT {
     fft_dim.PushBack(2);
     fft_dim.PushBack(5);
     fft_dim.PushBack(3);
+    Long howmany = 3;
 
     if (1){ // R2C, C2R
-      FFT<ValueType> myfft0, myfft1;
-      myfft0.Setup(FFT_Type::R2C, 1, fft_dim);
-      myfft1.Setup(FFT_Type::C2R, 1, fft_dim);
+      FFT_Derived myfft0, myfft1;
+      myfft0.Setup(FFT_Type::R2C, howmany, fft_dim);
+      myfft1.Setup(FFT_Type::C2R, howmany, fft_dim);
       Vector<ValueType> v0(myfft0.Dim(0)), v1, v2;
       for (int i = 0; i < v0.Dim(); i++) v0[i] = 1 + i;
       myfft0.Execute(v0, v1);
@@ -193,15 +200,15 @@ template <class ValueType> class FFT {
       { // Print error
         ValueType err = 0;
         SCTL_ASSERT(v0.Dim() == v2.Dim());
-        for (Long i=0;i<v0.Dim();i++) err = std::max(err, fabs(v0[i] - v2[i]));
+        for (Long i = 0; i < v0.Dim(); i++) err = std::max(err, fabs(v0[i] - v2[i]));
         std::cout<<"Error : "<<err<<'\n';
       }
     }
     std::cout<<'\n';
     { // C2C, C2C_INV
-      FFT<ValueType> myfft0, myfft1;
-      myfft0.Setup(FFT_Type::C2C, 1, fft_dim);
-      myfft1.Setup(FFT_Type::C2C_INV, 1, fft_dim);
+      FFT_Derived myfft0, myfft1;
+      myfft0.Setup(FFT_Type::C2C, howmany, fft_dim);
+      myfft1.Setup(FFT_Type::C2C_INV, howmany, fft_dim);
       Vector<ValueType> v0(myfft0.Dim(0)), v1, v2;
       for (int i = 0; i < v0.Dim(); i++) v0[i] = 1 + i;
       myfft0.Execute(v0, v1);
@@ -209,13 +216,14 @@ template <class ValueType> class FFT {
       { // Print error
         ValueType err = 0;
         SCTL_ASSERT(v0.Dim() == v2.Dim());
-        for (Long i=0;i<v0.Dim();i++) err = std::max(err, fabs(v0[i] - v2[i]));
+        for (Long i = 0; i < v0.Dim(); i++) err = std::max(err, fabs(v0[i] - v2[i]));
         std::cout<<"Error : "<<err<<'\n';
       }
     }
+    std::cout<<'\n';
   }
 
- private:
+ protected:
 
   static Matrix<ValueType> fft_r2c(Long N0) {
     ValueType s = 1 / sqrt<ValueType>(N0);
@@ -223,8 +231,8 @@ template <class ValueType> class FFT {
     Matrix<ValueType> M(N0, 2 * N1);
     for (Long j = 0; j < N0; j++)
       for (Long i = 0; i < N1; i++) {
-        M[j][2 * i + 0] = cos<ValueType>(2 * const_pi<ValueType>() * j * i / N0)*s;
-        M[j][2 * i + 1] = sin<ValueType>(2 * const_pi<ValueType>() * j * i / N0)*s;
+        M[j][2 * i + 0] =  cos<ValueType>(2 * const_pi<ValueType>() * j * i / N0)*s;
+        M[j][2 * i + 1] = -sin<ValueType>(2 * const_pi<ValueType>() * j * i / N0)*s;
       }
     return M;
   }
@@ -248,8 +256,8 @@ template <class ValueType> class FFT {
     Matrix<ValueType> M(2 * N1, N0);
     for (Long i = 0; i < N1; i++) {
       for (Long j = 0; j < N0; j++) {
-        M[2 * i + 0][j] = 2 * cos<ValueType>(2 * const_pi<ValueType>() * j * i / N0)*s;
-        M[2 * i + 1][j] = 2 * sin<ValueType>(2 * const_pi<ValueType>() * j * i / N0)*s;
+        M[2 * i + 0][j] =  2 * cos<ValueType>(2 * const_pi<ValueType>() * j * i / N0)*s;
+        M[2 * i + 1][j] = -2 * sin<ValueType>(2 * const_pi<ValueType>() * j * i / N0)*s;
       }
     }
     if (N1 > 0) {
@@ -273,10 +281,262 @@ template <class ValueType> class FFT {
     M1 = M0.Transpose();
   }
 
+  StaticArray<Long,2> dim;
+  FFT_Type fft_type;
+  Long howmany;
   FFTPlan plan;
 };
 
+template <class ValueType> class FFT : public FFT_Generic<ValueType, FFT<ValueType>> {};
+
+#ifdef SCTL_HAVE_FFTW
+template <> class FFT<double> : public FFT_Generic<double, FFT<double>> {
+
+  typedef double ValueType;
+
+ public:
+
+  ~FFT() { if (this->Dim(0) * this->Dim(1)) fftw_destroy_plan(plan); }
+
+  void Setup(FFT_Type fft_type_, Long howmany_, const Vector<Long>& dim_vec) {
+    if (Dim(0) * Dim(1)) fftw_destroy_plan(plan);
+    Long rank = dim_vec.Dim();
+    this->fft_type = fft_type_;
+    this->howmany = howmany_;
+    Long N0, N1;
+    { // Set N0, N1
+      Long N = howmany;
+      for (auto ni : dim_vec) N *= ni;
+      if (fft_type == FFT_Type::R2C) {
+        N0 = N;
+        N1 = (N / dim_vec[rank - 1]) * (dim_vec[rank - 1] / 2 + 1) * 2;
+      } else if (fft_type == FFT_Type::C2C) {
+        N0 = N * 2;
+        N1 = N * 2;
+      } else if (fft_type == FFT_Type::C2C_INV) {
+        N0 = N * 2;
+        N1 = N * 2;
+      } else if (fft_type == FFT_Type::C2R) {
+        N0 = (N / dim_vec[rank - 1]) * (dim_vec[rank - 1] / 2 + 1) * 2;
+        N1 = N;
+      }
+      this->dim[0] = N0;
+      this->dim[1] = N1;
+    }
+    if (!N0 * N1) return;
+
+    in .ReInit(N0);
+    out.ReInit(N1);
+    Vector<int> dim_vec_(rank);
+    for (Integer i = 0; i < rank; i++) dim_vec_[i] = dim_vec[i];
+
+    if (fft_type == FFT_Type::R2C) {
+      plan = fftw_plan_many_dft_r2c(rank, &dim_vec_[0], howmany_, (double*)&in[0], NULL, 1, N0 / howmany, (fftw_complex*)&out[0], NULL, 1, N1 / 2 / howmany, FFTW_ESTIMATE);
+    } else if (fft_type == FFT_Type::C2C) {
+      plan = fftw_plan_many_dft(rank, &dim_vec_[0], howmany_, (fftw_complex*)&in[0], NULL, 1, N0 / 2 / howmany, (fftw_complex*)&out[0], NULL, 1, N1 / 2 / howmany, FFTW_FORWARD, FFTW_ESTIMATE);
+    } else if (fft_type == FFT_Type::C2C_INV) {
+      plan = fftw_plan_many_dft(rank, &dim_vec_[0], howmany_, (fftw_complex*)&in[0], NULL, 1, N0 / 2 / howmany, (fftw_complex*)&out[0], NULL, 1, N1 / 2 / howmany, FFTW_BACKWARD, FFTW_ESTIMATE);
+    } else if (fft_type == FFT_Type::C2R) {
+      plan = fftw_plan_many_dft_c2r(rank, &dim_vec_[0], howmany_, (fftw_complex*)&in[0], NULL, 1, N0 / 2 / howmany, (double*)&out[0], NULL, 1, N1 / howmany, FFTW_ESTIMATE);
+    }
+  }
+
+  void Execute(const Vector<ValueType>& in, Vector<ValueType>& out) const {
+    Long N0 = this->Dim(0);
+    Long N1 = this->Dim(1);
+    if (!N0 * N1) return;
+    SCTL_ASSERT_MSG(in.Dim() == N0, "FFT: Wrong input size.");
+    if (out.Dim() != N1) out.ReInit(N1);
+
+    ValueType s = 0;
+    if (fft_type == FFT_Type::R2C) {
+      s = 1 / sqrt<ValueType>(N0 / howmany);
+      fftw_execute_dft_r2c(plan, (double*)&in[0], (fftw_complex*)&out[0]);
+    } else if (fft_type == FFT_Type::C2C) {
+      s = 1 / sqrt<ValueType>(N0 / howmany * (ValueType)0.5);
+      fftw_execute_dft(plan, (fftw_complex*)&in[0], (fftw_complex*)&out[0]);
+    } else if (fft_type == FFT_Type::C2C_INV) {
+      s = 1 / sqrt<ValueType>(N1 / howmany * (ValueType)0.5);
+      fftw_execute_dft(plan, (fftw_complex*)&in[0], (fftw_complex*)&out[0]);
+    } else if (fft_type == FFT_Type::C2R) {
+      s = 1 / sqrt<ValueType>(N1 / howmany);
+      fftw_execute_dft_c2r(plan, (fftw_complex*)&in[0], (double*)&out[0]);
+    }
+    for (auto& x : out) x *= s;
+  }
+
+ private:
+
+  Vector<ValueType> in, out;
+  fftw_plan plan;
+};
+#endif
+
+#ifdef SCTL_HAVE_FFTWF
+template <> class FFT<float> : public FFT_Generic<float, FFT<float>> {
+
+  typedef float ValueType;
+
+ public:
+
+  ~FFT() { if (this->Dim(0) * this->Dim(1)) fftwf_destroy_plan(plan); }
+
+  void Setup(FFT_Type fft_type_, Long howmany_, const Vector<Long>& dim_vec) {
+    if (Dim(0) * Dim(1)) fftwf_destroy_plan(plan);
+    Long rank = dim_vec.Dim();
+    this->fft_type = fft_type_;
+    this->howmany = howmany_;
+    Long N0, N1;
+    { // Set N0, N1
+      Long N = howmany;
+      for (auto ni : dim_vec) N *= ni;
+      if (fft_type == FFT_Type::R2C) {
+        N0 = N;
+        N1 = (N / dim_vec[rank - 1]) * (dim_vec[rank - 1] / 2 + 1) * 2;
+      } else if (fft_type == FFT_Type::C2C) {
+        N0 = N * 2;
+        N1 = N * 2;
+      } else if (fft_type == FFT_Type::C2C_INV) {
+        N0 = N * 2;
+        N1 = N * 2;
+      } else if (fft_type == FFT_Type::C2R) {
+        N0 = (N / dim_vec[rank - 1]) * (dim_vec[rank - 1] / 2 + 1) * 2;
+        N1 = N;
+      }
+      this->dim[0] = N0;
+      this->dim[1] = N1;
+    }
+    if (!N0 * N1) return;
 
+    in .ReInit(N0);
+    out.ReInit(N1);
+    Vector<int> dim_vec_(rank);
+    for (Integer i = 0; i < rank; i++) dim_vec_[i] = dim_vec[i];
+
+    if (fft_type == FFT_Type::R2C) {
+      plan = fftwf_plan_many_dft_r2c(rank, &dim_vec_[0], howmany_, (float*)&in[0], NULL, 1, N0 / howmany, (fftwf_complex*)&out[0], NULL, 1, N1 / 2 / howmany, FFTW_ESTIMATE);
+    } else if (fft_type == FFT_Type::C2C) {
+      plan = fftwf_plan_many_dft(rank, &dim_vec_[0], howmany_, (fftwf_complex*)&in[0], NULL, 1, N0 / 2 / howmany, (fftwf_complex*)&out[0], NULL, 1, N1 / 2 / howmany, FFTW_FORWARD, FFTW_ESTIMATE);
+    } else if (fft_type == FFT_Type::C2C_INV) {
+      plan = fftwf_plan_many_dft(rank, &dim_vec_[0], howmany_, (fftwf_complex*)&in[0], NULL, 1, N0 / 2 / howmany, (fftwf_complex*)&out[0], NULL, 1, N1 / 2 / howmany, FFTW_BACKWARD, FFTW_ESTIMATE);
+    } else if (fft_type == FFT_Type::C2R) {
+      plan = fftwf_plan_many_dft_c2r(rank, &dim_vec_[0], howmany_, (fftwf_complex*)&in[0], NULL, 1, N0 / 2 / howmany, (float*)&out[0], NULL, 1, N1 / howmany, FFTW_ESTIMATE);
+    }
+  }
+
+  void Execute(const Vector<ValueType>& in, Vector<ValueType>& out) const {
+    Long N0 = this->Dim(0);
+    Long N1 = this->Dim(1);
+    if (!N0 * N1) return;
+    SCTL_ASSERT_MSG(in.Dim() == N0, "FFT: Wrong input size.");
+    if (out.Dim() != N1) out.ReInit(N1);
+
+    ValueType s = 0;
+    if (fft_type == FFT_Type::R2C) {
+      s = 1 / sqrt<ValueType>(N0 / howmany);
+      fftwf_execute_dft_r2c(plan, (float*)&in[0], (fftwf_complex*)&out[0]);
+    } else if (fft_type == FFT_Type::C2C) {
+      s = 1 / sqrt<ValueType>(N0 / howmany * (ValueType)0.5);
+      fftwf_execute_dft(plan, (fftwf_complex*)&in[0], (fftwf_complex*)&out[0]);
+    } else if (fft_type == FFT_Type::C2C_INV) {
+      s = 1 / sqrt<ValueType>(N1 / howmany * (ValueType)0.5);
+      fftwf_execute_dft(plan, (fftwf_complex*)&in[0], (fftwf_complex*)&out[0]);
+    } else if (fft_type == FFT_Type::C2R) {
+      s = 1 / sqrt<ValueType>(N1 / howmany);
+      fftwf_execute_dft_c2r(plan, (fftwf_complex*)&in[0], (float*)&out[0]);
+    }
+    for (auto& x : out) x *= s;
+  }
+
+ private:
+
+  Vector<ValueType> in, out;
+  fftwf_plan plan;
+};
+#endif
+
+#ifdef SCTL_HAVE_FFTWL
+template <> class FFT<long double> : public FFT_Generic<long double, FFT<long double>> {
+
+  typedef long double ValueType;
+
+ public:
+
+  ~FFT() { if (this->Dim(0) * this->Dim(1)) fftwl_destroy_plan(plan); }
+
+  void Setup(FFT_Type fft_type_, Long howmany_, const Vector<Long>& dim_vec) {
+    if (Dim(0) * Dim(1)) fftwl_destroy_plan(plan);
+    Long rank = dim_vec.Dim();
+    this->fft_type = fft_type_;
+    this->howmany = howmany_;
+    Long N0, N1;
+    { // Set N0, N1
+      Long N = howmany;
+      for (auto ni : dim_vec) N *= ni;
+      if (fft_type == FFT_Type::R2C) {
+        N0 = N;
+        N1 = (N / dim_vec[rank - 1]) * (dim_vec[rank - 1] / 2 + 1) * 2;
+      } else if (fft_type == FFT_Type::C2C) {
+        N0 = N * 2;
+        N1 = N * 2;
+      } else if (fft_type == FFT_Type::C2C_INV) {
+        N0 = N * 2;
+        N1 = N * 2;
+      } else if (fft_type == FFT_Type::C2R) {
+        N0 = (N / dim_vec[rank - 1]) * (dim_vec[rank - 1] / 2 + 1) * 2;
+        N1 = N;
+      }
+      this->dim[0] = N0;
+      this->dim[1] = N1;
+    }
+    if (!N0 * N1) return;
+
+    in .ReInit(N0);
+    out.ReInit(N1);
+    Vector<int> dim_vec_(rank);
+    for (Integer i = 0; i < rank; i++) dim_vec_[i] = dim_vec[i];
+
+    if (fft_type == FFT_Type::R2C) {
+      plan = fftwl_plan_many_dft_r2c(rank, &dim_vec_[0], howmany_, (long double*)&in[0], NULL, 1, N0 / howmany, (fftwl_complex*)&out[0], NULL, 1, N1 / 2 / howmany, FFTW_ESTIMATE);
+    } else if (fft_type == FFT_Type::C2C) {
+      plan = fftwl_plan_many_dft(rank, &dim_vec_[0], howmany_, (fftwl_complex*)&in[0], NULL, 1, N0 / 2 / howmany, (fftwl_complex*)&out[0], NULL, 1, N1 / 2 / howmany, FFTW_FORWARD, FFTW_ESTIMATE);
+    } else if (fft_type == FFT_Type::C2C_INV) {
+      plan = fftwl_plan_many_dft(rank, &dim_vec_[0], howmany_, (fftwl_complex*)&in[0], NULL, 1, N0 / 2 / howmany, (fftwl_complex*)&out[0], NULL, 1, N1 / 2 / howmany, FFTW_BACKWARD, FFTW_ESTIMATE);
+    } else if (fft_type == FFT_Type::C2R) {
+      plan = fftwl_plan_many_dft_c2r(rank, &dim_vec_[0], howmany_, (fftwl_complex*)&in[0], NULL, 1, N0 / 2 / howmany, (long double*)&out[0], NULL, 1, N1 / howmany, FFTW_ESTIMATE);
+    }
+  }
+
+  void Execute(const Vector<ValueType>& in, Vector<ValueType>& out) const {
+    Long N0 = this->Dim(0);
+    Long N1 = this->Dim(1);
+    if (!N0 * N1) return;
+    SCTL_ASSERT_MSG(in.Dim() == N0, "FFT: Wrong input size.");
+    if (out.Dim() != N1) out.ReInit(N1);
+
+    ValueType s = 0;
+    if (fft_type == FFT_Type::R2C) {
+      s = 1 / sqrt<ValueType>(N0 / howmany);
+      fftwl_execute_dft_r2c(plan, (long double*)&in[0], (fftwl_complex*)&out[0]);
+    } else if (fft_type == FFT_Type::C2C) {
+      s = 1 / sqrt<ValueType>(N0 / howmany * (ValueType)0.5);
+      fftwl_execute_dft(plan, (fftwl_complex*)&in[0], (fftwl_complex*)&out[0]);
+    } else if (fft_type == FFT_Type::C2C_INV) {
+      s = 1 / sqrt<ValueType>(N1 / howmany * (ValueType)0.5);
+      fftwl_execute_dft(plan, (fftwl_complex*)&in[0], (fftwl_complex*)&out[0]);
+    } else if (fft_type == FFT_Type::C2R) {
+      s = 1 / sqrt<ValueType>(N1 / howmany);
+      fftwl_execute_dft_c2r(plan, (fftwl_complex*)&in[0], (long double*)&out[0]);
+    }
+    for (auto& x : out) x *= s;
+  }
+
+ private:
+
+  Vector<ValueType> in, out;
+  fftwl_plan plan;
+};
+#endif
 
 }  // end namespace