Fix issue with Kokkos FFT_CUFFT

This commit is contained in:
Stan Moore 2020-01-30 13:27:36 -07:00
parent a50563d582
commit 9fade740fb
2 changed files with 10 additions and 9 deletions

View File

@ -227,7 +227,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename FFT_AT::t_FFT_DATA_1d d_in,
else
FFTW_API(execute_dft)(plan->plan_fast_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
#elif defined(FFT_CUFFT)
cufftExec(plan->plan_fast,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
cufftExec(plan->plan_fast,d_data.data(),d_data.data(),flag);
#else
typename FFT_AT::t_FFT_DATA_1d d_tmp =
typename FFT_AT::t_FFT_DATA_1d(Kokkos::view_alloc("fft_3d:tmp",Kokkos::WithoutInitializing),d_in.dimension_0());
@ -273,7 +273,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename FFT_AT::t_FFT_DATA_1d d_in,
else
FFTW_API(execute_dft)(plan->plan_mid_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
#elif defined(FFT_CUFFT)
cufftExec(plan->plan_mid,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
cufftExec(plan->plan_mid,d_data.data(),d_data.data(),flag);
#else
if (flag == -1)
f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_mid_forward,length);
@ -315,7 +315,7 @@ void FFT3dKokkos<DeviceType>::fft_3d_kokkos(typename FFT_AT::t_FFT_DATA_1d d_in,
else
FFTW_API(execute_dft)(plan->plan_slow_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
#elif defined(FFT_CUFFT)
cufftExec(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
cufftExec(plan->plan_slow,d_data.data(),d_data.data(),flag);
#else
if (flag == -1)
f = kiss_fft_functor<DeviceType>(d_data,d_tmp,plan->cfg_slow_forward,length);
@ -859,9 +859,9 @@ void FFT3dKokkos<DeviceType>::fft_3d_1d_only_kokkos(typename FFT_AT::t_FFT_DATA_
FFTW_API(execute_dft)(plan->plan_slow_backward,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data());
}
#elif defined(FFT_CUFFT)
cufftExec(plan->plan_fast,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
cufftExec(plan->plan_mid,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
cufftExec(plan->plan_slow,(FFT_DATA*)d_data.data(),(FFT_DATA*)d_data.data(),flag);
cufftExec(plan->plan_fast,d_data.data(),d_data.data(),flag);
cufftExec(plan->plan_mid,d_data.data(),d_data.data(),flag);
cufftExec(plan->plan_slow,d_data.data(),d_data.data(),flag);
#else
kiss_fft_functor<DeviceType> f;
typename FFT_AT::t_FFT_DATA_1d d_tmp = typename FFT_AT::t_FFT_DATA_1d("fft_3d:tmp",d_data.dimension_0());

View File

@ -11,6 +11,9 @@
See the README file in the top-level LAMMPS directory.
------------------------------------------------------------------------- */
#include "kokkos_type.h"
#define MAX(A,B) ((A) > (B) ? (A) : (B))
// data types for 2d/3d FFTs
@ -121,15 +124,13 @@ typedef double FFT_SCALAR;
#endif
// (double[2]*) is not a 1D pointer
#if defined(FFT_FFTW3) || defined(FFT_CUFFT)
#if defined(FFT_FFTW3)
typedef FFT_SCALAR* FFT_DATA_POINTER;
#else
typedef FFT_DATA* FFT_DATA_POINTER;
#endif
#include "kokkos_type.h"
template <class DeviceType>
struct FFTArrayTypes;