forked from mindspore-Ecosystem/mindspore
fix cpu kernel scatter_nd_update
This commit is contained in:
parent
ac5af72836
commit
afe2bb0db8
|
@ -85,11 +85,12 @@ void ScatterNdUpdateCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
out_stride *= shape[i + 1];
|
||||
out_strides_.push_back(out_stride);
|
||||
}
|
||||
reverse(out_strides_.begin(), out_strides_.end());
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
}
|
||||
|
||||
bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernel<float16>(inputs, outputs);
|
||||
|
|
|
@ -26,46 +26,25 @@ from mindspore.ops import operations as P
|
|||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU', save_graphs=False)
|
||||
|
||||
|
||||
class ScatterNdUpdate1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate1, self).__init__()
|
||||
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||
self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x")
|
||||
|
||||
def construct(self, indices, update):
|
||||
return self.scatter_nd_update(self.x, indices, update)
|
||||
|
||||
|
||||
class ScatterNdUpdate2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate2, self).__init__()
|
||||
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||
self.x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32), name="x")
|
||||
|
||||
def construct(self, indices, update):
|
||||
return self.scatter_nd_update(self.x, indices, update)
|
||||
|
||||
|
||||
class ScatterNdUpdate3(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate3, self).__init__()
|
||||
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||
self.x = Parameter(Tensor(np.zeros((4, 4, 4)), mstype.float32), name="x")
|
||||
|
||||
def construct(self, indices, update):
|
||||
return self.scatter_nd_update(self.x, indices, update)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_op1():
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate, self).__init__()
|
||||
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||
self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x")
|
||||
|
||||
def construct(self, indices, update):
|
||||
return self.scatter_nd_update(self.x, indices, update)
|
||||
|
||||
indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32)
|
||||
update = Tensor(np.array([1.0, 2.2]), mstype.float32)
|
||||
|
||||
scatter_nd_update = ScatterNdUpdate1()
|
||||
scatter_nd_update = ScatterNdUpdate()
|
||||
scatter_nd_update(indices, update)
|
||||
print("x:\n", scatter_nd_update.x.data)
|
||||
print("x:\n", scatter_nd_update.x.data.asnumpy())
|
||||
expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]]
|
||||
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, np.float))
|
||||
|
||||
|
@ -74,12 +53,21 @@ def test_op1():
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_op2():
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate, self).__init__()
|
||||
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||
self.x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32), name="x")
|
||||
|
||||
def construct(self, indices, update):
|
||||
return self.scatter_nd_update(self.x, indices, update)
|
||||
|
||||
indices = Tensor(np.array([[4], [3], [1], [7]]), mstype.int32)
|
||||
update = Tensor(np.array([9, 10, 11, 12]), mstype.float32)
|
||||
|
||||
scatter_nd_update = ScatterNdUpdate2()
|
||||
scatter_nd_update = ScatterNdUpdate()
|
||||
scatter_nd_update(indices, update)
|
||||
print("x:\n", scatter_nd_update.x.data)
|
||||
print("x:\n", scatter_nd_update.x.data.asnumpy())
|
||||
expect = [1, 11, 3, 10, 9, 6, 7, 12]
|
||||
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=float))
|
||||
|
||||
|
@ -88,17 +76,50 @@ def test_op2():
|
|||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_op3():
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate, self).__init__()
|
||||
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||
self.x = Parameter(Tensor(np.zeros((4, 4, 4)), mstype.float32), name="x")
|
||||
|
||||
def construct(self, indices, update):
|
||||
return self.scatter_nd_update(self.x, indices, update)
|
||||
|
||||
indices = Tensor(np.array([[0], [2]]), mstype.int32)
|
||||
update = Tensor(np.array([[[5, 5, 5, 5], [6, 6, 6, 6],
|
||||
[7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[5, 5, 5, 5], [6, 6, 6, 6],
|
||||
[7, 7, 7, 7], [8, 8, 8, 8]]]), mstype.float32)
|
||||
|
||||
scatter_nd_update = ScatterNdUpdate3()
|
||||
scatter_nd_update = ScatterNdUpdate()
|
||||
scatter_nd_update(indices, update)
|
||||
print("x:\n", scatter_nd_update.x.data)
|
||||
print("x:\n", scatter_nd_update.x.data.asnumpy())
|
||||
expect = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
|
||||
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
|
||||
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=float))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_op4():
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate, self).__init__()
|
||||
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||
self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x")
|
||||
|
||||
def construct(self, indices, update):
|
||||
return self.scatter_nd_update(self.x, indices, update)
|
||||
|
||||
indices = Tensor(np.array([[0, 1]]), mstype.int32)
|
||||
update = Tensor(np.array([1.0]), mstype.float32)
|
||||
|
||||
scatter_nd_update = ScatterNdUpdate()
|
||||
out = scatter_nd_update(indices, update)
|
||||
print("x:\n", out)
|
||||
assert np.allclose(out.asnumpy(), scatter_nd_update.x.data.asnumpy())
|
||||
expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]]
|
||||
assert np.allclose(out.asnumpy(), np.array(expect, np.float))
|
||||
|
|
Loading…
Reference in New Issue