Dhairya Malhotra 7 سال پیش
والد
کامیت
caff7ac106
2فایلهای تغییر یافته به همراه27 افزوده شده و 21 حذف شده
  1. 26 20
      include/sctl/fft_wrapper.hpp
  2. 1 1
      include/sctl/intrin_wrapper.hpp

+ 26 - 20
include/sctl/fft_wrapper.hpp

@@ -154,6 +154,12 @@ template <class ScalarType, class ValueType> Complex<ValueType> operator/(const
   return z;
 }
 
+template <class ValueType> std::ostream& operator<<(std::ostream& output, const Complex<ValueType>& V) {
+  output << "(" << V.real <<"," << V.imag << ")";
+  return output;
+}
+
+
 enum class FFT_Type {R2C, C2C, C2C_INV, C2R};
 
 template <class ValueType, class FFT_Derived> class FFT_Generic {
@@ -358,8 +364,8 @@ template <class ValueType, class FFT_Derived> class FFT_Generic {
   }
 
   static void check_align(const Vector<ValueType>& in, const Vector<ValueType>& out) {
-    //SCTL_ASSERT_MSG((((uintptr_t)& in[0]) & ((uintptr_t)(SCTL_MEM_ALIGN - 1))) == 0, "sctl::FFT: Input vector not aligned to " <<SCTL_MEM_ALIGN<<" bits!");
-    //SCTL_ASSERT_MSG((((uintptr_t)&out[0]) & ((uintptr_t)(SCTL_MEM_ALIGN - 1))) == 0, "sctl::FFT: Output vector not aligned to "<<SCTL_MEM_ALIGN<<" bits!");
+    SCTL_ASSERT_MSG((((uintptr_t)& in[0]) & ((uintptr_t)(SCTL_MEM_ALIGN - 1))) == 0, "sctl::FFT: Input vector not aligned to " <<SCTL_MEM_ALIGN<<" bytes!");
+    SCTL_ASSERT_MSG((((uintptr_t)&out[0]) & ((uintptr_t)(SCTL_MEM_ALIGN - 1))) == 0, "sctl::FFT: Output vector not aligned to "<<SCTL_MEM_ALIGN<<" bytes!");
     // TODO: copy to auxiliary array if unaligned
   }
 
@@ -384,6 +390,7 @@ template <> class FFT<double> : public FFT_Generic<double, FFT<double>> {
     if (Dim(0) && Dim(1)) fftw_destroy_plan(plan);
     this->fft_type = fft_type_;
     this->howmany = howmany_;
+    copy_input = false;
     plan = NULL;
 
     Long rank = dim_vec.Dim();
@@ -437,9 +444,7 @@ template <> class FFT<double> : public FFT_Generic<double, FFT<double>> {
       } else if (fft_type == FFT_Type::C2R) {
         plan = fftw_plan_many_dft_c2r(rank, &dim_vec_[0], howmany_, (fftw_complex*)&in[0], NULL, 1, N0 / 2 / howmany, &out[0], NULL, 1, N1 / howmany, FFTW_ESTIMATE);
       }
-      tmp.Swap(in);
-    } else {
-      tmp.ReInit(0);
+      copy_input = true;
     }
     SCTL_ASSERT(plan);
   }
