diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc index a50fa924691..18609f63d12 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/scatter_nd_update_cpu_kernel.cc @@ -85,11 +85,12 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) { out_stride *= shape[i + 1]; out_strides_.push_back(out_stride); } + reverse(out_strides_.begin(), out_strides_.end()); dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); } bool ScatterNdUpdateCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, + const std::vector &, const std::vector &outputs) { if (dtype_ == kNumberTypeFloat16) { LaunchKernel(inputs, outputs); diff --git a/tests/st/ops/cpu/test_scatter_nd_update_op.py b/tests/st/ops/cpu/test_scatter_nd_update_op.py index 313640a556a..4158092abd9 100644 --- a/tests/st/ops/cpu/test_scatter_nd_update_op.py +++ b/tests/st/ops/cpu/test_scatter_nd_update_op.py @@ -26,46 +26,25 @@ from mindspore.ops import operations as P context.set_context(mode=context.GRAPH_MODE, device_target='CPU', save_graphs=False) -class ScatterNdUpdate1(nn.Cell): - def __init__(self): - super(ScatterNdUpdate1, self).__init__() - self.scatter_nd_update = P.ScatterNdUpdate() - self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x") - - def construct(self, indices, update): - return self.scatter_nd_update(self.x, indices, update) - - -class ScatterNdUpdate2(nn.Cell): - def __init__(self): - super(ScatterNdUpdate2, self).__init__() - self.scatter_nd_update = P.ScatterNdUpdate() - self.x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32), name="x") - - def construct(self, indices, update): - return self.scatter_nd_update(self.x, indices, update) - - -class ScatterNdUpdate3(nn.Cell): - def __init__(self): - super(ScatterNdUpdate3, self).__init__() - self.scatter_nd_update = P.ScatterNdUpdate() - self.x = Parameter(Tensor(np.zeros((4, 4, 4)), mstype.float32), name="x") - - def construct(self, indices, update): - return self.scatter_nd_update(self.x, indices, update) - - @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_op1(): + class ScatterNdUpdate(nn.Cell): + def __init__(self): + super(ScatterNdUpdate, self).__init__() + self.scatter_nd_update = P.ScatterNdUpdate() + self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x") + + def construct(self, indices, update): + return self.scatter_nd_update(self.x, indices, update) + indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32) update = Tensor(np.array([1.0, 2.2]), mstype.float32) - scatter_nd_update = ScatterNdUpdate1() + scatter_nd_update = ScatterNdUpdate() scatter_nd_update(indices, update) - print("x:\n", scatter_nd_update.x.data) + print("x:\n", scatter_nd_update.x.data.asnumpy()) expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]] assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, np.float)) @@ -74,12 +53,21 @@ def test_op1(): @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_op2(): + class ScatterNdUpdate(nn.Cell): + def __init__(self): + super(ScatterNdUpdate, self).__init__() + self.scatter_nd_update = P.ScatterNdUpdate() + self.x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32), name="x") + + def construct(self, indices, update): + return self.scatter_nd_update(self.x, indices, update) + indices = Tensor(np.array([[4], [3], [1], [7]]), mstype.int32) update = Tensor(np.array([9, 10, 11, 12]), mstype.float32) - scatter_nd_update = ScatterNdUpdate2() + scatter_nd_update = ScatterNdUpdate() scatter_nd_update(indices, update) - print("x:\n", scatter_nd_update.x.data) + print("x:\n", scatter_nd_update.x.data.asnumpy()) expect = [1, 11, 3, 10, 9, 6, 7, 12] assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=float)) @@ -88,17 +76,50 @@ def test_op2(): @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_op3(): + class ScatterNdUpdate(nn.Cell): + def __init__(self): + super(ScatterNdUpdate, self).__init__() + self.scatter_nd_update = P.ScatterNdUpdate() + self.x = Parameter(Tensor(np.zeros((4, 4, 4)), mstype.float32), name="x") + + def construct(self, indices, update): + return self.scatter_nd_update(self.x, indices, update) + indices = Tensor(np.array([[0], [2]]), mstype.int32) update = Tensor(np.array([[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]]]), mstype.float32) - scatter_nd_update = ScatterNdUpdate3() + scatter_nd_update = ScatterNdUpdate() scatter_nd_update(indices, update) - print("x:\n", scatter_nd_update.x.data) + print("x:\n", scatter_nd_update.x.data.asnumpy()) expect = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=float)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_op4(): + class ScatterNdUpdate(nn.Cell): + def __init__(self): + super(ScatterNdUpdate, self).__init__() + self.scatter_nd_update = P.ScatterNdUpdate() + self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x") + + def construct(self, indices, update): + return self.scatter_nd_update(self.x, indices, update) + + indices = Tensor(np.array([[0, 1]]), mstype.int32) + update = Tensor(np.array([1.0]), mstype.float32) + + scatter_nd_update = ScatterNdUpdate() + out = scatter_nd_update(indices, update) + print("x:\n", out) + assert np.allclose(out.asnumpy(), scatter_nd_update.x.data.asnumpy()) + expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]] + assert np.allclose(out.asnumpy(), np.array(expect, np.float))