fft_wrapper.hpp 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. /**
  2. * \file mat_utils.hpp
  3. * \author Dhairya Malhotra, dhairya.malhotra@gmail.com
  4. * \date 2-11-2011
  5. * \brief This file contains FFTW3 wrapper functions.
  6. */
  7. #ifndef _PVFMM_FFT_WRAPPER_
  8. #define _PVFMM_FFT_WRAPPER_
  9. #include <fftw3.h>
  10. #ifdef FFTW3_MKL
  11. #include <fftw3_mkl.h>
  12. #endif
  13. #include <pvfmm_common.hpp>
  14. #include <matrix.hpp>
  15. namespace pvfmm{
  16. template<class T>
  17. struct FFTW_t{
  18. struct plan{
  19. std::vector<size_t> dim;
  20. std::vector<Matrix<T> > M;
  21. size_t howmany;
  22. };
  23. struct cplx{
  24. T real;
  25. T imag;
  26. };
  27. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  28. T *in, const int *inembed, int istride, int idist,
  29. cplx *out, const int *onembed, int ostride, int odist, unsigned flags){
  30. assert(inembed==NULL);
  31. assert(onembed==NULL);
  32. assert(istride==1);
  33. assert(ostride==1);
  34. plan p;
  35. p.howmany=howmany;
  36. { // r2c
  37. p.dim.push_back(n[rank-1]);
  38. p.M.push_back(fft_r2c(n[rank-1]));
  39. }
  40. for(int i=rank-2;i>=0;i--){ // c2c
  41. p.dim.push_back(n[i]);
  42. p.M.push_back(fft_c2c(n[i]));
  43. }
  44. size_t N1=1, N2=1;
  45. for(size_t i=0;i<p.dim.size();i++){
  46. N1*=p.dim[i];
  47. N2*=p.M[i].Dim(1)/2;
  48. }
  49. assert(idist==N1);
  50. assert(odist==N2);
  51. return p;
  52. }
  53. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  54. cplx *in, const int *inembed, int istride, int idist,
  55. T *out, const int *onembed, int ostride, int odist, unsigned flags){
  56. assert(inembed==NULL);
  57. assert(onembed==NULL);
  58. assert(istride==1);
  59. assert(ostride==1);
  60. plan p;
  61. p.howmany=howmany;
  62. for(size_t i=0;i<rank-1;i++){ // c2c
  63. p.dim.push_back(n[i]);
  64. p.M.push_back(fft_c2c(n[i]));
  65. }
  66. { // c2r
  67. p.dim.push_back(n[rank-1]);
  68. p.M.push_back(fft_c2r(n[rank-1]));
  69. }
  70. size_t N1=1, N2=1;
  71. for(size_t i=0;i<p.dim.size();i++){
  72. N1*=p.dim[i];
  73. N2*=p.M[i].Dim(0)/2;
  74. }
  75. assert(idist==N2);
  76. assert(odist==N1);
  77. return p;
  78. }
  79. static void fft_execute_dft_r2c(const plan p, T *in, cplx *out){
  80. size_t N1=p.howmany, N2=p.howmany;
  81. for(size_t i=0;i<p.dim.size();i++){
  82. N1*=p.dim[i];
  83. N2*=p.M[i].Dim(1)/2;
  84. }
  85. std::vector<T> buff_(N1+2*N2);
  86. T* buff=&buff_[0];
  87. { // r2c
  88. size_t i=0;
  89. const Matrix<T>& M=p.M[i];
  90. assert(2*N2/M.Dim(1)==N1/M.Dim(0));
  91. Matrix<T> x( N1/M.Dim(0),M.Dim(0), in,false);
  92. Matrix<T> y(2*N2/M.Dim(1),M.Dim(1),buff,false);
  93. Matrix<T>::DGEMM(y, x, M);
  94. transpose<cplx>(2*N2/M.Dim(1), M.Dim(1)/2, (cplx*)buff);
  95. }
  96. for(size_t i=1;i<p.dim.size();i++){ // c2c
  97. const Matrix<T>& M=p.M[i];
  98. assert(M.Dim(0)==M.Dim(1));
  99. Matrix<T> x(2*N2/M.Dim(0),M.Dim(0),buff); // TODO: optimize this
  100. Matrix<T> y(2*N2/M.Dim(1),M.Dim(1),buff,false);
  101. Matrix<T>::DGEMM(y, x, M);
  102. transpose<cplx>(2*N2/M.Dim(1), M.Dim(1)/2, (cplx*)buff);
  103. }
  104. { // howmany
  105. transpose<cplx>(N2/p.howmany, p.howmany, (cplx*)buff);
  106. mem::memcopy(out,buff,2*N2*sizeof(T));
  107. }
  108. }
  109. static void fft_execute_dft_c2r(const plan p, cplx *in, T *out){
  110. size_t N1=p.howmany, N2=p.howmany;
  111. for(size_t i=0;i<p.dim.size();i++){
  112. N1*=p.dim[i];
  113. N2*=p.M[i].Dim(0)/2;
  114. }
  115. std::vector<T> buff_(N1+2*N2);
  116. T* buff=&buff_[0];
  117. { // howmany
  118. mem::memcopy(buff,in,2*N2*sizeof(T));
  119. transpose<cplx>(p.howmany, N2/p.howmany, (cplx*)buff);
  120. }
  121. for(size_t i=0;i<p.dim.size()-1;i++){ // c2c
  122. Matrix<T> M=p.M[i];
  123. assert(M.Dim(0)==M.Dim(1));
  124. transpose<cplx>(M.Dim(0)/2, 2*N2/M.Dim(0), (cplx*)buff);
  125. Matrix<T> y(2*N2/M.Dim(0),M.Dim(0),buff); // TODO: optimize this
  126. Matrix<T> x(2*N2/M.Dim(1),M.Dim(1),buff,false);
  127. Matrix<T>::DGEMM(x, y, M.Transpose());
  128. }
  129. { // r2c
  130. size_t i=p.dim.size()-1;
  131. const Matrix<T>& M=p.M[i];
  132. assert(2*N2/M.Dim(0)==N1/M.Dim(1));
  133. transpose<cplx>(M.Dim(0)/2, 2*N2/M.Dim(0), (cplx*)buff);
  134. Matrix<T> y(2*N2/M.Dim(0),M.Dim(0),buff,false);
  135. Matrix<T> x( N1/M.Dim(1),M.Dim(1), out,false);
  136. Matrix<T>::DGEMM(x, y, M);
  137. }
  138. }
  139. static void fft_destroy_plan(plan p){
  140. p.dim.clear();
  141. p.M.clear();
  142. p.howmany=0;
  143. }
  144. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  145. *add=0;
  146. *mul=0;
  147. *fma=0;
  148. }
  149. private:
  150. static Matrix<T> fft_r2c(size_t N1){
  151. size_t N2=(N1/2+1);
  152. Matrix<T> M(N1,2*N2);
  153. for(size_t j=0;j<N1;j++)
  154. for(size_t i=0;i<N2;i++){
  155. M[j][2*i+0]=cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  156. M[j][2*i+1]=sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  157. }
  158. return M;
  159. }
  160. static Matrix<T> fft_c2c(size_t N1){
  161. Matrix<T> M(2*N1,2*N1);
  162. for(size_t i=0;i<N1;i++)
  163. for(size_t j=0;j<N1;j++){
  164. M[2*i+0][2*j+0]=cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  165. M[2*i+1][2*j+0]=sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  166. M[2*i+0][2*j+1]=-sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  167. M[2*i+1][2*j+1]= cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  168. }
  169. return M;
  170. }
  171. static Matrix<T> fft_c2r(size_t N1){
  172. size_t N2=(N1/2+1);
  173. Matrix<T> M(2*N2,N1);
  174. for(size_t i=0;i<N2;i++)
  175. for(size_t j=0;j<N1;j++){
  176. M[2*i+0][j]=2*cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  177. M[2*i+1][j]=2*sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  178. }
  179. if(N2>0){
  180. for(size_t j=0;j<N1;j++){
  181. M[0][j]=M[0][j]*0.5;
  182. M[1][j]=M[1][j]*0.5;
  183. }
  184. }
  185. if(N1%2==0){
  186. for(size_t j=0;j<N1;j++){
  187. M[2*N2-2][j]=M[2*N2-2][j]*0.5;
  188. M[2*N2-1][j]=M[2*N2-1][j]*0.5;
  189. }
  190. }
  191. return M;
  192. }
  193. template <class Y>
  194. static void transpose(size_t dim1, size_t dim2, Y* A){
  195. Matrix<Y> M(dim1, dim2, A);
  196. Matrix<Y> Mt(dim2, dim1, A, false);
  197. Mt=M.Transpose();
  198. }
  199. };
  200. #ifdef PVFMM_HAVE_FFTW
  201. template<>
  202. struct FFTW_t<double>{
  203. typedef fftw_plan plan;
  204. typedef fftw_complex cplx;
  205. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  206. double *in, const int *inembed, int istride, int idist,
  207. fftw_complex *out, const int *onembed, int ostride, int odist, unsigned flags){
  208. #ifdef FFTW3_MKL
  209. int omp_p0=omp_get_num_threads();
  210. int omp_p1=omp_get_max_threads();
  211. fftw3_mkl.number_of_user_threads = (omp_p0>omp_p1?omp_p0:omp_p1);
  212. #endif
  213. return fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride,
  214. idist, out, onembed, ostride, odist, flags);
  215. }
  216. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  217. cplx *in, const int *inembed, int istride, int idist,
  218. double *out, const int *onembed, int ostride, int odist, unsigned flags){
  219. #ifdef FFTW3_MKL
  220. int omp_p0=omp_get_num_threads();
  221. int omp_p1=omp_get_max_threads();
  222. fftw3_mkl.number_of_user_threads = (omp_p0>omp_p1?omp_p0:omp_p1);
  223. #endif
  224. return fftw_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
  225. out, onembed, ostride, odist, flags);
  226. }
  227. static void fft_execute_dft_r2c(const plan p, double *in, cplx *out){
  228. fftw_execute_dft_r2c(p, in, out);
  229. }
  230. static void fft_execute_dft_c2r(const plan p, cplx *in, double *out){
  231. fftw_execute_dft_c2r(p, in, out);
  232. }
  233. static void fft_destroy_plan(plan p){
  234. fftw_destroy_plan(p);
  235. }
  236. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  237. ::fftw_flops(p, add, mul, fma);
  238. }
  239. };
  240. #endif
  241. #ifdef PVFMM_HAVE_FFTWF
  242. template<>
  243. struct FFTW_t<float>{
  244. typedef fftwf_plan plan;
  245. typedef fftwf_complex cplx;
  246. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  247. float *in, const int *inembed, int istride, int idist,
  248. cplx *out, const int *onembed, int ostride, int odist, unsigned flags){
  249. return fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride,
  250. idist, out, onembed, ostride, odist, flags);
  251. }
  252. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  253. cplx *in, const int *inembed, int istride, int idist,
  254. float *out, const int *onembed, int ostride, int odist, unsigned flags){
  255. return fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
  256. out, onembed, ostride, odist, flags);
  257. }
  258. static void fft_execute_dft_r2c(const plan p, float *in, cplx *out){
  259. fftwf_execute_dft_r2c(p, in, out);
  260. }
  261. static void fft_execute_dft_c2r(const plan p, cplx *in, float *out){
  262. fftwf_execute_dft_c2r(p, in, out);
  263. }
  264. static void fft_destroy_plan(plan p){
  265. fftwf_destroy_plan(p);
  266. }
  267. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  268. ::fftwf_flops(p, add, mul, fma);
  269. }
  270. };
  271. #endif
  272. }//end namespace
  273. #endif //_PVFMM_FFT_WRAPPER_