fft_wrapper.hpp 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. /**
  2. * \file mat_utils.hpp
  3. * \author Dhairya Malhotra, dhairya.malhotra@gmail.com
  4. * \date 2-11-2011
  5. * \brief This file contains FFTW3 wrapper functions.
  6. */
  7. #ifndef _PVFMM_FFT_WRAPPER_
  8. #define _PVFMM_FFT_WRAPPER_
  9. #ifdef __INTEL_OFFLOAD
  10. #pragma offload_attribute(push,target(mic))
  11. #endif
  12. #include <fftw3.h>
  13. #ifdef FFTW3_MKL
  14. #include <fftw3_mkl.h>
  15. #endif
  16. #include <pvfmm_common.hpp>
  17. #include <matrix.hpp>
  18. namespace pvfmm{
  19. template<class T>
  20. struct FFTW_t{
  21. struct plan{
  22. std::vector<size_t> dim;
  23. std::vector<Matrix<T> > M;
  24. size_t howmany;
  25. };
  26. struct cplx{
  27. T real;
  28. T imag;
  29. };
  30. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  31. T *in, const int *inembed, int istride, int idist,
  32. cplx *out, const int *onembed, int ostride, int odist, unsigned flags){
  33. assert(inembed==NULL);
  34. assert(onembed==NULL);
  35. assert(istride==1);
  36. assert(ostride==1);
  37. plan p;
  38. p.howmany=howmany;
  39. { // r2c
  40. p.dim.push_back(n[rank-1]);
  41. p.M.push_back(fft_r2c(n[rank-1]));
  42. }
  43. for(int i=rank-2;i>=0;i--){ // c2c
  44. p.dim.push_back(n[i]);
  45. p.M.push_back(fft_c2c(n[i]));
  46. }
  47. size_t N1=1, N2=1;
  48. for(size_t i=0;i<p.dim.size();i++){
  49. N1*=p.dim[i];
  50. N2*=p.M[i].Dim(1)/2;
  51. }
  52. assert(idist==N1);
  53. assert(odist==N2);
  54. return p;
  55. }
  56. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  57. cplx *in, const int *inembed, int istride, int idist,
  58. T *out, const int *onembed, int ostride, int odist, unsigned flags){
  59. assert(inembed==NULL);
  60. assert(onembed==NULL);
  61. assert(istride==1);
  62. assert(ostride==1);
  63. plan p;
  64. p.howmany=howmany;
  65. for(size_t i=0;i<rank-1;i++){ // c2c
  66. p.dim.push_back(n[i]);
  67. p.M.push_back(fft_c2c(n[i]));
  68. }
  69. { // c2r
  70. p.dim.push_back(n[rank-1]);
  71. p.M.push_back(fft_c2r(n[rank-1]));
  72. }
  73. size_t N1=1, N2=1;
  74. for(size_t i=0;i<p.dim.size();i++){
  75. N1*=p.dim[i];
  76. N2*=p.M[i].Dim(0)/2;
  77. }
  78. assert(idist==N2);
  79. assert(odist==N1);
  80. return p;
  81. }
  82. static void fft_execute_dft_r2c(const plan p, T *in, cplx *out){
  83. size_t N1=p.howmany, N2=p.howmany;
  84. for(size_t i=0;i<p.dim.size();i++){
  85. N1*=p.dim[i];
  86. N2*=p.M[i].Dim(1)/2;
  87. }
  88. std::vector<T> buff_(N1+2*N2);
  89. T* buff=&buff_[0];
  90. { // r2c
  91. size_t i=0;
  92. const Matrix<T>& M=p.M[i];
  93. assert(2*N2/M.Dim(1)==N1/M.Dim(0));
  94. Matrix<T> x( N1/M.Dim(0),M.Dim(0), in,false);
  95. Matrix<T> y(2*N2/M.Dim(1),M.Dim(1),buff,false);
  96. Matrix<T>::DGEMM(y, x, M);
  97. transpose<cplx>(2*N2/M.Dim(1), M.Dim(1)/2, (cplx*)buff);
  98. }
  99. for(size_t i=1;i<p.dim.size();i++){ // c2c
  100. const Matrix<T>& M=p.M[i];
  101. assert(M.Dim(0)==M.Dim(1));
  102. Matrix<T> x(2*N2/M.Dim(0),M.Dim(0),buff); // TODO: optimize this
  103. Matrix<T> y(2*N2/M.Dim(1),M.Dim(1),buff,false);
  104. Matrix<T>::DGEMM(y, x, M);
  105. transpose<cplx>(2*N2/M.Dim(1), M.Dim(1)/2, (cplx*)buff);
  106. }
  107. { // howmany
  108. transpose<cplx>(N2/p.howmany, p.howmany, (cplx*)buff);
  109. mem::memcopy(out,buff,2*N2*sizeof(T));
  110. }
  111. }
  112. static void fft_execute_dft_c2r(const plan p, cplx *in, T *out){
  113. size_t N1=p.howmany, N2=p.howmany;
  114. for(size_t i=0;i<p.dim.size();i++){
  115. N1*=p.dim[i];
  116. N2*=p.M[i].Dim(0)/2;
  117. }
  118. std::vector<T> buff_(N1+2*N2);
  119. T* buff=&buff_[0];
  120. { // howmany
  121. mem::memcopy(buff,in,2*N2*sizeof(T));
  122. transpose<cplx>(p.howmany, N2/p.howmany, (cplx*)buff);
  123. }
  124. for(size_t i=0;i<p.dim.size()-1;i++){ // c2c
  125. Matrix<T> M=p.M[i];
  126. assert(M.Dim(0)==M.Dim(1));
  127. transpose<cplx>(M.Dim(0)/2, 2*N2/M.Dim(0), (cplx*)buff);
  128. Matrix<T> y(2*N2/M.Dim(0),M.Dim(0),buff); // TODO: optimize this
  129. Matrix<T> x(2*N2/M.Dim(1),M.Dim(1),buff,false);
  130. Matrix<T>::DGEMM(x, y, M.Transpose());
  131. }
  132. { // r2c
  133. size_t i=p.dim.size()-1;
  134. const Matrix<T>& M=p.M[i];
  135. assert(2*N2/M.Dim(0)==N1/M.Dim(1));
  136. transpose<cplx>(M.Dim(0)/2, 2*N2/M.Dim(0), (cplx*)buff);
  137. Matrix<T> y(2*N2/M.Dim(0),M.Dim(0),buff,false);
  138. Matrix<T> x( N1/M.Dim(1),M.Dim(1), out,false);
  139. Matrix<T>::DGEMM(x, y, M);
  140. }
  141. }
  142. static void fft_destroy_plan(plan p){
  143. p.dim.clear();
  144. p.M.clear();
  145. p.howmany=0;
  146. }
  147. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  148. *add=0;
  149. *mul=0;
  150. *fma=0;
  151. }
  152. private:
  153. static Matrix<T> fft_r2c(size_t N1){
  154. size_t N2=(N1/2+1);
  155. Matrix<T> M(N1,2*N2);
  156. for(size_t j=0;j<N1;j++)
  157. for(size_t i=0;i<N2;i++){
  158. M[j][2*i+0]=cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  159. M[j][2*i+1]=sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  160. }
  161. return M;
  162. }
  163. static Matrix<T> fft_c2c(size_t N1){
  164. Matrix<T> M(2*N1,2*N1);
  165. for(size_t i=0;i<N1;i++)
  166. for(size_t j=0;j<N1;j++){
  167. M[2*i+0][2*j+0]=cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  168. M[2*i+1][2*j+0]=sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  169. M[2*i+0][2*j+1]=-sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  170. M[2*i+1][2*j+1]= cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  171. }
  172. return M;
  173. }
  174. static Matrix<T> fft_c2r(size_t N1){
  175. size_t N2=(N1/2+1);
  176. Matrix<T> M(2*N2,N1);
  177. for(size_t i=0;i<N2;i++)
  178. for(size_t j=0;j<N1;j++){
  179. M[2*i+0][j]=2*cos(j*i*(1.0/N1)*2.0*const_pi<T>());
  180. M[2*i+1][j]=2*sin(j*i*(1.0/N1)*2.0*const_pi<T>());
  181. }
  182. if(N2>0){
  183. for(size_t j=0;j<N1;j++){
  184. M[0][j]=M[0][j]*0.5;
  185. M[1][j]=M[1][j]*0.5;
  186. }
  187. }
  188. if(N1%2==0){
  189. for(size_t j=0;j<N1;j++){
  190. M[2*N2-2][j]=M[2*N2-2][j]*0.5;
  191. M[2*N2-1][j]=M[2*N2-1][j]*0.5;
  192. }
  193. }
  194. return M;
  195. }
  196. template <class Y>
  197. static void transpose(size_t dim1, size_t dim2, Y* A){
  198. Matrix<Y> M(dim1, dim2, A);
  199. Matrix<Y> Mt(dim2, dim1, A, false);
  200. Mt=M.Transpose();
  201. }
  202. };
  203. #ifdef PVFMM_HAVE_FFTW
  204. template<>
  205. struct FFTW_t<double>{
  206. typedef fftw_plan plan;
  207. typedef fftw_complex cplx;
  208. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  209. double *in, const int *inembed, int istride, int idist,
  210. fftw_complex *out, const int *onembed, int ostride, int odist, unsigned flags){
  211. #ifdef FFTW3_MKL
  212. int omp_p0=omp_get_num_threads();
  213. int omp_p1=omp_get_max_threads();
  214. fftw3_mkl.number_of_user_threads = (omp_p0>omp_p1?omp_p0:omp_p1);
  215. #endif
  216. return fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride,
  217. idist, out, onembed, ostride, odist, flags);
  218. }
  219. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  220. cplx *in, const int *inembed, int istride, int idist,
  221. double *out, const int *onembed, int ostride, int odist, unsigned flags){
  222. #ifdef FFTW3_MKL
  223. int omp_p0=omp_get_num_threads();
  224. int omp_p1=omp_get_max_threads();
  225. fftw3_mkl.number_of_user_threads = (omp_p0>omp_p1?omp_p0:omp_p1);
  226. #endif
  227. return fftw_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
  228. out, onembed, ostride, odist, flags);
  229. }
  230. static void fft_execute_dft_r2c(const plan p, double *in, cplx *out){
  231. fftw_execute_dft_r2c(p, in, out);
  232. }
  233. static void fft_execute_dft_c2r(const plan p, cplx *in, double *out){
  234. fftw_execute_dft_c2r(p, in, out);
  235. }
  236. static void fft_destroy_plan(plan p){
  237. fftw_destroy_plan(p);
  238. }
  239. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  240. ::fftw_flops(p, add, mul, fma);
  241. }
  242. };
  243. #endif
  244. #ifdef PVFMM_HAVE_FFTWF
  245. template<>
  246. struct FFTW_t<float>{
  247. typedef fftwf_plan plan;
  248. typedef fftwf_complex cplx;
  249. static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany,
  250. float *in, const int *inembed, int istride, int idist,
  251. cplx *out, const int *onembed, int ostride, int odist, unsigned flags){
  252. return fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride,
  253. idist, out, onembed, ostride, odist, flags);
  254. }
  255. static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany,
  256. cplx *in, const int *inembed, int istride, int idist,
  257. float *out, const int *onembed, int ostride, int odist, unsigned flags){
  258. return fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist,
  259. out, onembed, ostride, odist, flags);
  260. }
  261. static void fft_execute_dft_r2c(const plan p, float *in, cplx *out){
  262. fftwf_execute_dft_r2c(p, in, out);
  263. }
  264. static void fft_execute_dft_c2r(const plan p, cplx *in, float *out){
  265. fftwf_execute_dft_c2r(p, in, out);
  266. }
  267. static void fft_destroy_plan(plan p){
  268. fftwf_destroy_plan(p);
  269. }
  270. static void fftw_flops(const plan& p, double* add, double* mul, double* fma){
  271. ::fftwf_flops(p, add, mul, fma);
  272. }
  273. };
  274. #endif
  275. }//end namespace
  276. #ifdef __INTEL_OFFLOAD
  277. #pragma offload_attribute(pop)
  278. #endif
  279. #endif //_PVFMM_FFT_WRAPPER_