fft_wrapper.hpp 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. /**
  2. * \file mat_utils.hpp
  3. * \author Dhairya Malhotra, dhairya.malhotra@gmail.com
  4. * \date 2-11-2011
  5. * \brief This file contains FFTW3 wrapper functions.
  6. */
  7. #ifndef _PVFMM_FFT_WRAPPER_
  8. #define _PVFMM_FFT_WRAPPER_
  9. #ifdef __INTEL_OFFLOAD
  10. #pragma offload_attribute(push,target(mic))
  11. #endif
  12. #include <fftw3.h>
  13. #ifdef FFTW3_MKL
  14. #include <fftw3_mkl.h>
  15. #endif
  16. namespace pvfmm{
  17. template<class T>
  18. struct FFTW_t{};
  19. #ifdef PVFMM_HAVE_FFTW
  20. template<>
  21. struct FFTW_t<double>{
  22. typedef fftw_plan plan;
  23. typedef fftw_complex cplx;
  24. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  25. double *in, const int *inembed, int istride, int idist,
  26. fftw_complex *out, const int *onembed, int ostride, int odist, unsigned flags){
  27. #ifdef FFTW3_MKL
  28. int omp_p0=omp_get_num_threads();
  29. int omp_p1=omp_get_max_threads();
  30. fftw3_mkl.number_of_user_threads = (omp_p0>omp_p1?omp_p0:omp_p1);
  31. #endif
  32. return fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride,
  33. idist, out, onembed, ostride, odist, flags);
  34. }
  35. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  36. fftw_complex *in, const int *inembed, int istride, int idist,
  37. double *out, const int *onembed, int ostride, int odist, unsigned flags){
  38. #ifdef FFTW3_MKL
  39. int omp_p0=omp_get_num_threads();
  40. int omp_p1=omp_get_max_threads();
  41. fftw3_mkl.number_of_user_threads = (omp_p0>omp_p1?omp_p0:omp_p1);
  42. #endif
  43. return fftw_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
  44. out, onembed, ostride, odist, flags);
  45. }
  46. static void fft_execute_dft_r2c(const fftw_plan p, double *in, fftw_complex *out){
  47. fftw_execute_dft_r2c(p, in, out);
  48. }
  49. static void fft_execute_dft_c2r(const fftw_plan p, fftw_complex *in, double *out){
  50. fftw_execute_dft_c2r(p, in, out);
  51. }
  52. static void fft_destroy_plan(fftw_plan plan){
  53. fftw_destroy_plan(plan);
  54. }
  55. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  56. ::fftw_flops(p, add, mul, fma);
  57. }
  58. };
  59. #endif
  60. #ifdef PVFMM_HAVE_FFTWF
  61. template<>
  62. struct FFTW_t<float>{
  63. typedef fftwf_plan plan;
  64. typedef fftwf_complex cplx;
  65. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  66. float *in, const int *inembed, int istride, int idist,
  67. fftwf_complex *out, const int *onembed, int ostride, int odist, unsigned flags){
  68. return fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride,
  69. idist, out, onembed, ostride, odist, flags);
  70. }
  71. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  72. fftwf_complex *in, const int *inembed, int istride, int idist,
  73. float *out, const int *onembed, int ostride, int odist, unsigned flags){
  74. return fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
  75. out, onembed, ostride, odist, flags);
  76. }
  77. static void fft_execute_dft_r2c(const fftwf_plan p, float *in, fftwf_complex *out){
  78. fftwf_execute_dft_r2c(p, in, out);
  79. }
  80. static void fft_execute_dft_c2r(const fftwf_plan p, fftwf_complex *in, float *out){
  81. fftwf_execute_dft_c2r(p, in, out);
  82. }
  83. static void fft_destroy_plan(fftwf_plan plan){
  84. fftwf_destroy_plan(plan);
  85. }
  86. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  87. ::fftwf_flops(p, add, mul, fma);
  88. }
  89. };
  90. #endif
  91. }//end namespace
  92. #ifdef __INTEL_OFFLOAD
  93. #pragma offload_attribute(pop)
  94. #endif
  95. #endif //_PVFMM_FFT_WRAPPER_