diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc index daad7939e63..7ff7f1d13be 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc @@ -49,6 +49,38 @@ MS_REG_GPU_KERNEL_TWO( KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), GatherV2GpuFwdKernel, half, int64_t) +MS_REG_GPU_KERNEL_TWO( + Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + GatherV2GpuFwdKernel, int, int) + +MS_REG_GPU_KERNEL_TWO( + Gather, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + GatherV2GpuFwdKernel, int, int64_t) + +MS_REG_GPU_KERNEL_TWO( + Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), + GatherV2GpuFwdKernel, int16_t, int) + +MS_REG_GPU_KERNEL_TWO( + Gather, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), + GatherV2GpuFwdKernel, int16_t, int64_t) + +MS_REG_GPU_KERNEL_TWO( + Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), + GatherV2GpuFwdKernel, int8_t, int) + +MS_REG_GPU_KERNEL_TWO( + Gather, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), + GatherV2GpuFwdKernel, int8_t, int64_t) + +MS_REG_GPU_KERNEL_TWO( + Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), + GatherV2GpuFwdKernel, uint8_t, int) + +MS_REG_GPU_KERNEL_TWO( + Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), + GatherV2GpuFwdKernel, uint8_t, int64_t) + MS_REG_GPU_KERNEL_TWO(Gather, KernelAttr() .AddInputAttr(kNumberTypeFloat32) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu index a2469dd0a59..c103cde474e 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu @@ -59,3 +59,21 @@ template void GatherV2(double *input, int *indices, double *output, size_t output_dim2, size_t input_dim1, cudaStream_t stream); template void GatherV2(double *input, int64_t *indices, double *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template void GatherV2(int *input, int *indices, int *output, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template void GatherV2(int *input, int64_t *indices, int *output, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template void GatherV2(int16_t *input, int *indices, int16_t *output, size_t output_dim0, + size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template void GatherV2(int16_t *input, int64_t *indices, int16_t *output, size_t output_dim0, + size_t output_dim1, size_t output_dim2, size_t input_dim1, + cudaStream_t stream); +template void GatherV2(int8_t *input, int *indices, int8_t *output, size_t output_dim0, size_t output_dim1, + size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template void GatherV2(int8_t *input, int64_t *indices, int8_t *output, size_t output_dim0, + size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template void GatherV2(uint8_t *input, int *indices, uint8_t *output, size_t output_dim0, + size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template void GatherV2(uint8_t *input, int64_t *indices, uint8_t *output, size_t output_dim0, + size_t output_dim1, size_t output_dim2, size_t input_dim1, + cudaStream_t stream); diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index 9ac29d2f80e..bdac2662745 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -1300,3 +1300,15 @@ def get_bprop_lin_space(self): return zeros_like(start), zeros_like(stop), zeros_like(num) return bprop + + +@bprop_getters.register(P.IndexAdd) +def get_bprop_index_add(self): + """Generate bprop for IndexAdd""" + gather = P.Gather() + _axis = self.axis + + def bprop(input_x, indices, input_y, out, dout): + return dout, zeros_like(indices), gather(dout, indices, _axis) + + return bprop diff --git a/tests/st/ops/gpu/test_gatherV2_op.py b/tests/st/ops/gpu/test_gatherV2_op.py index dc9a758bcb1..73e2ff37133 100644 --- a/tests/st/ops/gpu/test_gatherV2_op.py +++ b/tests/st/ops/gpu/test_gatherV2_op.py @@ -1178,3 +1178,183 @@ def test_gather1_float64(): diff = output.asnumpy() - expect assert np.all(diff < error) assert np.all(-diff < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather1_int32(): + x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.int32).reshape(2, 3, 4, 5)) + indices = Tensor(np.array([1, 3, 4], dtype='i4')) + expect = np.array([[[[1., 3., 4.], + [6., 8., 9.], + [11., 13., 14.], + [16., 18., 19.]], + + [[21., 23., 24.], + [26., 28., 29.], + [31., 33., 34.], + [36., 38., 39.]], + + [[41., 43., 44.], + [46., 48., 49.], + [51., 53., 54.], + [56., 58., 59.]]], + + [[[61., 63., 64.], + [66., 68., 69.], + [71., 73., 74.], + [76., 78., 79.]], + + [[81., 83., 84.], + [86., 88., 89.], + [91., 93., 94.], + [96., 98., 99.]], + + [[101., 103., 104.], + [106., 108., 109.], + [111., 113., 114.], + [116., 118., 119.]]]]).astype(np.int32) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = GatherNet1() + output = gather(x, indices) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather1_int16(): + x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.int16).reshape(2, 3, 4, 5)) + indices = Tensor(np.array([1, 3, 4], dtype='i4')) + expect = np.array([[[[1., 3., 4.], + [6., 8., 9.], + [11., 13., 14.], + [16., 18., 19.]], + + [[21., 23., 24.], + [26., 28., 29.], + [31., 33., 34.], + [36., 38., 39.]], + + [[41., 43., 44.], + [46., 48., 49.], + [51., 53., 54.], + [56., 58., 59.]]], + + [[[61., 63., 64.], + [66., 68., 69.], + [71., 73., 74.], + [76., 78., 79.]], + + [[81., 83., 84.], + [86., 88., 89.], + [91., 93., 94.], + [96., 98., 99.]], + + [[101., 103., 104.], + [106., 108., 109.], + [111., 113., 114.], + [116., 118., 119.]]]]).astype(np.int16) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = GatherNet1() + output = gather(x, indices) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather1_int8(): + x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.int8).reshape(2, 3, 4, 5)) + indices = Tensor(np.array([1, 3, 4], dtype='i4')) + expect = np.array([[[[1., 3., 4.], + [6., 8., 9.], + [11., 13., 14.], + [16., 18., 19.]], + + [[21., 23., 24.], + [26., 28., 29.], + [31., 33., 34.], + [36., 38., 39.]], + + [[41., 43., 44.], + [46., 48., 49.], + [51., 53., 54.], + [56., 58., 59.]]], + + [[[61., 63., 64.], + [66., 68., 69.], + [71., 73., 74.], + [76., 78., 79.]], + + [[81., 83., 84.], + [86., 88., 89.], + [91., 93., 94.], + [96., 98., 99.]], + + [[101., 103., 104.], + [106., 108., 109.], + [111., 113., 114.], + [116., 118., 119.]]]]).astype(np.int8) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = GatherNet1() + output = gather(x, indices) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather1_uint8(): + x = Tensor(np.arange(2 * 3 * 4 * 5, dtype=np.uint8).reshape(2, 3, 4, 5)) + indices = Tensor(np.array([1, 3, 4], dtype='i4')) + expect = np.array([[[[1., 3., 4.], + [6., 8., 9.], + [11., 13., 14.], + [16., 18., 19.]], + + [[21., 23., 24.], + [26., 28., 29.], + [31., 33., 34.], + [36., 38., 39.]], + + [[41., 43., 44.], + [46., 48., 49.], + [51., 53., 54.], + [56., 58., 59.]]], + + [[[61., 63., 64.], + [66., 68., 69.], + [71., 73., 74.], + [76., 78., 79.]], + + [[81., 83., 84.], + [86., 88., 89.], + [91., 93., 94.], + [96., 98., 99.]], + + [[101., 103., 104.], + [106., 108., 109.], + [111., 113., 114.], + [116., 118., 119.]]]]).astype(np.uint8) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = GatherNet1() + output = gather(x, indices) + error = np.ones(shape=output.asnumpy().shape) * 1.0e-6 + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) diff --git a/tests/st/ops/gpu/test_index_add_op.py b/tests/st/ops/gpu/test_index_add_op.py index fe2343edbfa..e706edc0cd8 100644 --- a/tests/st/ops/gpu/test_index_add_op.py +++ b/tests/st/ops/gpu/test_index_add_op.py @@ -16,10 +16,12 @@ import numpy as np import pytest +import mindspore import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops import composite as C class NetIndexAdd(nn.Cell): @@ -257,3 +259,110 @@ def test_index_add_invalid_inputs(): net = NetIndexAdd(1) _ = net(Tensor(x), Tensor(idx), Tensor(y)) assert "out of range" in str(info.value) + + +class IndexAddGradNet(nn.Cell): + def __init__(self, network): + super(IndexAddGradNet, self).__init__() + self.grad = C.GradOperation(get_all=True, sens_param=True) + self.network = network + + def construct(self, x, idx, y, dout): + out = self.grad(self.network)(x, idx, y, dout) + return out + + +def index_add_grad_with_type(nptype): + net = NetIndexAdd(1) + grad_net = IndexAddGradNet(net) + x = Tensor(np.arange(15).reshape(5, 3).astype(nptype)) + y = Tensor(np.arange(5).reshape(5, 1).astype(nptype)) + dout = Tensor(np.array([[63., 64., 65.], + [66., 67., 68.], + [69., 70., 71.], + [72., 73., 74.], + [75., 76., 77.]]).astype(nptype)) + index = Tensor(np.array([1]), dtype=mindspore.int32) + xgrad, _, ygrad = grad_net(x, index, y, dout) + expect_xgrad = np.array([[63., 64., 65.], + [66., 67., 68.], + [69., 70., 71.], + [72., 73., 74.], + [75., 76., 77.]]).astype(nptype) + expect_ygrad = np.array([[64.], + [67.], + [70.], + [73.], + [76.]]).astype(nptype) + np.testing.assert_array_equal(xgrad.asnumpy(), expect_xgrad) + np.testing.assert_array_equal(ygrad.asnumpy(), expect_ygrad) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_index_add_grad_float64(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + index_add_grad_with_type(np.float64) + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + index_add_grad_with_type(np.float64) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_index_add_grad_float32(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + index_add_grad_with_type(np.float32) + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + index_add_grad_with_type(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_index_add_grad_float16(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + index_add_grad_with_type(np.float16) + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + index_add_grad_with_type(np.float16) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_index_add_grad_int32(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + index_add_grad_with_type(np.int32) + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + index_add_grad_with_type(np.int32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_index_add_grad_int16(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + index_add_grad_with_type(np.int16) + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + index_add_grad_with_type(np.int16) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_index_add_grad_int8(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + index_add_grad_with_type(np.int8) + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + index_add_grad_with_type(np.int8) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_index_add_grad_uint8(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + index_add_grad_with_type(np.uint8) + context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') + index_add_grad_with_type(np.uint8)