From e3820ad2cb96414252ef3fbb814d78335384328d Mon Sep 17 00:00:00 2001 From: TFBunny Date: Tue, 13 Apr 2021 14:35:47 -0400 Subject: [PATCH] add tensor scatter update support different type inputs --- .../tensor_scatter_update_gpu_kernel.cc | 18 ++++++-- .../gpu/cuda_impl/tensor_scatter_update.cu | 26 ++++++----- .../gpu/cuda_impl/tensor_scatter_update.cuh | 4 +- .../st/ops/gpu/test_tensor_scatter_update.py | 45 +++++++++++++++++++ 4 files changed, 76 insertions(+), 17 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.cc index 6412c834a85..c8aa5fed9fc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/tensor_scatter_update_gpu_kernel.cc @@ -25,7 +25,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, .AddInputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16), TensorScatterUpdateGpuFwdKernel, half, int) - MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, KernelAttr() .AddInputAttr(kNumberTypeFloat32) @@ -33,7 +32,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), TensorScatterUpdateGpuFwdKernel, float, int) - MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, KernelAttr() .AddInputAttr(kNumberTypeFloat64) @@ -41,7 +39,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, .AddInputAttr(kNumberTypeFloat64) .AddOutputAttr(kNumberTypeFloat64), TensorScatterUpdateGpuFwdKernel, double, int) - MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, KernelAttr() .AddInputAttr(kNumberTypeInt8) @@ -49,7 +46,6 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, .AddInputAttr(kNumberTypeInt8) .AddOutputAttr(kNumberTypeInt8), TensorScatterUpdateGpuFwdKernel, char, int) - MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, KernelAttr() .AddInputAttr(kNumberTypeUInt8) @@ -57,5 +53,19 @@ MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, .AddInputAttr(kNumberTypeUInt8) .AddOutputAttr(kNumberTypeUInt8), TensorScatterUpdateGpuFwdKernel, uchar, int) +MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeInt32), + TensorScatterUpdateGpuFwdKernel, int, int) +MS_REG_GPU_KERNEL_TWO(TensorScatterUpdate, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + TensorScatterUpdateGpuFwdKernel, double, int64_t) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cu index e322e579f99..2f3d6fa627c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cu @@ -48,12 +48,11 @@ __global__ void TensorScatterUpdateKernel(T *input, S *indices, T *update, T *ou template void TensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, - const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, - S *work_shape, cudaStream_t stream) { - TensorScatterUpdateKernel<<>>(input, indices, update, output, - block_size, input_size, output_size, - indices_dim_0, indices_dim_1, - indices_stride, work_shape); + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, + S *indices_stride, S *work_shape, cudaStream_t stream) { + TensorScatterUpdateKernel<<>>( + input, indices, update, output, block_size, input_size, output_size, indices_dim_0, indices_dim_1, indices_stride, + work_shape); return; } @@ -77,13 +76,18 @@ template void TensorScatterUpdate(char *input, int *indices, char *up const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, cudaStream_t stream); -template void TensorScatterUpdate(int *input, int *indices, int *update, int *output, - const size_t &block_size, const size_t &input_size, - const size_t &output_size, const size_t &indices_dim_0, - const size_t &indices_dim_1, int *indices_stride, int *work_shape, - cudaStream_t stream); template void TensorScatterUpdate(unsigned char *input, int *indices, unsigned char *update, unsigned char *output, const size_t &block_size, const size_t &input_size, const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, int *indices_stride, int *work_shape, cudaStream_t stream); +template void TensorScatterUpdate(int *input, int *indices, int *update, int *output, + const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int *indices_stride, int *work_shape, + cudaStream_t stream); +template void TensorScatterUpdate(double *input, int64_t *indices, double *update, double *output, + const size_t &block_size, const size_t &input_size, + const size_t &output_size, const size_t &indices_dim_0, + const size_t &indices_dim_1, int64_t *indices_stride, + int64_t *work_shape, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cuh index c1dd9976d29..78f29c3e1eb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/tensor_scatter_update.cuh @@ -21,6 +21,6 @@ template void TensorScatterUpdate(T *input, S *indices, T *update, T *output, const size_t &block_size, const size_t &input_size, - const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, S *indices_stride, - S *work_shape, cudaStream_t stream); + const size_t &output_size, const size_t &indices_dim_0, const size_t &indices_dim_1, + S *indices_stride, S *work_shape, cudaStream_t stream); #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_TENSOR_SCATTER_UPDATE_IMPL_CUH diff --git a/tests/st/ops/gpu/test_tensor_scatter_update.py b/tests/st/ops/gpu/test_tensor_scatter_update.py index 333a20f1044..06abf26d4f9 100644 --- a/tests/st/ops/gpu/test_tensor_scatter_update.py +++ b/tests/st/ops/gpu/test_tensor_scatter_update.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import pytest import numpy as np import mindspore.context as context import mindspore.nn as nn @@ -32,6 +33,10 @@ def scatter_net(x, indices, update): scatter = Net() return scatter(Tensor(x), Tensor(indices), Tensor(update)).asnumpy() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard def test_scatter(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") @@ -77,3 +82,43 @@ def test_scatter(): [15, 99, 17, 44, 19], [100, 21, 22, 23, 55]]).astype(np.float32) np.testing.assert_allclose(out, expected, rtol=1e-6) + + arr_input = np.arange(25).reshape(5, 5).astype(np.float64) + arr_indices = np.array([[[0, 0], + [1, 1], + [2, 2], + [3, 3], + [4, 4]], + [[0, 4], + [1, 3], + [2, 2], + [3, 1], + [4, 0]]]).astype(np.int64) + arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.float64) + out = scatter_net(arr_input, arr_indices, arr_update) + expected = np.array([[11, 1, 2, 3, 66], + [5, 22, 7, 77, 9], + [10, 11, 33, 13, 14], + [15, 99, 17, 44, 19], + [100, 21, 22, 23, 55]]).astype(np.float64) + np.testing.assert_allclose(out, expected, rtol=1e-6) + + arr_input = np.arange(25).reshape(5, 5).astype(np.int32) + arr_indices = np.array([[[0, 0], + [1, 1], + [2, 2], + [3, 3], + [4, 4]], + [[0, 4], + [1, 3], + [2, 2], + [3, 1], + [4, 0]]]).astype(np.int32) + arr_update = np.array([[11, 22, 33, 44, 55], [66, 77, 33, 99, 100]]).astype(np.int32) + out = scatter_net(arr_input, arr_indices, arr_update) + expected = np.array([[11, 1, 2, 3, 66], + [5, 22, 7, 77, 9], + [10, 11, 33, 13, 14], + [15, 99, 17, 44, 19], + [100, 21, 22, 23, 55]]).astype(np.int32) + np.testing.assert_allclose(out, expected, rtol=1e-6)