/** * \file mat_utils.hpp * \author Dhairya Malhotra, dhairya.malhotra@gmail.com * \date 2-11-2011 * \brief This file contains FFTW3 wrapper functions. */ #ifndef _PVFMM_FFT_WRAPPER_ #define _PVFMM_FFT_WRAPPER_ #include #ifdef FFTW3_MKL #include #endif #include #include namespace pvfmm{ template struct FFTW_t{ struct plan{ std::vector dim; std::vector > M; size_t howmany; }; struct cplx{ T real; T imag; }; static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany, T *in, const int *inembed, int istride, int idist, cplx *out, const int *onembed, int ostride, int odist, unsigned flags){ assert(inembed==NULL); assert(onembed==NULL); assert(istride==1); assert(ostride==1); plan p; p.howmany=howmany; { // r2c p.dim.push_back(n[rank-1]); p.M.push_back(fft_r2c(n[rank-1])); } for(int i=rank-2;i>=0;i--){ // c2c p.dim.push_back(n[i]); p.M.push_back(fft_c2c(n[i])); } size_t N1=1, N2=1; for(size_t i=0;i buff_(N1+2*N2); T* buff=&buff_[0]; { // r2c size_t i=0; const Matrix& M=p.M[i]; assert(2*N2/M.Dim(1)==N1/M.Dim(0)); Matrix x( N1/M.Dim(0),M.Dim(0), in,false); Matrix y(2*N2/M.Dim(1),M.Dim(1),buff,false); Matrix::DGEMM(y, x, M); transpose(2*N2/M.Dim(1), M.Dim(1)/2, (cplx*)buff); } for(size_t i=1;i& M=p.M[i]; assert(M.Dim(0)==M.Dim(1)); Matrix x(2*N2/M.Dim(0),M.Dim(0),buff); // TODO: optimize this Matrix y(2*N2/M.Dim(1),M.Dim(1),buff,false); Matrix::DGEMM(y, x, M); transpose(2*N2/M.Dim(1), M.Dim(1)/2, (cplx*)buff); } { // howmany transpose(N2/p.howmany, p.howmany, (cplx*)buff); mem::memcopy(out,buff,2*N2*sizeof(T)); } } static void fft_execute_dft_c2r(const plan p, cplx *in, T *out){ size_t N1=p.howmany, N2=p.howmany; for(size_t i=0;i buff_(N1+2*N2); T* buff=&buff_[0]; { // howmany mem::memcopy(buff,in,2*N2*sizeof(T)); transpose(p.howmany, N2/p.howmany, (cplx*)buff); } for(size_t i=0;i M=p.M[i]; assert(M.Dim(0)==M.Dim(1)); transpose(M.Dim(0)/2, 2*N2/M.Dim(0), (cplx*)buff); Matrix y(2*N2/M.Dim(0),M.Dim(0),buff); // TODO: optimize this Matrix x(2*N2/M.Dim(1),M.Dim(1),buff,false); Matrix::DGEMM(x, y, M.Transpose()); } { // r2c size_t i=p.dim.size()-1; const Matrix& M=p.M[i]; assert(2*N2/M.Dim(0)==N1/M.Dim(1)); transpose(M.Dim(0)/2, 2*N2/M.Dim(0), (cplx*)buff); Matrix y(2*N2/M.Dim(0),M.Dim(0),buff,false); Matrix x( N1/M.Dim(1),M.Dim(1), out,false); Matrix::DGEMM(x, y, M); } } static void fft_destroy_plan(plan p){ p.dim.clear(); p.M.clear(); p.howmany=0; } static void fftw_flops(const plan& p, double* add, double* mul, double* fma){ *add=0; *mul=0; *fma=0; } private: static Matrix fft_r2c(size_t N1){ size_t N2=(N1/2+1); Matrix M(N1,2*N2); for(size_t j=0;j()); M[j][2*i+1]=sin(j*i*(1.0/N1)*2.0*const_pi()); } return M; } static Matrix fft_c2c(size_t N1){ Matrix M(2*N1,2*N1); for(size_t i=0;i()); M[2*i+1][2*j+0]=sin(j*i*(1.0/N1)*2.0*const_pi()); M[2*i+0][2*j+1]=-sin(j*i*(1.0/N1)*2.0*const_pi()); M[2*i+1][2*j+1]= cos(j*i*(1.0/N1)*2.0*const_pi()); } return M; } static Matrix fft_c2r(size_t N1){ size_t N2=(N1/2+1); Matrix M(2*N2,N1); for(size_t i=0;i()); M[2*i+1][j]=2*sin(j*i*(1.0/N1)*2.0*const_pi()); } if(N2>0){ for(size_t j=0;j static void transpose(size_t dim1, size_t dim2, Y* A){ Matrix M(dim1, dim2, A); Matrix Mt(dim2, dim1, A, false); Mt=M.Transpose(); } }; #ifdef PVFMM_HAVE_FFTW template<> struct FFTW_t{ typedef fftw_plan plan; typedef fftw_complex cplx; static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany, double *in, const int *inembed, int istride, int idist, fftw_complex *out, const int *onembed, int ostride, int odist, unsigned flags){ #ifdef FFTW3_MKL int omp_p0=omp_get_num_threads(); int omp_p1=omp_get_max_threads(); fftw3_mkl.number_of_user_threads = (omp_p0>omp_p1?omp_p0:omp_p1); #endif return fftw_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist, out, onembed, ostride, odist, flags); } static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany, cplx *in, const int *inembed, int istride, int idist, double *out, const int *onembed, int ostride, int odist, unsigned flags){ #ifdef FFTW3_MKL int omp_p0=omp_get_num_threads(); int omp_p1=omp_get_max_threads(); fftw3_mkl.number_of_user_threads = (omp_p0>omp_p1?omp_p0:omp_p1); #endif return fftw_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist, out, onembed, ostride, odist, flags); } static void fft_execute_dft_r2c(const plan p, double *in, cplx *out){ fftw_execute_dft_r2c(p, in, out); } static void fft_execute_dft_c2r(const plan p, cplx *in, double *out){ fftw_execute_dft_c2r(p, in, out); } static void fft_destroy_plan(plan p){ fftw_destroy_plan(p); } static void fftw_flops(const plan& p, double* add, double* mul, double* fma){ ::fftw_flops(p, add, mul, fma); } }; #endif #ifdef PVFMM_HAVE_FFTWF template<> struct FFTW_t{ typedef fftwf_plan plan; typedef fftwf_complex cplx; static plan fft_plan_many_dft_r2c(int rank, const int *n, int howmany, float *in, const int *inembed, int istride, int idist, cplx *out, const int *onembed, int ostride, int odist, unsigned flags){ return fftwf_plan_many_dft_r2c(rank, n, howmany, in, inembed, istride, idist, out, onembed, ostride, odist, flags); } static plan fft_plan_many_dft_c2r(int rank, const int *n, int howmany, cplx *in, const int *inembed, int istride, int idist, float *out, const int *onembed, int ostride, int odist, unsigned flags){ return fftwf_plan_many_dft_c2r(rank, n, howmany, in, inembed, istride, idist, out, onembed, ostride, odist, flags); } static void fft_execute_dft_r2c(const plan p, float *in, cplx *out){ fftwf_execute_dft_r2c(p, in, out); } static void fft_execute_dft_c2r(const plan p, cplx *in, float *out){ fftwf_execute_dft_c2r(p, in, out); } static void fft_destroy_plan(plan p){ fftwf_destroy_plan(p); } static void fftw_flops(const plan& p, double* add, double* mul, double* fma){ ::fftwf_flops(p, add, mul, fma); } }; #endif }//end namespace #endif //_PVFMM_FFT_WRAPPER_