!20885 add dtypes & fft kernels for SPONGE

Merge pull request !20885 from huangmengxi/sponge_ccsrc
This commit is contained in:
i-robot 2021-07-31 03:31:08 +00:00 committed by Gitee
commit 22e9299c17
22 changed files with 512 additions and 1 deletions

View File

@ -66,6 +66,12 @@ MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
GatherGpuFwdKernel, int64_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
GatherGpuFwdKernel, uint, int)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
GatherGpuFwdKernel, uint, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
GatherGpuFwdKernel, uchar, int)

View File

@ -80,6 +80,14 @@ MS_REG_GPU_KERNEL_TWO(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
GatherGradGpuKernel, int64_t, uchar)
MS_REG_GPU_KERNEL_TWO(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
GatherGradGpuKernel, int, uint)
MS_REG_GPU_KERNEL_TWO(
GatherDGrad,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
GatherGradGpuKernel, int64_t, uint)
MS_REG_GPU_KERNEL_TWO(
GatherDGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
GatherGradGpuKernel, int, bool)

View File

@ -36,6 +36,10 @@ MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_TWO(
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
GatherNdGpuFwdKernel, short, int) // NOLINT
MS_REG_GPU_KERNEL_TWO(
GatherNd,
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
GatherNdGpuFwdKernel, uint, int)
MS_REG_GPU_KERNEL_TWO(
GatherNd, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
GatherNdGpuFwdKernel, uchar, int)
@ -60,6 +64,10 @@ MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_TWO(
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
GatherNdGpuFwdKernel, short, int64_t) // NOLINT
MS_REG_GPU_KERNEL_TWO(
GatherNd,
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
GatherNdGpuFwdKernel, uint, int64_t)
MS_REG_GPU_KERNEL_TWO(
GatherNd, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
GatherNdGpuFwdKernel, uchar, int64_t)

View File

@ -60,6 +60,12 @@ MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_TWO(
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
GatherV2GpuFwdKernel, int8_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
GatherV2GpuFwdKernel, uint, int)
MS_REG_GPU_KERNEL_TWO(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
GatherV2GpuFwdKernel, uint, int64_t)
MS_REG_GPU_KERNEL_TWO(
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
GatherV2GpuFwdKernel, uint8_t, int)

View File

@ -67,5 +67,19 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
TensorScatterUpdateGpuFwdKernel, double, int64_t)
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool),
TensorScatterUpdateGpuFwdKernel, bool, int)
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeBool)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeBool)
.AddOutputAttr(kNumberTypeBool),
TensorScatterUpdateGpuFwdKernel, bool, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -30,5 +30,9 @@ MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutpu
int)
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), TileGpuKernel,
int64_t)
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
TileGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), TileGpuKernel,
bool)
} // namespace kernel
} // namespace mindspore

View File

