fix cpu kernel scatter_nd_update

This commit is contained in:
huanghui 2021-05-07 19:53:31 +08:00
parent ac5af72836
commit afe2bb0db8
2 changed files with 59 additions and 37 deletions

View File

@ -85,11 +85,12 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) {
out_stride *= shape[i + 1];
out_strides_.push_back(out_stride);
}
reverse(out_strides_.begin(), out_strides_.end());
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
}
bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs);

View File

@ -26,46 +26,25 @@ from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target='CPU', save_graphs=False)
class ScatterNdUpdate1(nn.Cell):
def __init__(self):
super(ScatterNdUpdate1, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
class ScatterNdUpdate2(nn.Cell):
def __init__(self):
super(ScatterNdUpdate2, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
class ScatterNdUpdate3(nn.Cell):
def __init__(self):
super(ScatterNdUpdate3, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.zeros((4, 4, 4)), mstype.float32), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_op1():
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32)
update = Tensor(np.array([1.0, 2.2]), mstype.float32)
scatter_nd_update = ScatterNdUpdate1()
scatter_nd_update = ScatterNdUpdate()
scatter_nd_update(indices, update)
print("x:\n", scatter_nd_update.x.data)
print("x:\n", scatter_nd_update.x.data.asnumpy())
expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]]
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, np.float))
@ -74,12 +53,21 @@ def test_op1():
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_op2():
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
indices = Tensor(np.array([[4], [3], [1], [7]]), mstype.int32)
update = Tensor(np.array([9, 10, 11, 12]), mstype.float32)
scatter_nd_update = ScatterNdUpdate2()
scatter_nd_update = ScatterNdUpdate()
scatter_nd_update(indices, update)
print("x:\n", scatter_nd_update.x.data)
print("x:\n", scatter_nd_update.x.data.asnumpy())
expect = [1, 11, 3, 10, 9, 6, 7, 12]
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=float))
@ -88,17 +76,50 @@ def test_op2():
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_op3():
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.zeros((4, 4, 4)), mstype.float32), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
indices = Tensor(np.array([[0], [2]]), mstype.int32)
update = Tensor(np.array([[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]]), mstype.float32)
scatter_nd_update = ScatterNdUpdate3()
scatter_nd_update = ScatterNdUpdate()
scatter_nd_update(indices, update)
print("x:\n", scatter_nd_update.x.data)
print("x:\n", scatter_nd_update.x.data.asnumpy())
expect = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=float))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_op4():
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
indices = Tensor(np.array([[0, 1]]), mstype.int32)
update = Tensor(np.array([1.0]), mstype.float32)
scatter_nd_update = ScatterNdUpdate()
out = scatter_nd_update(indices, update)
print("x:\n", out)
assert np.allclose(out.asnumpy(), scatter_nd_update.x.data.asnumpy())
expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]]
assert np.allclose(out.asnumpy(), np.array(expect, np.float))