add scatter nd update int64 and float64 support

This commit is contained in:
zhujingxuan 2021-11-01 14:41:59 +08:00
parent d39ad14ce5
commit ed24ad8f94
3 changed files with 85 additions and 31 deletions

View File

@ -102,13 +102,23 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kScatterNdUpdateInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kScatterNdUpdateOutputsNum, kernel_name_);
if (dtype_ == kNumberTypeFloat16) {
switch (dtype_) {
case kNumberTypeFloat16:
LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
break;
case kNumberTypeFloat32:
LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeInt32) {
break;
case kNumberTypeFloat64:
LaunchKernel<double>(inputs, outputs);
break;
case kNumberTypeInt32:
LaunchKernel<int>(inputs, outputs);
} else {
break;
case kNumberTypeInt64:
LaunchKernel<int64_t>(inputs, outputs);
break;
default:
MS_LOG(EXCEPTION) << "Unsupported input data type: " << dtype_;
}
return true;

View File

@ -73,6 +73,22 @@ MS_REG_CPU_KERNEL(TensorScatterUpdate,
.AddOutputAttr(kNumberTypeFloat32),
ScatterNdUpdateCPUKernel);
MS_REG_CPU_KERNEL(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
ScatterNdUpdateCPUKernel);
MS_REG_CPU_KERNEL(TensorScatterUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
ScatterNdUpdateCPUKernel);
MS_REG_CPU_KERNEL(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
@ -91,18 +107,18 @@ MS_REG_CPU_KERNEL(TensorScatterUpdate,
MS_REG_CPU_KERNEL(ScatterNdUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
ScatterNdUpdateCPUKernel);
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
ScatterNdUpdateCPUKernel)
MS_REG_CPU_KERNEL(TensorScatterUpdate,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
ScatterNdUpdateCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -29,58 +29,80 @@ context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_op1():
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_op1(dtype):
"""
Feature: ALL TO ALL
Description: test cases for updating float values
Expectation: the result match scipy
"""
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x")
self.x = Parameter(
Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]], dtype=dtype)), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32)
update = Tensor(np.array([1.0, 2.2]), mstype.float32)
update = Tensor(np.array([1.0, 2.2], dtype=dtype))
scatter_nd_update = ScatterNdUpdate()
scatter_nd_update(indices, update)
print("x:\n", scatter_nd_update.x.data.asnumpy())
expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]]
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, np.float))
assert np.allclose(scatter_nd_update.x.data.asnumpy(),
np.array(expect, dtype=dtype))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_op2():
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
def test_op2(dtype):
"""
Feature: ALL TO ALL
Description: test cases for updating int values
Expectation: the result match scipy
"""
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32), name="x")
self.x = Parameter(
Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=dtype)), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
indices = Tensor(np.array([[4], [3], [1], [7]]), mstype.int32)
update = Tensor(np.array([9, 10, 11, 12]), mstype.float32)
update = Tensor(np.array([9, 10, 11, 12], dtype=dtype))
scatter_nd_update = ScatterNdUpdate()
scatter_nd_update(indices, update)
print("x:\n", scatter_nd_update.x.data.asnumpy())
expect = [1, 11, 3, 10, 9, 6, 7, 12]
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=float))
assert np.allclose(scatter_nd_update.x.data.asnumpy(),
np.array(expect, dtype=dtype))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_op3():
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
def test_op3(dtype):
"""
Feature: ALL TO ALL
Description: test cases for updating int values
Expectation: the result match scipy
"""
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.zeros((4, 4, 4)), mstype.float32), name="x")
self.x = Parameter(Tensor(np.zeros((4, 4, 4)).astype(dtype)), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
@ -89,7 +111,7 @@ def test_op3():
update = Tensor(np.array([[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]]), mstype.float32)
[7, 7, 7, 7], [8, 8, 8, 8]]], dtype=dtype))
scatter_nd_update = ScatterNdUpdate()
scatter_nd_update(indices, update)
@ -98,28 +120,34 @@ def test_op3():
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=float))
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=dtype))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_op4():
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_op4(dtype):
"""
Feature: ALL TO ALL
Description: test cases for updating single float value
Expectation: the result match scipy
"""
class ScatterNdUpdate(nn.Cell):
def __init__(self):
super(ScatterNdUpdate, self).__init__()
self.scatter_nd_update = P.ScatterNdUpdate()
self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x")
self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]], dtype=dtype)), name="x")
def construct(self, indices, update):
return self.scatter_nd_update(self.x, indices, update)
indices = Tensor(np.array([[0, 1]]), mstype.int32)
update = Tensor(np.array([1.0]), mstype.float32)
update = Tensor(np.array([1.0], dtype=dtype))
scatter_nd_update = ScatterNdUpdate()
out = scatter_nd_update(indices, update)
print("x:\n", out)
assert np.allclose(out.asnumpy(), scatter_nd_update.x.data.asnumpy())
expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]]
assert np.allclose(out.asnumpy(), np.array(expect, np.float))
assert np.allclose(out.asnumpy(), np.array(expect, dtype=dtype))