@ -69,6 +69,10 @@ template void GatherNd<int, int>(int *input, int *indices, int *output, const si
template void GatherNd<short, int>(short *input, int *indices, short *output, const size_t &output_dim0, // NOLINT
const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices,
int *batch_strides, cudaStream_t stream);
template void GatherNd<unsigned int, int>(unsigned int *input, int *indices, unsigned int *output,
const size_t &output_dim0, const size_t &output_dim1,
const size_t &indices_dim1, int *batch_indices, int *batch_strides,
cudaStream_t stream);
template void GatherNd<unsigned char, int>(unsigned char *input, int *indices, unsigned char *output,
const size_t &output_dim0, const size_t &output_dim1,
const size_t &indices_dim1, int *batch_indices, int *batch_strides,
@ -91,6 +95,10 @@ template void GatherNd<int, int64_t>(int *input, int64_t *indices, int *output,
template void GatherNd<short, int64_t>(short *input, int64_t *indices, short *output, // NOLINT
const size_t &output_dim0, const size_t &output_dim1, const size_t &indices_dim1,
int64_t *batch_indices, int64_t *batch_strides, cudaStream_t stream);
template void GatherNd<unsigned int, int64_t>(unsigned int *input, int64_t *indices, unsigned int *output,
const size_t &output_dim0, const size_t &output_dim1,
const size_t &indices_dim1, int64_t *batch_indices,
int64_t *batch_strides, cudaStream_t stream);
template void GatherNd<unsigned char, int64_t>(unsigned char *input, int64_t *indices, unsigned char *output,
const size_t &output_dim0, const size_t &output_dim1,
const size_t &indices_dim1, int64_t *batch_indices,

View File

@ -72,6 +72,11 @@ template void GatherV2<int8_t, int>(int8_t *input, int *indices, int8_t *output,
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template void GatherV2<int8_t, int64_t>(int8_t *input, int64_t *indices, int8_t *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template void GatherV2<uint32_t, int>(uint32_t *input, int *indices, uint32_t *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template void GatherV2<uint32_t, int64_t>(uint32_t *input, int64_t *indices, uint32_t *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1,
cudaStream_t stream);
template void GatherV2<uint8_t, int>(uint8_t *input, int *indices, uint8_t *output, size_t output_dim0,
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);
template void GatherV2<uint8_t, int64_t>(uint8_t *input, int64_t *indices, uint8_t *output, size_t output_dim0,

View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/fft_3d_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/pme_common.cuh"
template <typename T>
__global__ static void Split_Complex(const int element_numbers, T *real_part, T *imag_part,
const cufftComplex *complex_element) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < element_numbers) {
real_part[i] = complex_element[i].x;
imag_part[i] = complex_element[i].y;
}
}
template <typename T>
void FFT3D(int Nfft, T *input_tensor, T *complex_fq, T *output_real, T *output_imag,
const cufftHandle &FFT_plan_r2c, cudaStream_t stream) {
cufftComplex *COMPLEX_FQ = reinterpret_cast<cufftComplex *>(complex_fq);
cufftExecR2C(FFT_plan_r2c, input_tensor, COMPLEX_FQ);
Split_Complex<T><<<Nfft / 1024 + 1, 1024, 0, stream>>>(Nfft, output_real, output_imag, COMPLEX_FQ);
return;
}
template void FFT3D<float>(int Nfft, float *input_tensor, float *complex_fq, float *output_real,
float *output_imag, const cufftHandle &FFT_plan_r2c, cudaStream_t stream);

View File

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_FFT_3D_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_FFT_3D_IMPL_H_
#include <cufft.h>
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void FFT3D(int Nfft, T *input_tensor, T *complex_fq, T *output_real, T *output_imag,
const cufftHandle &FFT_plan_r2c, cudaStream_t stream);
#endif

View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/ifft_3d_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/pme_common.cuh"
template <typename T>
__global__ static void Merge_Complex(const int element_numbers, T *real_part, T *imag_part,
cufftComplex *complex_element) {
int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < element_numbers) {
complex_element[i].x = real_part[i];
complex_element[i].y = imag_part[i];
}
}
template <typename T>
void IFFT3D(int Nfft, T *input_real, T *input_imag, T *complex_fq, T *output_tensor,
const cufftHandle &FFT_plan_c2r, cudaStream_t stream) {
cufftComplex *COMPLEX_FQ = reinterpret_cast<cufftComplex *>(complex_fq);
Merge_Complex<T><<<Nfft / 1024 + 1, 1024, 0, stream>>>(Nfft, input_real, input_imag, COMPLEX_FQ);
cufftExecC2R(FFT_plan_c2r, COMPLEX_FQ, output_tensor);
return;
}
template void IFFT3D<float>(int Nfft, float *input_real, float *input_imag, float *complex_fq,
float *output_tensor, const cufftHandle &FFT_plan_c2r, cudaStream_t stream);

View File

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_IFFT_3D_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_IFFT_3D_IMPL_H_
#include <cufft.h>
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void IFFT3D(int Nfft, T *input_real, T *input_imag, T *complex_fq, T *output_tensor,
const cufftHandle &FFT_plan_c2r, cudaStream_t stream);
#endif

View File

@ -86,6 +86,16 @@ template void TensorScatterUpdate<int, int>(int *input, int *indices, int *updat
const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void TensorScatterUpdate<bool, int>(bool *input, int *indices, bool *update, bool *output,
const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
cudaStream_t stream);
template void TensorScatterUpdate<bool, int64_t>(bool *input, int64_t *indices, bool *update, bool *output,
const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0,
const size_t &indices_dim_1, int64_t *indices_stride,
int64_t *work_shape, cudaStream_t stream);
template void TensorScatterUpdate<double, int64_t>(double *input, int64_t *indices, double *update, double *output,
const size_t &block_size, const size_t &input_size,
const size_t &output_size, const size_t &indices_dim_0,

View File

@ -72,3 +72,6 @@ template void CalTile<int>(const size_t output_size, const size_t input_size, co
template void CalTile<int64_t>(const size_t output_size, const size_t input_size, const size_t shape_size,
const size_t *input_shape, const size_t *output_shape, const int64_t *input,
int64_t *output, cudaStream_t cuda_stream);
template void CalTile<bool>(const size_t output_size, const size_t input_size, const size_t shape_size,
const size_t *input_shape, const size_t *output_shape, const bool *input,
bool *output, cudaStream_t cuda_stream);

View File

@ -352,6 +352,14 @@ MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
BroadcastOpGpuKernel, int8_t)
// uint32
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
BroadcastOpGpuKernel, uint)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
BroadcastOpGpuKernel, uint)
// uint8
MS_REG_GPU_KERNEL_ONE(
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),

View File

@ -60,6 +60,15 @@ MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOu
FlattenGpuFwdKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
FlattenGpuFwdKernel, int16_t)
// uint32
MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
FlattenGpuFwdKernel, uint)
MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
FlattenGpuFwdKernel, uint)
MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
FlattenGpuFwdKernel, uint)
// uint8
MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
FlattenGpuFwdKernel, uchar)

