fft_wrapper.hpp 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. /**
  2. * \file fft_wrapper.hpp
  3. * \author Dhairya Malhotra, dhairya.malhotra@gmail.com
  4. * \date 2-11-2011
  5. * \brief This file contains FFTW3 wrapper functions.
  6. */
  7. #include <cmath>
  8. #include <cassert>
  9. #include <vector>
  10. #include <fftw3.h>
  11. #ifdef FFTW3_MKL
  12. #include <fftw3_mkl.h>
  13. #endif
  14. #include <pvfmm_common.hpp>
  15. #include <mem_utils.hpp>
  16. #include <matrix.hpp>
  17. #ifndef _PVFMM_FFT_WRAPPER_
  18. #define _PVFMM_FFT_WRAPPER_
  19. namespace pvfmm{
  20. template<class T>
  21. struct FFTW_t{
  22. struct plan{
  23. std::vector<size_t> dim;
  24. std::vector<Matrix<T> > M;
  25. size_t howmany;
  26. };
  27. struct cplx{
  28. T real;
  29. T imag;
  30. };
  31. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  32. T *in, const int *inembed, int istride, int idist,
  33. cplx *out, const int *onembed, int ostride, int odist, unsigned flags){
  34. assert(inembed==NULL);
  35. assert(onembed==NULL);
  36. assert(istride==1);
  37. assert(ostride==1);
  38. plan p;
  39. p.howmany=howmany;
  40. { // r2c
  41. p.dim.push_back(n[rank-1]);
  42. p.M.push_back(fft_r2c(n[rank-1]));
  43. }
  44. for(int i=rank-2;i>=0;i--){ // c2c
  45. p.dim.push_back(n[i]);
  46. p.M.push_back(fft_c2c(n[i]));
  47. }
  48. size_t N1=1, N2=1;
  49. for(size_t i=0;i<p.dim.size();i++){
  50. N1*=p.dim[i];
  51. N2*=p.M[i].Dim(1)/2;
  52. }
  53. assert(idist==N1);
  54. assert(odist==N2);
  55. return p;
  56. }
  57. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  58. cplx *in, const int *inembed, int istride, int idist,
  59. T *out, const int *onembed, int ostride, int odist, unsigned flags){
  60. assert(inembed==NULL);
  61. assert(onembed==NULL);
  62. assert(istride==1);
  63. assert(ostride==1);
  64. plan p;
  65. p.howmany=howmany;
  66. for(size_t i=0;i<rank-1;i++){ // c2c
  67. p.dim.push_back(n[i]);
  68. p.M.push_back(fft_c2c(n[i]));
  69. }
  70. { // c2r
  71. p.dim.push_back(n[rank-1]);
  72. p.M.push_back(fft_c2r(n[rank-1]));
  73. }
  74. size_t N1=1, N2=1;
  75. for(size_t i=0;i<p.dim.size();i++){
  76. N1*=p.dim[i];
  77. N2*=p.M[i].Dim(0)/2;
  78. }
  79. assert(idist==N2);
  80. assert(odist==N1);
  81. return p;
  82. }
  83. static void fft_execute_dft_r2c(const plan p, T *in, cplx *out){
  84. size_t N1=p.howmany, N2=p.howmany;
  85. for(size_t i=0;i<p.dim.size();i++){
  86. N1*=p.dim[i];
  87. N2*=p.M[i].Dim(1)/2;
  88. }
  89. std::vector<T> buff_(N1+2*N2);
  90. T* buff=&buff_[0];
  91. { // r2c
  92. size_t i=0;
  93. const Matrix<T>& M=p.M[i];
  94. assert(2*N2/M.Dim(1)==N1/M.Dim(0));
  95. Matrix<T> x( N1/M.Dim(0),M.Dim(0), in,false);
  96. Matrix<T> y(2*N2/M.Dim(1),M.Dim(1),buff,false);
  97. Matrix<T>::DGEMM(y, x, M);
  98. transpose<cplx>(2*N2/M.Dim(1), M.Dim(1)/2, (cplx*)buff);
  99. }
  100. for(size_t i=1;i<p.dim.size();i++){ // c2c
  101. const Matrix<T>& M=p.M[i];
  102. assert(M.Dim(0)==M.Dim(1));
  103. Matrix<T> x(2*N2/M.Dim(0),M.Dim(0),buff); // TODO: optimize this
  104. Matrix<T> y(2*N2/M.Dim(1),M.Dim(1),buff,false);
  105. Matrix<T>::DGEMM(y, x, M);
  106. transpose<cplx>(2*N2/M.Dim(1), M.Dim(1)/2, (cplx*)buff);
  107. }
  108. { // howmany
  109. transpose<cplx>(N2/p.howmany, p.howmany, (cplx*)buff);
  110. mem::memcopy(out,buff,2*N2*sizeof(T));
  111. }
  112. }
  113. static void fft_execute_dft_c2r(const plan p, cplx *in, T *out){
  114. size_t N1=p.howmany, N2=p.howmany;
  115. for(size_t i=0;i<p.dim.size();i++){
  116. N1*=p.dim[i];
  117. N2*=p.M[i].Dim(0)/2;
  118. }
  119. std::vector<T> buff_(N1+2*N2);
  120. T* buff=&buff_[0];
  121. { // howmany
  122. mem::memcopy(buff,in,2*N2*sizeof(T));
  123. transpose<cplx>(p.howmany, N2/p.howmany, (cplx*)buff);
  124. }
  125. for(size_t i=0;i<p.dim.size()-1;i++){ // c2c
  126. Matrix<T> M=p.M[i];
  127. assert(M.Dim(0)==M.Dim(1));
  128. transpose<cplx>(M.Dim(0)/2, 2*N2/M.Dim(0), (cplx*)buff);
  129. Matrix<T> y(2*N2/M.Dim(0),M.Dim(0),buff); // TODO: optimize this
  130. Matrix<T> x(2*N2/M.Dim(1),M.Dim(1),buff,false);
  131. Matrix<T>::DGEMM(x, y, M.Transpose());
  132. }
  133. { // r2c
  134. size_t i=p.dim.size()-1;
  135. const Matrix<T>& M=p.M[i];
  136. assert(2*N2/M.Dim(0)==N1/M.Dim(1));
  137. transpose<cplx>(M.Dim(0)/2, 2*N2/M.Dim(0), (cplx*)buff);
  138. Matrix<T> y(2*N2/M.Dim(0),M.Dim(0),buff,false);
  139. Matrix<T> x( N1/M.Dim(1),M.Dim(1), out,false);
  140. Matrix<T>::DGEMM(x, y, M);
  141. }
  142. }
  143. static void fft_destroy_plan(plan p){
  144. p.dim.clear();
  145. p.M.clear();
  146. p.howmany=0;
  147. }
  148. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  149. *add=0;
  150. *mul=0;
  151. *fma=0;
  152. }
  153. private:
  154. static Matrix<T> fft_r2c(size_t N1){
  155. size_t N2=(N1/2+1);
  156. Matrix<T> M(N1,2*N2);
  157. for(size_t j=0;j<N1;j++)
  158. for(size_t i=0;i<N2;i++){
  159. M[j][2*i+0]=cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  160. M[j][2*i+1]=sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  161. }
  162. return M;
  163. }
  164. static Matrix<T> fft_c2c(size_t N1){
  165. Matrix<T> M(2*N1,2*N1);
  166. for(size_t i=0;i<N1;i++)
  167. for(size_t j=0;j<N1;j++){
  168. M[2*i+0][2*j+0]=cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  169. M[2*i+1][2*j+0]=sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  170. M[2*i+0][2*j+1]=-sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  171. M[2*i+1][2*j+1]= cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  172. }
  173. return M;
  174. }
  175. static Matrix<T> fft_c2r(size_t N1){
  176. size_t N2=(N1/2+1);
  177. Matrix<T> M(2*N2,N1);
  178. for(size_t i=0;i<N2;i++)
  179. for(size_t j=0;j<N1;j++){
  180. M[2*i+0][j]=2*cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  181. M[2*i+1][j]=2*sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  182. }
  183. if(N2>0){
  184. for(size_t j=0;j<N1;j++){
  185. M[0][j]=M[0][j]*0.5;
  186. M[1][j]=M[1][j]*0.5;
  187. }
  188. }
  189. if(N1%2==0){
  190. for(size_t j=0;j<N1;j++){
  191. M[2*N2-2][j]=M[2*N2-2][j]*0.5;
  192. M[2*N2-1][j]=M[2*N2-1][j]*0.5;
  193. }
  194. }
  195. return M;
  196. }
  197. template <class Y>
  198. static void transpose(size_t dim1, size_t dim2, Y* A){
  199. Matrix<Y> M(dim1, dim2, A);
  200. Matrix<Y> Mt(dim2, dim1, A, false);
  201. Mt=M.Transpose();
  202. }
  203. };
  204. #ifdef PVFMM_HAVE_FFTW
  205. template<>
  206. struct FFTW_t<double>{
  207. typedef fftw_plan plan;
  208. typedef fftw_complex cplx;
  209. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  210. double *in, const int *inembed, int istride, int idist,
  211. fftw_complex *out, const int *onembed, int ostride, int odist, unsigned flags){
  212. #ifdef FFTW3_MKL
  213. int omp_p0=omp_get_num_threads();
  214. int omp_p1=omp_get_max_threads();
  215. fftw3_mkl.number_of_user_threads = (omp_p0>omp_p1?omp_p0:omp_p1);
  216. #endif
  217. return fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride,
  218. idist, out, onembed, ostride, odist, flags);
  219. }
  220. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  221. cplx *in, const int *inembed, int istride, int idist,
  222. double *out, const int *onembed, int ostride, int odist, unsigned flags){
  223. #ifdef FFTW3_MKL
  224. int omp_p0=omp_get_num_threads();
  225. int omp_p1=omp_get_max_threads();
  226. fftw3_mkl.number_of_user_threads = (omp_p0>omp_p1?omp_p0:omp_p1);
  227. #endif
  228. return fftw_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
  229. out, onembed, ostride, odist, flags);
  230. }
  231. static void fft_execute_dft_r2c(const plan p, double *in, cplx *out){
  232. fftw_execute_dft_r2c(p, in, out);
  233. }
  234. static void fft_execute_dft_c2r(const plan p, cplx *in, double *out){
  235. fftw_execute_dft_c2r(p, in, out);
  236. }
  237. static void fft_destroy_plan(plan p){
  238. fftw_destroy_plan(p);
  239. }
  240. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  241. ::fftw_flops(p, add, mul, fma);
  242. }
  243. };
  244. #endif
  245. #ifdef PVFMM_HAVE_FFTWF
  246. template<>
  247. struct FFTW_t<float>{
  248. typedef fftwf_plan plan;
  249. typedef fftwf_complex cplx;
  250. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  251. float *in, const int *inembed, int istride, int idist,
  252. cplx *out, const int *onembed, int ostride, int odist, unsigned flags){
  253. return fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride,
  254. idist, out, onembed, ostride, odist, flags);
  255. }
  256. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  257. cplx *in, const int *inembed, int istride, int idist,
  258. float *out, const int *onembed, int ostride, int odist, unsigned flags){
  259. return fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
  260. out, onembed, ostride, odist, flags);
  261. }
  262. static void fft_execute_dft_r2c(const plan p, float *in, cplx *out){
  263. fftwf_execute_dft_r2c(p, in, out);
  264. }
  265. static void fft_execute_dft_c2r(const plan p, cplx *in, float *out){
  266. fftwf_execute_dft_c2r(p, in, out);
  267. }
  268. static void fft_destroy_plan(plan p){
  269. fftwf_destroy_plan(p);
  270. }
  271. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  272. ::fftwf_flops(p, add, mul, fma);
  273. }
  274. };
  275. #endif
  276. }//end namespace
  277. #endif //_PVFMM_FFT_WRAPPER_