diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc index 1131d8847ac..30692388673 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.cc @@ -66,6 +66,13 @@ MS_REG_GPU_KERNEL_TWO( MS_REG_GPU_KERNEL_TWO( Gather, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), GatherV2GpuFwdKernel, uint8_t, int64_t) +MS_REG_GPU_KERNEL_TWO( + Gather, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + GatherV2GpuFwdKernel, bool, int) +MS_REG_GPU_KERNEL_TWO( + Gather, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), + GatherV2GpuFwdKernel, bool, int64_t) +// dynamic shape MS_REG_GPU_KERNEL_TWO(Gather, KernelAttr() .AddInputAttr(kNumberTypeFloat32) @@ -94,6 +101,21 @@ MS_REG_GPU_KERNEL_TWO(Gather, .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeFloat16), GatherV2GpuFwdKernel, half, int64_t) +MS_REG_GPU_KERNEL_TWO(Gather, + KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeBool), + GatherV2GpuFwdKernel, bool, int) +MS_REG_GPU_KERNEL_TWO(Gather, + KernelAttr() + .AddInputAttr(kNumberTypeBool) + .AddInputAttr(kNumberTypeInt64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeBool), + GatherV2GpuFwdKernel, bool, int64_t) +// dynamic shape ends MS_REG_GPU_KERNEL_TWO( SparseGatherV2, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu index c103cde474e..640cff2165b 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/gatherv2.cu @@ -77,3 +77,8 @@ template void GatherV2(uint8_t *input, int *indices, uint8_t *outp template void GatherV2(uint8_t *input, int64_t *indices, uint8_t *output, size_t output_dim0, size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template void GatherV2(bool *input, int *indices, bool *output, size_t output_dim0, + size_t output_dim1, size_t output_dim2, size_t input_dim1, cudaStream_t stream); +template void GatherV2(bool *input, int64_t *indices, bool *output, size_t output_dim0, + size_t output_dim1, size_t output_dim2, size_t input_dim1, + cudaStream_t stream); diff --git a/tests/st/ops/gpu/test_gatherV2_op.py b/tests/st/ops/gpu/test_gatherV2_op.py index 73e2ff37133..4498e789f4e 100644 --- a/tests/st/ops/gpu/test_gatherV2_op.py +++ b/tests/st/ops/gpu/test_gatherV2_op.py @@ -914,16 +914,16 @@ class GatherNet2(nn.Cell): @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_gather2(): - x = Tensor(np.array([[4., 5., 4., 1., 5.,], - [4., 9., 5., 6., 4.,], - [9., 8., 4., 3., 6.,], - [0., 4., 2., 2., 8.,], - [1., 8., 6., 2., 8.,], - [8., 1., 9., 7., 3.,], - [7., 9., 2., 5., 7.,], - [9., 8., 6., 8., 5.,], - [3., 7., 2., 7., 4.,], - [4., 2., 8., 2., 9.,]] + x = Tensor(np.array([[4., 5., 4., 1., 5.], + [4., 9., 5., 6., 4.], + [9., 8., 4., 3., 6.], + [0., 4., 2., 2., 8.], + [1., 8., 6., 2., 8.], + [8., 1., 9., 7., 3.], + [7., 9., 2., 5., 7.], + [9., 8., 6., 8., 5.], + [3., 7., 2., 7., 4.], + [4., 2., 8., 2., 9.]] ).astype(np.float32)) indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int64)) @@ -949,6 +949,7 @@ class GatherNetDynamic(nn.Cell): self.to_dyn_1 = dyn_a self.to_dyn_2 = dyn_b self.axis = axis + def construct(self, x, indices): # testing selective inputs being dynamic if self.to_dyn_1: @@ -967,16 +968,16 @@ def test_gatherV2_dyn_ab(): """ context.set_context(mode=context.GRAPH_MODE, device_target="GPU") gather = GatherNetDynamic() - x = Tensor(np.array([[4., 5., 4., 1., 5.,], - [4., 9., 5., 6., 4.,], - [9., 8., 4., 3., 6.,], - [0., 4., 2., 2., 8.,], - [1., 8., 6., 2., 8.,], - [8., 1., 9., 7., 3.,], - [7., 9., 2., 5., 7.,], - [9., 8., 6., 8., 5.,], - [3., 7., 2., 7., 4.,], - [4., 2., 8., 2., 9.,]] + x = Tensor(np.array([[4., 5., 4., 1., 5.], + [4., 9., 5., 6., 4.], + [9., 8., 4., 3., 6.], + [0., 4., 2., 2., 8.], + [1., 8., 6., 2., 8.], + [8., 1., 9., 7., 3.], + [7., 9., 2., 5., 7.], + [9., 8., 6., 8., 5.], + [3., 7., 2., 7., 4.], + [4., 2., 8., 2., 9.]] ).astype(np.float32)) indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32)) expect = np.array([[[0., 0., 0., 0., 0.], @@ -999,16 +1000,16 @@ def test_gatherV2_dyn_a(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") gather = GatherNetDynamic(-1, True, False) # test 1 - x = Tensor(np.array([[4., 5., 4., 1., 5.,], - [4., 9., 5., 6., 4.,], - [9., 8., 4., 3., 6.,], - [0., 4., 2., 2., 8.,], - [1., 8., 6., 2., 8.,], - [8., 1., 9., 7., 3.,], - [7., 9., 2., 5., 7.,], - [9., 8., 6., 8., 5.,], - [3., 7., 2., 7., 4.,], - [4., 2., 8., 2., 9.,]] + x = Tensor(np.array([[4., 5., 4., 1., 5.], + [4., 9., 5., 6., 4.], + [9., 8., 4., 3., 6.], + [0., 4., 2., 2., 8.], + [1., 8., 6., 2., 8.], + [8., 1., 9., 7., 3.], + [7., 9., 2., 5., 7.], + [9., 8., 6., 8., 5.], + [3., 7., 2., 7., 4.], + [4., 2., 8., 2., 9.]] ).astype(np.float32)) indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int64)) expect = np.array([[[0., 5., 0.]], @@ -1075,16 +1076,16 @@ def test_gatherV2_dyn_b(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") gather = GatherNetDynamic(-1, False, True) # test 1 - x = Tensor(np.array([[4., 5., 4., 1., 5.,], - [4., 9., 5., 6., 4.,], - [9., 8., 4., 3., 6.,], - [0., 4., 2., 2., 8.,], - [1., 8., 6., 2., 8.,], - [8., 1., 9., 7., 3.,], - [7., 9., 2., 5., 7.,], - [9., 8., 6., 8., 5.,], - [3., 7., 2., 7., 4.,], - [4., 2., 8., 2., 9.,]] + x = Tensor(np.array([[4., 5., 4., 1., 5.], + [4., 9., 5., 6., 4.], + [9., 8., 4., 3., 6.], + [0., 4., 2., 2., 8.], + [1., 8., 6., 2., 8.], + [8., 1., 9., 7., 3.], + [7., 9., 2., 5., 7.], + [9., 8., 6., 8., 5.], + [3., 7., 2., 7., 4.], + [4., 2., 8., 2., 9.]] ).astype(np.float32)) indices = Tensor(np.array([[4000, 1, 300000]]).astype(np.int32)) expect = np.array([[[0., 5., 0.]], @@ -1135,6 +1136,7 @@ def test_gatherV2_dyn_b(): assert np.all(diff < error) assert np.all(-diff < error) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -1358,3 +1360,17 @@ def test_gather1_uint8(): diff = output.asnumpy() - expect assert np.all(diff < error) assert np.all(-diff < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_gather1_bool(): + x = Tensor(np.array([[0, 1, 1, 0], [1, 0, 0, 0], [1, 0, 1, 0]], dtype=np.bool)) + indices = Tensor(np.array(([1, 2]), dtype='i4')) + expect = np.array([[1, 1], [0, 0], [0, 1]]).astype(np.bool) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + gather = GatherNet1() + output = gather(x, indices) + assert np.all(expect == output.asnumpy())