View File

@ -0,0 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/sponge/pme/fft_3d_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
FFT3D,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
FFT3DGpuKernel, float)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,86 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONGE_PME_FFT_3D_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONGE_PME_FFT_3D_KERNEL_H_
#include <cuda_runtime_api.h>
#include <cufft.h>
#include <vector>
#include <string>
#include <map>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "runtime/device/gpu/cuda_common.h"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/fft_3d_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class FFT3DGpuKernel : public GpuKernel {
public:
FFT3DGpuKernel() = default;
~FFT3DGpuKernel() override = default;
bool Init(const CNodePtr &kernel_node) override {
fftx = static_cast<int>(GetAttr<int64_t>(kernel_node, "fftx"));
ffty = static_cast<int>(GetAttr<int64_t>(kernel_node, "ffty"));
fftz = static_cast<int>(GetAttr<int64_t>(kernel_node, "fftz"));
Nall = fftx * ffty * fftz;
Nfft = fftx * ffty * (fftz / 2 + 1);
cufftPlan3d(&FFT_plan_r2c, fftx, ffty, fftz, CUFFT_R2C);
InitSizeLists();
return true;
}
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto input_tensor = GetDeviceAddress<T>(inputs, 0);
auto complex_fq = GetDeviceAddress<T>(workspace, 0);
auto output_real = GetDeviceAddress<T>(outputs, 0);
auto output_imag = GetDeviceAddress<T>(outputs, 1);
cufftSetStream(FFT_plan_r2c, reinterpret_cast<cudaStream_t>(stream_ptr));
FFT3D<T>(Nfft, input_tensor, complex_fq, output_real, output_imag, FFT_plan_r2c,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(Nall * sizeof(T));
workspace_size_list_.push_back(Nfft * 2 * sizeof(T));
output_size_list_.push_back(Nfft * sizeof(T));
output_size_list_.push_back(Nfft * sizeof(T));
}
private:
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int fftx;
int ffty;
int fftz;
int Nall;
int Nfft;
cufftHandle FFT_plan_r2c;
};
} // namespace kernel
} // namespace mindspore
#endif

View File

@ -0,0 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/sponge/pme/ifft_3d_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
IFFT3D,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
IFFT3DGpuKernel, float)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,86 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONGE_PME_IFFT_3D_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONGE_PME_IFFT_3D_KERNEL_H_
#include <cuda_runtime_api.h>
#include <cufft.h>
#include <vector>
#include <string>
#include <map>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "runtime/device/gpu/cuda_common.h"
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/ifft_3d_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class IFFT3DGpuKernel : public GpuKernel {
public:
IFFT3DGpuKernel() = default;
~IFFT3DGpuKernel() override = default;
bool Init(const CNodePtr &kernel_node) override {
fftx = static_cast<int>(GetAttr<int64_t>(kernel_node, "fftx"));
ffty = static_cast<int>(GetAttr<int64_t>(kernel_node, "ffty"));
fftz = static_cast<int>(GetAttr<int64_t>(kernel_node, "fftz"));
Nfft = fftx * ffty * fftz;
Nall = fftx * ffty * (fftz - 1) * 2;
cufftPlan3d(&FFT_plan_c2r, fftx, ffty, (fftz - 1) * 2, CUFFT_C2R);
InitSizeLists();
return true;
}
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto input_real = GetDeviceAddress<T>(inputs, 0);
auto input_imag = GetDeviceAddress<T>(inputs, 1);
auto complex_fq = GetDeviceAddress<T>(workspace, 0);
auto output_tensor = GetDeviceAddress<T>(outputs, 0);
cufftSetStream(FFT_plan_c2r, reinterpret_cast<cudaStream_t>(stream_ptr));
IFFT3D<T>(Nfft, input_real, input_imag, complex_fq, output_tensor, FFT_plan_c2r,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(Nfft * sizeof(T));
input_size_list_.push_back(Nfft * sizeof(T));
workspace_size_list_.push_back(Nfft * 2 * sizeof(T));
output_size_list_.push_back(Nall * sizeof(T));
}
private:
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int fftx;
int ffty;
int fftz;
int Nall;
int Nfft;
cufftHandle FFT_plan_c2r;
};
} // namespace kernel
} // namespace mindspore
#endif