@@ -453,9 +458,10 @@ template <> class FFT<double> : public FFT_Generic<double, FFT<double>> {
     check_align(in, out);
 
     ValueType s = 0;
+    Vector<ValueType> tmp;
     auto in_ptr = in.begin();
-    if (tmp.Dim()) { // Save input
-      assert(tmp.Dim() == N0);
+    if (copy_input) { // Save input
+      tmp.ReInit(N0);
       in_ptr = tmp.begin();
       tmp = in;
     }
@@ -477,7 +483,7 @@ template <> class FFT<double> : public FFT_Generic<double, FFT<double>> {
 
  private:
 
-  mutable Vector<ValueType> tmp;
+  bool copy_input;
   fftw_plan plan;
 };
 #endif
@@ -495,6 +501,7 @@ template <> class FFT<float> : public FFT_Generic<float, FFT<float>> {
     if (Dim(0) && Dim(1)) fftwf_destroy_plan(plan);
     this->fft_type = fft_type_;
     this->howmany = howmany_;
+    copy_input = false;
     plan = NULL;
 
     Long rank = dim_vec.Dim();
@@ -548,9 +555,7 @@ template <> class FFT<float> : public FFT_Generic<float, FFT<float>> {
       } else if (fft_type == FFT_Type::C2R) {
         plan = fftwf_plan_many_dft_c2r(rank, &dim_vec_[0], howmany_, (fftwf_complex*)&in[0], NULL, 1, N0 / 2 / howmany, &out[0], NULL, 1, N1 / howmany, FFTW_ESTIMATE);
       }
-      tmp.Swap(in);
-    } else {
-      tmp.ReInit(0);
+      copy_input = true;
     }
     SCTL_ASSERT(plan);
   }
@@ -564,9 +569,10 @@ template <> class FFT<float> : public FFT_Generic<float, FFT<float>> {
     check_align(in, out);
 
     ValueType s = 0;
+    Vector<ValueType> tmp;
     auto in_ptr = in.begin();
-    if (tmp.Dim()) { // Save input
-      assert(tmp.Dim() == N0);
+    if (copy_input) { // Save input
+      tmp.ReInit(N0);
       in_ptr = tmp.begin();
       tmp = in;
     }
@@ -588,7 +594,7 @@ template <> class FFT<float> : public FFT_Generic<float, FFT<float>> {
 
  private:
 
-  mutable Vector<ValueType> tmp;
+  bool copy_input;
   fftwf_plan plan;
 };
 #endif
@@ -606,6 +612,7 @@ template <> class FFT<long double> : public FFT_Generic<long double, FFT<long do
     if (Dim(0) && Dim(1)) fftwl_destroy_plan(plan);
     this->fft_type = fft_type_;
     this->howmany = howmany_;
+    copy_input = false;
     plan = NULL;
 
     Long rank = dim_vec.Dim();
@@ -657,9 +664,7 @@ template <> class FFT<long double> : public FFT_Generic<long double, FFT<long do
       } else if (fft_type == FFT_Type::C2R) {
         plan = fftwl_plan_many_dft_c2r(rank, &dim_vec_[0], howmany_, (fftwl_complex*)&in[0], NULL, 1, N0 / 2 / howmany, &out[0], NULL, 1, N1 / howmany, FFTW_ESTIMATE);
       }
-      tmp.Swap(in);
-    } else {
-      tmp.ReInit(0);
+      copy_input = true;
     }
     SCTL_ASSERT(plan);
   }
@@ -673,9 +678,10 @@ template <> class FFT<long double> : public FFT_Generic<long double, FFT<long do
     check_align(in, out);
 
     ValueType s = 0;
+    Vector<ValueType> tmp;
     auto in_ptr = in.begin();
-    if (tmp.Dim()) { // Save input
-      assert(tmp.Dim() == N0);
+    if (copy_input) { // Save input
+      tmp.ReInit(N0);
       in_ptr = tmp.begin();
       tmp = in;
     }
@@ -697,7 +703,7 @@ template <> class FFT<long double> : public FFT_Generic<long double, FFT<long do
 
  private:
 
-  mutable Vector<ValueType> tmp;
+  bool copy_input;
   fftwl_plan plan;
 };
 #endif

+ 1 - 1
include/sctl/intrin_wrapper.hpp

@@ -21,7 +21,7 @@
 #include <immintrin.h>
 #endif
 
-// TODO: Check alignment which SCTL_MEMDEBUG is defined
+// TODO: Check alignment when SCTL_MEMDEBUG is defined
 // TODO: Replace pointers with iterators
 
 namespace SCTL_NAMESPACE {