diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc index e3c123e0069..38e01dcf143 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc @@ -102,14 +102,24 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector &inp const std::vector &outputs) { CHECK_KERNEL_INPUTS_NUM(inputs.size(), kScatterNdUpdateInputsNum, kernel_name_); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kScatterNdUpdateOutputsNum, kernel_name_); - if (dtype_ == kNumberTypeFloat16) { - LaunchKernel(inputs, outputs); - } else if (dtype_ == kNumberTypeFloat32) { - LaunchKernel(inputs, outputs); - } else if (dtype_ == kNumberTypeInt32) { - LaunchKernel(inputs, outputs); - } else { - MS_LOG(EXCEPTION) << "Unsupported input data type: " << dtype_; + switch (dtype_) { + case kNumberTypeFloat16: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeFloat32: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeFloat64: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeInt32: + LaunchKernel(inputs, outputs); + break; + case kNumberTypeInt64: + LaunchKernel(inputs, outputs); + break; + default: + MS_LOG(EXCEPTION) << "Unsupported input data type: " << dtype_; } return true; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h index eb5c1a2b792..f333657ca9c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.h @@ -73,6 +73,22 @@ MS_REG_CPU_KERNEL(TensorScatterUpdate, .AddOutputAttr(kNumberTypeFloat32), ScatterNdUpdateCPUKernel); +MS_REG_CPU_KERNEL(ScatterNdUpdate, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + ScatterNdUpdateCPUKernel); + +MS_REG_CPU_KERNEL(TensorScatterUpdate, + KernelAttr() + .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat64) + .AddOutputAttr(kNumberTypeFloat64), + ScatterNdUpdateCPUKernel); + MS_REG_CPU_KERNEL(ScatterNdUpdate, KernelAttr() .AddInputAttr(kNumberTypeInt32) @@ -91,18 +107,18 @@ MS_REG_CPU_KERNEL(TensorScatterUpdate, MS_REG_CPU_KERNEL(ScatterNdUpdate, KernelAttr() - .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64), - ScatterNdUpdateCPUKernel); + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), + ScatterNdUpdateCPUKernel) MS_REG_CPU_KERNEL(TensorScatterUpdate, KernelAttr() - .AddInputAttr(kNumberTypeFloat64) + .AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt32) - .AddInputAttr(kNumberTypeFloat64) - .AddOutputAttr(kNumberTypeFloat64), + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeInt64), ScatterNdUpdateCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/tests/st/ops/cpu/test_scatter_nd_update_op.py b/tests/st/ops/cpu/test_scatter_nd_update_op.py index 588c40c1986..f96a2a6ad7f 100644 --- a/tests/st/ops/cpu/test_scatter_nd_update_op.py +++ b/tests/st/ops/cpu/test_scatter_nd_update_op.py @@ -29,58 +29,80 @@ context.set_context(mode=context.GRAPH_MODE, device_target='CPU') @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_op1(): +@pytest.mark.parametrize('dtype', [np.float32, np.float64]) +def test_op1(dtype): + """ + Feature: ALL TO ALL + Description: test cases for updating float values + Expectation: the result match scipy + """ class ScatterNdUpdate(nn.Cell): def __init__(self): super(ScatterNdUpdate, self).__init__() self.scatter_nd_update = P.ScatterNdUpdate() - self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x") + self.x = Parameter( + Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]], dtype=dtype)), name="x") def construct(self, indices, update): return self.scatter_nd_update(self.x, indices, update) indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32) - update = Tensor(np.array([1.0, 2.2]), mstype.float32) + update = Tensor(np.array([1.0, 2.2], dtype=dtype)) scatter_nd_update = ScatterNdUpdate() scatter_nd_update(indices, update) print("x:\n", scatter_nd_update.x.data.asnumpy()) expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]] - assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, np.float)) + assert np.allclose(scatter_nd_update.x.data.asnumpy(), + np.array(expect, dtype=dtype)) @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_op2(): +@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64]) +def test_op2(dtype): + """ + Feature: ALL TO ALL + Description: test cases for updating int values + Expectation: the result match scipy + """ class ScatterNdUpdate(nn.Cell): def __init__(self): super(ScatterNdUpdate, self).__init__() self.scatter_nd_update = P.ScatterNdUpdate() - self.x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32), name="x") + self.x = Parameter( + Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=dtype)), name="x") def construct(self, indices, update): return self.scatter_nd_update(self.x, indices, update) indices = Tensor(np.array([[4], [3], [1], [7]]), mstype.int32) - update = Tensor(np.array([9, 10, 11, 12]), mstype.float32) + update = Tensor(np.array([9, 10, 11, 12], dtype=dtype)) scatter_nd_update = ScatterNdUpdate() scatter_nd_update(indices, update) print("x:\n", scatter_nd_update.x.data.asnumpy()) expect = [1, 11, 3, 10, 9, 6, 7, 12] - assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=float)) + assert np.allclose(scatter_nd_update.x.data.asnumpy(), + np.array(expect, dtype=dtype)) @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_op3(): +@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64]) +def test_op3(dtype): + """ + Feature: ALL TO ALL + Description: test cases for updating int values + Expectation: the result match scipy + """ class ScatterNdUpdate(nn.Cell): def __init__(self): super(ScatterNdUpdate, self).__init__() self.scatter_nd_update = P.ScatterNdUpdate() - self.x = Parameter(Tensor(np.zeros((4, 4, 4)), mstype.float32), name="x") + self.x = Parameter(Tensor(np.zeros((4, 4, 4)).astype(dtype)), name="x") def construct(self, indices, update): return self.scatter_nd_update(self.x, indices, update) @@ -89,7 +111,7 @@ def test_op3(): update = Tensor(np.array([[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], [[5, 5, 5, 5], [6, 6, 6, 6], - [7, 7, 7, 7], [8, 8, 8, 8]]]), mstype.float32) + [7, 7, 7, 7], [8, 8, 8, 8]]], dtype=dtype)) scatter_nd_update = ScatterNdUpdate() scatter_nd_update(indices, update) @@ -98,28 +120,34 @@ def test_op3(): [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] - assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=float)) + assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=dtype)) @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_op4(): +@pytest.mark.parametrize('dtype', [np.float32, np.float64]) +def test_op4(dtype): + """ + Feature: ALL TO ALL + Description: test cases for updating single float value + Expectation: the result match scipy + """ class ScatterNdUpdate(nn.Cell): def __init__(self): super(ScatterNdUpdate, self).__init__() self.scatter_nd_update = P.ScatterNdUpdate() - self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x") + self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]], dtype=dtype)), name="x") def construct(self, indices, update): return self.scatter_nd_update(self.x, indices, update) indices = Tensor(np.array([[0, 1]]), mstype.int32) - update = Tensor(np.array([1.0]), mstype.float32) + update = Tensor(np.array([1.0], dtype=dtype)) scatter_nd_update = ScatterNdUpdate() out = scatter_nd_update(indices, update) print("x:\n", out) assert np.allclose(out.asnumpy(), scatter_nd_update.x.data.asnumpy()) expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]] - assert np.allclose(out.asnumpy(), np.array(expect, np.float)) + assert np.allclose(out.asnumpy(), np.array(expect, dtype=dtype))