View File

@ -109,7 +109,7 @@ from .sponge_ops import (BondForce, BondEnergy, BondAtomEnergy, BondForceWithAto
Dihedral14LJForceWithDirectCF, Dihedral14LJEnergy, Dihedral14LJCFForceWithAtomEnergy,
Dihedral14LJAtomEnergy, Dihedral14CFEnergy, Dihedral14CFAtomEnergy, MDIterationLeapFrog,
GetCenterOfGeometry, MDTemperature, NeighborListUpdate, MDIterationLeapFrogLiujian,
CrdToUintCrd, MDIterationSetupRandState, TransferCrd)
CrdToUintCrd, MDIterationSetupRandState, TransferCrd, FFT3D, IFFT3D)
__all__ = [
'Unique',
@ -484,6 +484,8 @@ __all__ = [
"TensorScatterMin",
"TensorScatterSub",
"SoftShrink",
"FFT3D",
"IFFT3D"
]
__all__.sort()

View File

@ -3279,3 +3279,71 @@ class TransferCrd(PrimitiveWithInfer):
validator.check_tensor_dtype_valid('old_crd', old_crd_dtype, [mstype.float32], self.name)
validator.check_tensor_dtype_valid('box', box_dtype, [mstype.float32], self.name)
return mstype.float32, mstype.float32, mstype.float32, mstype.int32
class FFT3D(PrimitiveWithInfer):
"""
Forward FFT with Three-Dimensional Input.
Inputs:
- **input_tensor** (Tensor, float32) - [fftx, ffty, fftz]
Outputs:
- **output_real** (float32)
- **output_imag** (float32)
Supported Platforms:
``GPU``
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(
inputs=['input_tensor'],
outputs=['output_real', 'output_imag'])
def infer_shape(self, input_shape):
self.add_prim_attr('fftx', input_shape[0])
self.add_prim_attr('ffty', input_shape[1])
self.add_prim_attr('fftz', input_shape[2])
return [input_shape[0], input_shape[1], int(input_shape[2]/2)+1],\
[input_shape[0], input_shape[1], int(input_shape[2]/2)+1]
def infer_dtype(self, input_dtype):
validator.check_tensor_dtype_valid('input_tensor', input_dtype, mstype.number_type, self.name)
return input_dtype, input_dtype
class IFFT3D(PrimitiveWithInfer):
"""
Inverse FFT with Three-Dimensional Input.
Inputs:
- **input_real** (Tensor, float32) - [fftx, ffty, fftz]
- **input_imag** (Tensor, float32) - [fftx, ffty, fftz]
Outputs:
- **output_tensor** (float32)
Supported Platforms:
``GPU``
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(
inputs=['input_real', 'input_imag'],
outputs=['output_tensor'])
def infer_shape(self, input_shape1, input_shape2):
for i in range(len(input_shape1)):
validator.check_int(input_shape1[i], input_shape2[i], Rel.EQ, "input_shape", self.name)
self.add_prim_attr('fftx', input_shape1[0])
self.add_prim_attr('ffty', input_shape1[1])
self.add_prim_attr('fftz', input_shape1[2])
return [input_shape1[0], input_shape1[1], (input_shape1[2]-1)*2]
def infer_dtype(self, input_real_dtype, input_imag_dtype):
validator.check_tensor_dtype_valid('input_real', input_real_dtype, mstype.number_type, self.name)
validator.check_tensor_dtype_valid('input_imag', input_imag_dtype, mstype.number_type, self.name)
return input_real_dtype