fft_wrapper.hpp 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. /**
  2. * \file fft_wrapper.hpp
  3. * \author Dhairya Malhotra, dhairya.malhotra@gmail.com
  4. * \date 2-11-2011
  5. * \brief This file contains FFTW3 wrapper functions.
  6. */
  7. #include <cmath>
  8. #include <cassert>
  9. #include <cstdlib>
  10. #include <vector>
  11. #if defined(PVFMM_HAVE_FFTW) || defined(PVFMM_HAVE_FFTWF)
  12. #include <fftw3.h>
  13. #ifdef FFTW3_MKL
  14. #include <fftw3_mkl.h>
  15. #endif
  16. #endif
  17. #include <pvfmm_common.hpp>
  18. #include <mem_mgr.hpp>
  19. #include <matrix.hpp>
  20. #ifndef _PVFMM_FFT_WRAPPER_
  21. #define _PVFMM_FFT_WRAPPER_
  22. namespace pvfmm{
  23. template<class T>
  24. struct FFTW_t{
  25. struct plan{
  26. std::vector<size_t> dim;
  27. std::vector<Matrix<T> > M;
  28. size_t howmany;
  29. };
  30. struct cplx{
  31. T real;
  32. T imag;
  33. };
  34. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  35. T *in, const int *inembed, int istride, int idist,
  36. cplx *out, const int *onembed, int ostride, int odist){
  37. assert(inembed==NULL);
  38. assert(onembed==NULL);
  39. assert(istride==1);
  40. assert(ostride==1);
  41. plan p;
  42. p.howmany=howmany;
  43. { // r2c
  44. p.dim.push_back(n[rank-1]);
  45. p.M.push_back(fft_r2c(n[rank-1]));
  46. }
  47. for(int i=rank-2;i>=0;i--){ // c2c
  48. p.dim.push_back(n[i]);
  49. p.M.push_back(fft_c2c(n[i]));
  50. }
  51. size_t N1=1, N2=1;
  52. for(size_t i=0;i<p.dim.size();i++){
  53. N1*=p.dim[i];
  54. N2*=p.M[i].Dim(1)/2;
  55. }
  56. assert(idist==N1);
  57. assert(odist==N2);
  58. return p;
  59. }
  60. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  61. cplx *in, const int *inembed, int istride, int idist,
  62. T *out, const int *onembed, int ostride, int odist){
  63. assert(inembed==NULL);
  64. assert(onembed==NULL);
  65. assert(istride==1);
  66. assert(ostride==1);
  67. plan p;
  68. p.howmany=howmany;
  69. for(size_t i=0;i<rank-1;i++){ // c2c
  70. p.dim.push_back(n[i]);
  71. p.M.push_back(fft_c2c(n[i]));
  72. }
  73. { // c2r
  74. p.dim.push_back(n[rank-1]);
  75. p.M.push_back(fft_c2r(n[rank-1]));
  76. }
  77. size_t N1=1, N2=1;
  78. for(size_t i=0;i<p.dim.size();i++){
  79. N1*=p.dim[i];
  80. N2*=p.M[i].Dim(0)/2;
  81. }
  82. assert(idist==N2);
  83. assert(odist==N1);
  84. return p;
  85. }
  86. static void fft_execute_dft_r2c(const plan p, T *in, cplx *out){
  87. size_t N1=p.howmany, N2=p.howmany;
  88. for(size_t i=0;i<p.dim.size();i++){
  89. N1*=p.dim[i];
  90. N2*=p.M[i].Dim(1)/2;
  91. }
  92. std::vector<T> buff_(N1+2*N2);
  93. T* buff=&buff_[0];
  94. { // r2c
  95. size_t i=0;
  96. const Matrix<T>& M=p.M[i];
  97. assert(2*N2/M.Dim(1)==N1/M.Dim(0));
  98. Matrix<T> x( N1/M.Dim(0),M.Dim(0), in,false);
  99. Matrix<T> y(2*N2/M.Dim(1),M.Dim(1),buff,false);
  100. Matrix<T>::GEMM(y, x, M);
  101. transpose<cplx>(2*N2/M.Dim(1), M.Dim(1)/2, (cplx*)buff);
  102. }
  103. for(size_t i=1;i<p.dim.size();i++){ // c2c
  104. const Matrix<T>& M=p.M[i];
  105. assert(M.Dim(0)==M.Dim(1));
  106. Matrix<T> x(2*N2/M.Dim(0),M.Dim(0),buff); // TODO: optimize this
  107. Matrix<T> y(2*N2/M.Dim(1),M.Dim(1),buff,false);
  108. Matrix<T>::GEMM(y, x, M);
  109. transpose<cplx>(2*N2/M.Dim(1), M.Dim(1)/2, (cplx*)buff);
  110. }
  111. { // howmany
  112. transpose<cplx>(N2/p.howmany, p.howmany, (cplx*)buff);
  113. mem::memcopy(out,buff,2*N2*sizeof(T));
  114. }
  115. }
  116. static void fft_execute_dft_c2r(const plan p, cplx *in, T *out){
  117. size_t N1=p.howmany, N2=p.howmany;
  118. for(size_t i=0;i<p.dim.size();i++){
  119. N1*=p.dim[i];
  120. N2*=p.M[i].Dim(0)/2;
  121. }
  122. std::vector<T> buff_(N1+2*N2);
  123. T* buff=&buff_[0];
  124. { // howmany
  125. mem::memcopy(buff,in,2*N2*sizeof(T));
  126. transpose<cplx>(p.howmany, N2/p.howmany, (cplx*)buff);
  127. }
  128. for(size_t i=0;i<p.dim.size()-1;i++){ // c2c
  129. Matrix<T> M=p.M[i];
  130. assert(M.Dim(0)==M.Dim(1));
  131. transpose<cplx>(M.Dim(0)/2, 2*N2/M.Dim(0), (cplx*)buff);
  132. Matrix<T> y(2*N2/M.Dim(0),M.Dim(0),buff); // TODO: optimize this
  133. Matrix<T> x(2*N2/M.Dim(1),M.Dim(1),buff,false);
  134. Matrix<T>::GEMM(x, y, M.Transpose());
  135. }
  136. { // r2c
  137. size_t i=p.dim.size()-1;
  138. const Matrix<T>& M=p.M[i];
  139. assert(2*N2/M.Dim(0)==N1/M.Dim(1));
  140. transpose<cplx>(M.Dim(0)/2, 2*N2/M.Dim(0), (cplx*)buff);
  141. Matrix<T> y(2*N2/M.Dim(0),M.Dim(0),buff,false);
  142. Matrix<T> x( N1/M.Dim(1),M.Dim(1), out,false);
  143. Matrix<T>::GEMM(x, y, M);
  144. }
  145. }
  146. static void fft_destroy_plan(plan p){
  147. p.dim.clear();
  148. p.M.clear();
  149. p.howmany=0;
  150. }
  151. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  152. *add=0;
  153. *mul=0;
  154. *fma=0;
  155. }
  156. private:
  157. static Matrix<T> fft_r2c(size_t N1){
  158. size_t N2=(N1/2+1);
  159. Matrix<T> M(N1,2*N2);
  160. for(size_t j=0;j<N1;j++)
  161. for(size_t i=0;i<N2;i++){
  162. M[j][2*i+0]=pvfmm::cos<T>(j*i*(1.0/N1)*2.0*const_pi<T>());
  163. M[j][2*i+1]=pvfmm::sin<T>(j*i*(1.0/N1)*2.0*const_pi<T>());
  164. }
  165. return M;
  166. }
  167. static Matrix<T> fft_c2c(size_t N1){
  168. Matrix<T> M(2*N1,2*N1);
  169. for(size_t i=0;i<N1;i++)
  170. for(size_t j=0;j<N1;j++){
  171. M[2*i+0][2*j+0]=pvfmm::cos<T>(j*i*(1.0/N1)*2.0*const_pi<T>());
  172. M[2*i+1][2*j+0]=pvfmm::sin<T>(j*i*(1.0/N1)*2.0*const_pi<T>());
  173. M[2*i+0][2*j+1]=-pvfmm::sin<T>(j*i*(1.0/N1)*2.0*const_pi<T>());
  174. M[2*i+1][2*j+1]= pvfmm::cos<T>(j*i*(1.0/N1)*2.0*const_pi<T>());
  175. }
  176. return M;
  177. }
  178. static Matrix<T> fft_c2r(size_t N1){
  179. size_t N2=(N1/2+1);
  180. Matrix<T> M(2*N2,N1);
  181. for(size_t i=0;i<N2;i++)
  182. for(size_t j=0;j<N1;j++){
  183. M[2*i+0][j]=2*pvfmm::cos<T>(j*i*(1.0/N1)*2.0*const_pi<T>());
  184. M[2*i+1][j]=2*pvfmm::sin<T>(j*i*(1.0/N1)*2.0*const_pi<T>());
  185. }
  186. if(N2>0){
  187. for(size_t j=0;j<N1;j++){
  188. M[0][j]=M[0][j]*0.5;
  189. M[1][j]=M[1][j]*0.5;
  190. }
  191. }
  192. if(N1%2==0){
  193. for(size_t j=0;j<N1;j++){
  194. M[2*N2-2][j]=M[2*N2-2][j]*0.5;
  195. M[2*N2-1][j]=M[2*N2-1][j]*0.5;
  196. }
  197. }
  198. return M;
  199. }
  200. template <class Y>
  201. static void transpose(size_t dim1, size_t dim2, Y* A){
  202. Matrix<Y> M(dim1, dim2, A);
  203. Matrix<Y> Mt(dim2, dim1, A, false);
  204. Mt=M.Transpose();
  205. }
  206. };
  207. #ifdef PVFMM_HAVE_FFTW
  208. template<>
  209. struct FFTW_t<double>{
  210. typedef fftw_plan plan;
  211. typedef fftw_complex cplx;
  212. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  213. double *in, const int *inembed, int istride, int idist,
  214. fftw_complex *out, const int *onembed, int ostride, int odist){
  215. #ifdef FFTW3_MKL
  216. int omp_p0=omp_get_num_threads();
  217. int omp_p1=omp_get_max_threads();
  218. fftw3_mkl.number_of_user_threads = (omp_p0>omp_p1?omp_p0:omp_p1);
  219. #endif
  220. return fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride,
  221. idist, out, onembed, ostride, odist, FFTW_ESTIMATE);
  222. }
  223. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  224. cplx *in, const int *inembed, int istride, int idist,
  225. double *out, const int *onembed, int ostride, int odist){
  226. #ifdef FFTW3_MKL
  227. int omp_p0=omp_get_num_threads();
  228. int omp_p1=omp_get_max_threads();
  229. fftw3_mkl.number_of_user_threads = (omp_p0>omp_p1?omp_p0:omp_p1);
  230. #endif
  231. return fftw_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
  232. out, onembed, ostride, odist, FFTW_ESTIMATE);
  233. }
  234. static void fft_execute_dft_r2c(const plan p, double *in, cplx *out){
  235. fftw_execute_dft_r2c(p, in, out);
  236. }
  237. static void fft_execute_dft_c2r(const plan p, cplx *in, double *out){
  238. fftw_execute_dft_c2r(p, in, out);
  239. }
  240. static void fft_destroy_plan(plan p){
  241. fftw_destroy_plan(p);
  242. }
  243. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  244. ::fftw_flops(p, add, mul, fma);
  245. }
  246. };
  247. #endif
  248. #ifdef PVFMM_HAVE_FFTWF
  249. template<>
  250. struct FFTW_t<float>{
  251. typedef fftwf_plan plan;
  252. typedef fftwf_complex cplx;
  253. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  254. float *in, const int *inembed, int istride, int idist,
  255. cplx *out, const int *onembed, int ostride, int odist){
  256. return fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride,
  257. idist, out, onembed, ostride, odist, FFTW_ESTIMATE);
  258. }
  259. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  260. cplx *in, const int *inembed, int istride, int idist,
  261. float *out, const int *onembed, int ostride, int odist){
  262. return fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
  263. out, onembed, ostride, odist, FFTW_ESTIMATE);
  264. }
  265. static void fft_execute_dft_r2c(const plan p, float *in, cplx *out){
  266. fftwf_execute_dft_r2c(p, in, out);
  267. }
  268. static void fft_execute_dft_c2r(const plan p, cplx *in, float *out){
  269. fftwf_execute_dft_c2r(p, in, out);
  270. }
  271. static void fft_destroy_plan(plan p){
  272. fftwf_destroy_plan(p);
  273. }
  274. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  275. ::fftwf_flops(p, add, mul, fma);
  276. }
  277. };
  278. #endif
  279. }//end namespace
  280. #endif //_PVFMM_FFT_WRAPPER_