forked from mindspore-Ecosystem/mindspore
!20885 add dtypes & fft kernels for SPONGE
Merge pull request !20885 from huangmengxi/sponge_ccsrc
This commit is contained in:
commit
22e9299c17
|
@ -66,6 +66,12 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
GatherGpuFwdKernel, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherGpuFwdKernel, uint, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherGpuFwdKernel, uint, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherD, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherGpuFwdKernel, uchar, int)
|
||||
|
|
|
@ -80,6 +80,14 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
GatherDGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherGradGpuKernel, int64_t, uchar)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherDGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherGradGpuKernel, int, uint)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherDGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherGradGpuKernel, int64_t, uint)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherDGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
GatherGradGpuKernel, int, bool)
|
||||
|
|
|
@ -36,6 +36,10 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
GatherNdGpuFwdKernel, short, int) // NOLINT
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherNdGpuFwdKernel, uint, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherNdGpuFwdKernel, uchar, int)
|
||||
|
@ -60,6 +64,10 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
GatherNdGpuFwdKernel, short, int64_t) // NOLINT
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd,
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherNdGpuFwdKernel, uint, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
GatherNd, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherNdGpuFwdKernel, uchar, int64_t)
|
||||
|
|
|
@ -60,6 +60,12 @@ MS_REG_GPU_KERNEL_TWO(
|
|||
MS_REG_GPU_KERNEL_TWO(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
GatherV2GpuFwdKernel, int8_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherV2GpuFwdKernel, uint, int)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
GatherV2GpuFwdKernel, uint, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
GatherV2GpuFwdKernel, uint8_t, int)
|
||||
|
|
|
@ -67,5 +67,19 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
|||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
TensorScatterUpdateGpuFwdKernel, double, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
TensorScatterUpdateGpuFwdKernel, bool, int)
|
||||
MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeBool)
|
||||
.AddOutputAttr(kNumberTypeBool),
|
||||
TensorScatterUpdateGpuFwdKernel, bool, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,5 +30,9 @@ MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutpu
|
|||
int)
|
||||
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), TileGpuKernel,
|
||||
int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
TileGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), TileGpuKernel,
|
||||
bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -69,6 +69,10 @@ template void GatherNd<int, int>(int *input, int *indices, int *output, const si
|
|||
template void GatherNd<short, int>(short *input, int *indices, short *output, const size_t &output_dim0, // NOLINT
|
||||
const size_t &output_dim1, const size_t &indices_dim1, int *batch_indices,
|
||||
int *batch_strides, cudaStream_t stream);
|
||||
template void GatherNd<unsigned int, int>(unsigned int *input, int *indices, unsigned int *output,
|
||||
const size_t &output_dim0, const size_t &output_dim1,
|
||||
const size_t &indices_dim1, int *batch_indices, int *batch_strides,
|
||||
cudaStream_t stream);
|
||||
template void GatherNd<unsigned char, int>(unsigned char *input, int *indices, unsigned char *output,
|
||||
const size_t &output_dim0, const size_t &output_dim1,
|
||||
const size_t &indices_dim1, int *batch_indices, int *batch_strides,
|
||||
|
@ -91,6 +95,10 @@ template void GatherNd<int, int64_t>(int *input, int64_t *indices, int *output,
|
|||
template void GatherNd<short, int64_t>(short *input, int64_t *indices, short *output, // NOLINT
|
||||
const size_t &output_dim0, const size_t &output_dim1, const size_t &indices_dim1,
|
||||
int64_t *batch_indices, int64_t *batch_strides, cudaStream_t stream);
|
||||
template void GatherNd<unsigned int, int64_t>(unsigned int *input, int64_t *indices, unsigned int *output,
|
||||
const size_t &output_dim0, const size_t &output_dim1,
|
||||
const size_t &indices_dim1, int64_t *batch_indices,
|
||||
int64_t *batch_strides, cudaStream_t stream);
|
||||
template void GatherNd<unsigned char, int64_t>(unsigned char *input, int64_t *indices, unsigned char *output,
|
||||
const size_t &output_dim0, const size_t &output_dim1,
|
||||
const size_t &indices_dim1, int64_t *batch_indices,
|
||||
|
|
|
@ -72,6 +72,11 @@ template void GatherV2<int8_t, int>(int8_t *input, int *indices, int8_t *output,
|
|||
size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template void GatherV2<int8_t, int64_t>(int8_t *input, int64_t *indices, int8_t *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template void GatherV2<uint32_t, int>(uint32_t *input, int *indices, uint32_t *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template void GatherV2<uint32_t, int64_t>(uint32_t *input, int64_t *indices, uint32_t *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1,
|
||||
cudaStream_t stream);
|
||||
template void GatherV2<uint8_t, int>(uint8_t *input, int *indices, uint8_t *output, size_t output_dim0,
|
||||
size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream);
|
||||
template void GatherV2<uint8_t, int64_t>(uint8_t *input, int64_t *indices, uint8_t *output, size_t output_dim0,
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/fft_3d_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/pme_common.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ static void Split_Complex(const int element_numbers, T *real_part, T *imag_part,
|
||||
const cufftComplex *complex_element) {
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i < element_numbers) {
|
||||
real_part[i] = complex_element[i].x;
|
||||
imag_part[i] = complex_element[i].y;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void FFT3D(int Nfft, T *input_tensor, T *complex_fq, T *output_real, T *output_imag,
|
||||
const cufftHandle &FFT_plan_r2c, cudaStream_t stream) {
|
||||
cufftComplex *COMPLEX_FQ = reinterpret_cast<cufftComplex *>(complex_fq);
|
||||
cufftExecR2C(FFT_plan_r2c, input_tensor, COMPLEX_FQ);
|
||||
Split_Complex<T><<<Nfft / 1024 + 1, 1024, 0, stream>>>(Nfft, output_real, output_imag, COMPLEX_FQ);
|
||||
return;
|
||||
}
|
||||
|
||||
template void FFT3D<float>(int Nfft, float *input_tensor, float *complex_fq, float *output_real,
|
||||
float *output_imag, const cufftHandle &FFT_plan_r2c, cudaStream_t stream);
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_FFT_3D_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_FFT_3D_IMPL_H_
|
||||
|
||||
#include <cufft.h>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void FFT3D(int Nfft, T *input_tensor, T *complex_fq, T *output_real, T *output_imag,
|
||||
const cufftHandle &FFT_plan_r2c, cudaStream_t stream);
|
||||
|
||||
#endif
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/ifft_3d_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/pme_common.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ static void Merge_Complex(const int element_numbers, T *real_part, T *imag_part,
|
||||
cufftComplex *complex_element) {
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (i < element_numbers) {
|
||||
complex_element[i].x = real_part[i];
|
||||
complex_element[i].y = imag_part[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void IFFT3D(int Nfft, T *input_real, T *input_imag, T *complex_fq, T *output_tensor,
|
||||
const cufftHandle &FFT_plan_c2r, cudaStream_t stream) {
|
||||
cufftComplex *COMPLEX_FQ = reinterpret_cast<cufftComplex *>(complex_fq);
|
||||
Merge_Complex<T><<<Nfft / 1024 + 1, 1024, 0, stream>>>(Nfft, input_real, input_imag, COMPLEX_FQ);
|
||||
cufftExecC2R(FFT_plan_c2r, COMPLEX_FQ, output_tensor);
|
||||
return;
|
||||
}
|
||||
|
||||
template void IFFT3D<float>(int Nfft, float *input_real, float *input_imag, float *complex_fq,
|
||||
float *output_tensor, const cufftHandle &FFT_plan_c2r, cudaStream_t stream);
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_IFFT_3D_IMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPONGE_PME_IFFT_3D_IMPL_H_
|
||||
|
||||
#include <cufft.h>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void IFFT3D(int Nfft, T *input_real, T *input_imag, T *complex_fq, T *output_tensor,
|
||||
const cufftHandle &FFT_plan_c2r, cudaStream_t stream);
|
||||
|
||||
#endif
|
|
@ -86,6 +86,16 @@ template void TensorScatterUpdate<int, int>(int *input, int *indices, int *updat
|
|||
const size_t &output_size, const size_t &indices_dim_0,
|
||||
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
||||
cudaStream_t stream);
|
||||
template void TensorScatterUpdate<bool, int>(bool *input, int *indices, bool *update, bool *output,
|
||||
const size_t &block_size, const size_t &input_size,
|
||||
const size_t &output_size, const size_t &indices_dim_0,
|
||||
const size_t &indices_dim_1, int *indices_stride, int *work_shape,
|
||||
cudaStream_t stream);
|
||||
template void TensorScatterUpdate<bool, int64_t>(bool *input, int64_t *indices, bool *update, bool *output,
|
||||
const size_t &block_size, const size_t &input_size,
|
||||
const size_t &output_size, const size_t &indices_dim_0,
|
||||
const size_t &indices_dim_1, int64_t *indices_stride,
|
||||
int64_t *work_shape, cudaStream_t stream);
|
||||
template void TensorScatterUpdate<double, int64_t>(double *input, int64_t *indices, double *update, double *output,
|
||||
const size_t &block_size, const size_t &input_size,
|
||||
const size_t &output_size, const size_t &indices_dim_0,
|
||||
|
|
|
@ -72,3 +72,6 @@ template void CalTile<int>(const size_t output_size, const size_t input_size, co
|
|||
template void CalTile<int64_t>(const size_t output_size, const size_t input_size, const size_t shape_size,
|
||||
const size_t *input_shape, const size_t *output_shape, const int64_t *input,
|
||||
int64_t *output, cudaStream_t cuda_stream);
|
||||
template void CalTile<bool>(const size_t output_size, const size_t input_size, const size_t shape_size,
|
||||
const size_t *input_shape, const size_t *output_shape, const bool *input,
|
||||
bool *output, cudaStream_t cuda_stream);
|
||||
|
|
|
@ -352,6 +352,14 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
Mul, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
BroadcastOpGpuKernel, int8_t)
|
||||
|
||||
// uint32
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Sub, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
BroadcastOpGpuKernel, uint)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
BroadcastOpGpuKernel, uint)
|
||||
|
||||
// uint8
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
|
|
|
@ -60,6 +60,15 @@ MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOu
|
|||
FlattenGpuFwdKernel, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
FlattenGpuFwdKernel, int16_t)
|
||||
|
||||
// uint32
|
||||
MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
FlattenGpuFwdKernel, uint)
|
||||
MS_REG_GPU_KERNEL_ONE(Reshape, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
FlattenGpuFwdKernel, uint)
|
||||
MS_REG_GPU_KERNEL_ONE(ExpandDims, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
FlattenGpuFwdKernel, uint)
|
||||
|
||||
// uint8
|
||||
MS_REG_GPU_KERNEL_ONE(Flatten, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
FlattenGpuFwdKernel, uchar)
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/gpu/sponge/pme/fft_3d_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
FFT3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
FFT3DGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONGE_PME_FFT_3D_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONGE_PME_FFT_3D_KERNEL_H_
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cufft.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/fft_3d_impl.cuh"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class FFT3DGpuKernel : public GpuKernel {
|
||||
public:
|
||||
FFT3DGpuKernel() = default;
|
||||
~FFT3DGpuKernel() override = default;
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
fftx = static_cast<int>(GetAttr<int64_t>(kernel_node, "fftx"));
|
||||
ffty = static_cast<int>(GetAttr<int64_t>(kernel_node, "ffty"));
|
||||
fftz = static_cast<int>(GetAttr<int64_t>(kernel_node, "fftz"));
|
||||
Nall = fftx * ffty * fftz;
|
||||
Nfft = fftx * ffty * (fftz / 2 + 1);
|
||||
|
||||
cufftPlan3d(&FFT_plan_r2c, fftx, ffty, fftz, CUFFT_R2C);
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto input_tensor = GetDeviceAddress<T>(inputs, 0);
|
||||
auto complex_fq = GetDeviceAddress<T>(workspace, 0);
|
||||
auto output_real = GetDeviceAddress<T>(outputs, 0);
|
||||
auto output_imag = GetDeviceAddress<T>(outputs, 1);
|
||||
|
||||
cufftSetStream(FFT_plan_r2c, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
FFT3D<T>(Nfft, input_tensor, complex_fq, output_real, output_imag, FFT_plan_r2c,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(Nall * sizeof(T));
|
||||
workspace_size_list_.push_back(Nfft * 2 * sizeof(T));
|
||||
output_size_list_.push_back(Nfft * sizeof(T));
|
||||
output_size_list_.push_back(Nfft * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
int fftx;
|
||||
int ffty;
|
||||
int fftz;
|
||||
int Nall;
|
||||
int Nfft;
|
||||
cufftHandle FFT_plan_r2c;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/kernel_compiler/gpu/sponge/pme/ifft_3d_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
IFFT3D,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
IFFT3DGpuKernel, float)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONGE_PME_IFFT_3D_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPONGE_PME_IFFT_3D_KERNEL_H_
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cufft.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/sponge/pme/ifft_3d_impl.cuh"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class IFFT3DGpuKernel : public GpuKernel {
|
||||
public:
|
||||
IFFT3DGpuKernel() = default;
|
||||
~IFFT3DGpuKernel() override = default;
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
fftx = static_cast<int>(GetAttr<int64_t>(kernel_node, "fftx"));
|
||||
ffty = static_cast<int>(GetAttr<int64_t>(kernel_node, "ffty"));
|
||||
fftz = static_cast<int>(GetAttr<int64_t>(kernel_node, "fftz"));
|
||||
Nfft = fftx * ffty * fftz;
|
||||
Nall = fftx * ffty * (fftz - 1) * 2;
|
||||
|
||||
cufftPlan3d(&FFT_plan_c2r, fftx, ffty, (fftz - 1) * 2, CUFFT_C2R);
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
auto input_real = GetDeviceAddress<T>(inputs, 0);
|
||||
auto input_imag = GetDeviceAddress<T>(inputs, 1);
|
||||
auto complex_fq = GetDeviceAddress<T>(workspace, 0);
|
||||
auto output_tensor = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
cufftSetStream(FFT_plan_c2r, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
IFFT3D<T>(Nfft, input_real, input_imag, complex_fq, output_tensor, FFT_plan_c2r,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(Nfft * sizeof(T));
|
||||
input_size_list_.push_back(Nfft * sizeof(T));
|
||||
workspace_size_list_.push_back(Nfft * 2 * sizeof(T));
|
||||
output_size_list_.push_back(Nall * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
int fftx;
|
||||
int ffty;
|
||||
int fftz;
|
||||
int Nall;
|
||||
int Nfft;
|
||||
cufftHandle FFT_plan_c2r;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif
|
|
@ -109,7 +109,7 @@ from .sponge_ops import (BondForce, BondEnergy, BondAtomEnergy, BondForceWithAto
|
|||
Dihedral14LJForceWithDirectCF, Dihedral14LJEnergy, Dihedral14LJCFForceWithAtomEnergy,
|
||||
Dihedral14LJAtomEnergy, Dihedral14CFEnergy, Dihedral14CFAtomEnergy, MDIterationLeapFrog,
|
||||
GetCenterOfGeometry, MDTemperature, NeighborListUpdate, MDIterationLeapFrogLiujian,
|
||||
CrdToUintCrd, MDIterationSetupRandState, TransferCrd)
|
||||
CrdToUintCrd, MDIterationSetupRandState, TransferCrd, FFT3D, IFFT3D)
|
||||
|
||||
__all__ = [
|
||||
'Unique',
|
||||
|
@ -484,6 +484,8 @@ __all__ = [
|
|||
"TensorScatterMin",
|
||||
"TensorScatterSub",
|
||||
"SoftShrink",
|
||||
"FFT3D",
|
||||
"IFFT3D"
|
||||
]
|
||||
|
||||
__all__.sort()
|
||||
|
|
|
@ -3279,3 +3279,71 @@ class TransferCrd(PrimitiveWithInfer):
|
|||
validator.check_tensor_dtype_valid('old_crd', old_crd_dtype, [mstype.float32], self.name)
|
||||
validator.check_tensor_dtype_valid('box', box_dtype, [mstype.float32], self.name)
|
||||
return mstype.float32, mstype.float32, mstype.float32, mstype.int32
|
||||
|
||||
|
||||
class FFT3D(PrimitiveWithInfer):
|
||||
"""
|
||||
Forward FFT with Three-Dimensional Input.
|
||||
|
||||
Inputs:
|
||||
- **input_tensor** (Tensor, float32) - [fftx, ffty, fftz]
|
||||
|
||||
Outputs:
|
||||
- **output_real** (float32)
|
||||
- **output_imag** (float32)
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(
|
||||
inputs=['input_tensor'],
|
||||
outputs=['output_real', 'output_imag'])
|
||||
|
||||
def infer_shape(self, input_shape):
|
||||
self.add_prim_attr('fftx', input_shape[0])
|
||||
self.add_prim_attr('ffty', input_shape[1])
|
||||
self.add_prim_attr('fftz', input_shape[2])
|
||||
return [input_shape[0], input_shape[1], int(input_shape[2]/2)+1],\
|
||||
[input_shape[0], input_shape[1], int(input_shape[2]/2)+1]
|
||||
|
||||
def infer_dtype(self, input_dtype):
|
||||
validator.check_tensor_dtype_valid('input_tensor', input_dtype, mstype.number_type, self.name)
|
||||
return input_dtype, input_dtype
|
||||
|
||||
|
||||
class IFFT3D(PrimitiveWithInfer):
|
||||
"""
|
||||
Inverse FFT with Three-Dimensional Input.
|
||||
|
||||
Inputs:
|
||||
- **input_real** (Tensor, float32) - [fftx, ffty, fftz]
|
||||
- **input_imag** (Tensor, float32) - [fftx, ffty, fftz]
|
||||
|
||||
Outputs:
|
||||
- **output_tensor** (float32)
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
self.init_prim_io_names(
|
||||
inputs=['input_real', 'input_imag'],
|
||||
outputs=['output_tensor'])
|
||||
|
||||
def infer_shape(self, input_shape1, input_shape2):
|
||||
for i in range(len(input_shape1)):
|
||||
validator.check_int(input_shape1[i], input_shape2[i], Rel.EQ, "input_shape", self.name)
|
||||
self.add_prim_attr('fftx', input_shape1[0])
|
||||
self.add_prim_attr('ffty', input_shape1[1])
|
||||
self.add_prim_attr('fftz', input_shape1[2])
|
||||
return [input_shape1[0], input_shape1[1], (input_shape1[2]-1)*2]
|
||||
|
||||
def infer_dtype(self, input_real_dtype, input_imag_dtype):
|
||||
validator.check_tensor_dtype_valid('input_real', input_real_dtype, mstype.number_type, self.name)
|
||||
validator.check_tensor_dtype_valid('input_imag', input_imag_dtype, mstype.number_type, self.name)
|
||||
return input_real_dtype
|
||||
|
|
Loading…
Reference in New Issue