fft_wrapper.hpp 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. /**
  2. * \file fft_wrapper.hpp
  3. * \author Dhairya Malhotra, dhairya.malhotra@gmail.com
  4. * \date 2-11-2011
  5. * \brief This file contains FFTW3 wrapper functions.
  6. */
  7. #include <cmath>
  8. #include <cassert>
  9. #include <cstdlib>
  10. #include <vector>
  11. #include <fftw3.h>
  12. #ifdef FFTW3_MKL
  13. #include <fftw3_mkl.h>
  14. #endif
  15. #include <pvfmm_common.hpp>
  16. #include <mem_utils.hpp>
  17. #include <matrix.hpp>
  18. #ifndef _PVFMM_FFT_WRAPPER_
  19. #define _PVFMM_FFT_WRAPPER_
  20. namespace pvfmm{
  21. template<class T>
  22. struct FFTW_t{
  23. struct plan{
  24. std::vector<size_t> dim;
  25. std::vector<Matrix<T> > M;
  26. size_t howmany;
  27. };
  28. struct cplx{
  29. T real;
  30. T imag;
  31. };
  32. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  33. T *in, const int *inembed, int istride, int idist,
  34. cplx *out, const int *onembed, int ostride, int odist, unsigned flags){
  35. assert(inembed==NULL);
  36. assert(onembed==NULL);
  37. assert(istride==1);
  38. assert(ostride==1);
  39. plan p;
  40. p.howmany=howmany;
  41. { // r2c
  42. p.dim.push_back(n[rank-1]);
  43. p.M.push_back(fft_r2c(n[rank-1]));
  44. }
  45. for(int i=rank-2;i>=0;i--){ // c2c
  46. p.dim.push_back(n[i]);
  47. p.M.push_back(fft_c2c(n[i]));
  48. }
  49. size_t N1=1, N2=1;
  50. for(size_t i=0;i<p.dim.size();i++){
  51. N1*=p.dim[i];
  52. N2*=p.M[i].Dim(1)/2;
  53. }
  54. assert(idist==N1);
  55. assert(odist==N2);
  56. return p;
  57. }
  58. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  59. cplx *in, const int *inembed, int istride, int idist,
  60. T *out, const int *onembed, int ostride, int odist, unsigned flags){
  61. assert(inembed==NULL);
  62. assert(onembed==NULL);
  63. assert(istride==1);
  64. assert(ostride==1);
  65. plan p;
  66. p.howmany=howmany;
  67. for(size_t i=0;i<rank-1;i++){ // c2c
  68. p.dim.push_back(n[i]);
  69. p.M.push_back(fft_c2c(n[i]));
  70. }
  71. { // c2r
  72. p.dim.push_back(n[rank-1]);
  73. p.M.push_back(fft_c2r(n[rank-1]));
  74. }
  75. size_t N1=1, N2=1;
  76. for(size_t i=0;i<p.dim.size();i++){
  77. N1*=p.dim[i];
  78. N2*=p.M[i].Dim(0)/2;
  79. }
  80. assert(idist==N2);
  81. assert(odist==N1);
  82. return p;
  83. }
  84. static void fft_execute_dft_r2c(const plan p, T *in, cplx *out){
  85. size_t N1=p.howmany, N2=p.howmany;
  86. for(size_t i=0;i<p.dim.size();i++){
  87. N1*=p.dim[i];
  88. N2*=p.M[i].Dim(1)/2;
  89. }
  90. std::vector<T> buff_(N1+2*N2);
  91. T* buff=&buff_[0];
  92. { // r2c
  93. size_t i=0;
  94. const Matrix<T>& M=p.M[i];
  95. assert(2*N2/M.Dim(1)==N1/M.Dim(0));
  96. Matrix<T> x( N1/M.Dim(0),M.Dim(0), in,false);
  97. Matrix<T> y(2*N2/M.Dim(1),M.Dim(1),buff,false);
  98. Matrix<T>::DGEMM(y, x, M);
  99. transpose<cplx>(2*N2/M.Dim(1), M.Dim(1)/2, (cplx*)buff);
  100. }
  101. for(size_t i=1;i<p.dim.size();i++){ // c2c
  102. const Matrix<T>& M=p.M[i];
  103. assert(M.Dim(0)==M.Dim(1));
  104. Matrix<T> x(2*N2/M.Dim(0),M.Dim(0),buff); // TODO: optimize this
  105. Matrix<T> y(2*N2/M.Dim(1),M.Dim(1),buff,false);
  106. Matrix<T>::DGEMM(y, x, M);
  107. transpose<cplx>(2*N2/M.Dim(1), M.Dim(1)/2, (cplx*)buff);
  108. }
  109. { // howmany
  110. transpose<cplx>(N2/p.howmany, p.howmany, (cplx*)buff);
  111. mem::memcopy(out,buff,2*N2*sizeof(T));
  112. }
  113. }
  114. static void fft_execute_dft_c2r(const plan p, cplx *in, T *out){
  115. size_t N1=p.howmany, N2=p.howmany;
  116. for(size_t i=0;i<p.dim.size();i++){
  117. N1*=p.dim[i];
  118. N2*=p.M[i].Dim(0)/2;
  119. }
  120. std::vector<T> buff_(N1+2*N2);
  121. T* buff=&buff_[0];
  122. { // howmany
  123. mem::memcopy(buff,in,2*N2*sizeof(T));
  124. transpose<cplx>(p.howmany, N2/p.howmany, (cplx*)buff);
  125. }
  126. for(size_t i=0;i<p.dim.size()-1;i++){ // c2c
  127. Matrix<T> M=p.M[i];
  128. assert(M.Dim(0)==M.Dim(1));
  129. transpose<cplx>(M.Dim(0)/2, 2*N2/M.Dim(0), (cplx*)buff);
  130. Matrix<T> y(2*N2/M.Dim(0),M.Dim(0),buff); // TODO: optimize this
  131. Matrix<T> x(2*N2/M.Dim(1),M.Dim(1),buff,false);
  132. Matrix<T>::DGEMM(x, y, M.Transpose());
  133. }
  134. { // r2c
  135. size_t i=p.dim.size()-1;
  136. const Matrix<T>& M=p.M[i];
  137. assert(2*N2/M.Dim(0)==N1/M.Dim(1));
  138. transpose<cplx>(M.Dim(0)/2, 2*N2/M.Dim(0), (cplx*)buff);
  139. Matrix<T> y(2*N2/M.Dim(0),M.Dim(0),buff,false);
  140. Matrix<T> x( N1/M.Dim(1),M.Dim(1), out,false);
  141. Matrix<T>::DGEMM(x, y, M);
  142. }
  143. }
  144. static void fft_destroy_plan(plan p){
  145. p.dim.clear();
  146. p.M.clear();
  147. p.howmany=0;
  148. }
  149. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  150. *add=0;
  151. *mul=0;
  152. *fma=0;
  153. }
  154. private:
  155. static Matrix<T> fft_r2c(size_t N1){
  156. size_t N2=(N1/2+1);
  157. Matrix<T> M(N1,2*N2);
  158. for(size_t j=0;j<N1;j++)
  159. for(size_t i=0;i<N2;i++){
  160. M[j][2*i+0]=cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  161. M[j][2*i+1]=sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  162. }
  163. return M;
  164. }
  165. static Matrix<T> fft_c2c(size_t N1){
  166. Matrix<T> M(2*N1,2*N1);
  167. for(size_t i=0;i<N1;i++)
  168. for(size_t j=0;j<N1;j++){
  169. M[2*i+0][2*j+0]=cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  170. M[2*i+1][2*j+0]=sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  171. M[2*i+0][2*j+1]=-sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  172. M[2*i+1][2*j+1]= cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  173. }
  174. return M;
  175. }
  176. static Matrix<T> fft_c2r(size_t N1){
  177. size_t N2=(N1/2+1);
  178. Matrix<T> M(2*N2,N1);
  179. for(size_t i=0;i<N2;i++)
  180. for(size_t j=0;j<N1;j++){
  181. M[2*i+0][j]=2*cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  182. M[2*i+1][j]=2*sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  183. }
  184. if(N2>0){
  185. for(size_t j=0;j<N1;j++){
  186. M[0][j]=M[0][j]*0.5;
  187. M[1][j]=M[1][j]*0.5;
  188. }
  189. }
  190. if(N1%2==0){
  191. for(size_t j=0;j<N1;j++){
  192. M[2*N2-2][j]=M[2*N2-2][j]*0.5;
  193. M[2*N2-1][j]=M[2*N2-1][j]*0.5;
  194. }
  195. }
  196. return M;
  197. }
  198. template <class Y>
  199. static void transpose(size_t dim1, size_t dim2, Y* A){
  200. Matrix<Y> M(dim1, dim2, A);
  201. Matrix<Y> Mt(dim2, dim1, A, false);
  202. Mt=M.Transpose();
  203. }
  204. };
  205. #ifdef PVFMM_HAVE_FFTW
  206. template<>
  207. struct FFTW_t<double>{
  208. typedef fftw_plan plan;
  209. typedef fftw_complex cplx;
  210. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  211. double *in, const int *inembed, int istride, int idist,
  212. fftw_complex *out, const int *onembed, int ostride, int odist, unsigned flags){
  213. #ifdef FFTW3_MKL
  214. int omp_p0=omp_get_num_threads();
  215. int omp_p1=omp_get_max_threads();
  216. fftw3_mkl.number_of_user_threads = (omp_p0>omp_p1?omp_p0:omp_p1);
  217. #endif
  218. return fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride,
  219. idist, out, onembed, ostride, odist, flags);
  220. }
  221. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  222. cplx *in, const int *inembed, int istride, int idist,
  223. double *out, const int *onembed, int ostride, int odist, unsigned flags){
  224. #ifdef FFTW3_MKL
  225. int omp_p0=omp_get_num_threads();
  226. int omp_p1=omp_get_max_threads();
  227. fftw3_mkl.number_of_user_threads = (omp_p0>omp_p1?omp_p0:omp_p1);
  228. #endif
  229. return fftw_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
  230. out, onembed, ostride, odist, flags);
  231. }
  232. static void fft_execute_dft_r2c(const plan p, double *in, cplx *out){
  233. fftw_execute_dft_r2c(p, in, out);
  234. }
  235. static void fft_execute_dft_c2r(const plan p, cplx *in, double *out){
  236. fftw_execute_dft_c2r(p, in, out);
  237. }
  238. static void fft_destroy_plan(plan p){
  239. fftw_destroy_plan(p);
  240. }
  241. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  242. ::fftw_flops(p, add, mul, fma);
  243. }
  244. };
  245. #endif
  246. #ifdef PVFMM_HAVE_FFTWF
  247. template<>
  248. struct FFTW_t<float>{
  249. typedef fftwf_plan plan;
  250. typedef fftwf_complex cplx;
  251. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  252. float *in, const int *inembed, int istride, int idist,
  253. cplx *out, const int *onembed, int ostride, int odist, unsigned flags){
  254. return fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride,
  255. idist, out, onembed, ostride, odist, flags);
  256. }
  257. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  258. cplx *in, const int *inembed, int istride, int idist,
  259. float *out, const int *onembed, int ostride, int odist, unsigned flags){
  260. return fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
  261. out, onembed, ostride, odist, flags);
  262. }
  263. static void fft_execute_dft_r2c(const plan p, float *in, cplx *out){
  264. fftwf_execute_dft_r2c(p, in, out);
  265. }
  266. static void fft_execute_dft_c2r(const plan p, cplx *in, float *out){
  267. fftwf_execute_dft_c2r(p, in, out);
  268. }
  269. static void fft_destroy_plan(plan p){
  270. fftwf_destroy_plan(p);
  271. }
  272. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  273. ::fftwf_flops(p, add, mul, fma);
  274. }
  275. };
  276. #endif
  277. }//end namespace
  278. #endif //_PVFMM_FFT_WRAPPER_