forked from mindspore-Ecosystem/mindspore
add scatter nd update int64 and float64 support
This commit is contained in:
parent
d39ad14ce5
commit
ed24ad8f94
|
@ -102,14 +102,24 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kScatterNdUpdateInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kScatterNdUpdateOutputsNum, kernel_name_);
|
||||
if (dtype_ == kNumberTypeFloat16) {
|
||||
LaunchKernel<float16>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeInt32) {
|
||||
LaunchKernel<int>(inputs, outputs);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported input data type: " << dtype_;
|
||||
switch (dtype_) {
|
||||
case kNumberTypeFloat16:
|
||||
LaunchKernel<float16>(inputs, outputs);
|
||||
break;
|
||||
case kNumberTypeFloat32:
|
||||
LaunchKernel<float>(inputs, outputs);
|
||||
break;
|
||||
case kNumberTypeFloat64:
|
||||
LaunchKernel<double>(inputs, outputs);
|
||||
break;
|
||||
case kNumberTypeInt32:
|
||||
LaunchKernel<int>(inputs, outputs);
|
||||
break;
|
||||
case kNumberTypeInt64:
|
||||
LaunchKernel<int64_t>(inputs, outputs);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Unsupported input data type: " << dtype_;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -73,6 +73,22 @@ MS_REG_CPU_KERNEL(TensorScatterUpdate,
|
|||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
ScatterNdUpdateCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(ScatterNdUpdate,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
ScatterNdUpdateCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(TensorScatterUpdate,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
ScatterNdUpdateCPUKernel);
|
||||
|
||||
MS_REG_CPU_KERNEL(ScatterNdUpdate,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
|
@ -91,18 +107,18 @@ MS_REG_CPU_KERNEL(TensorScatterUpdate,
|
|||
|
||||
MS_REG_CPU_KERNEL(ScatterNdUpdate,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
ScatterNdUpdateCPUKernel);
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
ScatterNdUpdateCPUKernel)
|
||||
|
||||
MS_REG_CPU_KERNEL(TensorScatterUpdate,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
ScatterNdUpdateCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,58 +29,80 @@ context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_op1():
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_op1(dtype):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for updating float values
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate, self).__init__()
|
||||
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||
self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x")
|
||||
self.x = Parameter(
|
||||
Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]], dtype=dtype)), name="x")
|
||||
|
||||
def construct(self, indices, update):
|
||||
return self.scatter_nd_update(self.x, indices, update)
|
||||
|
||||
indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32)
|
||||
update = Tensor(np.array([1.0, 2.2]), mstype.float32)
|
||||
update = Tensor(np.array([1.0, 2.2], dtype=dtype))
|
||||
|
||||
scatter_nd_update = ScatterNdUpdate()
|
||||
scatter_nd_update(indices, update)
|
||||
print("x:\n", scatter_nd_update.x.data.asnumpy())
|
||||
expect = [[1.0, 0.3, 3.6], [0.4, 2.2, -3.2]]
|
||||
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, np.float))
|
||||
assert np.allclose(scatter_nd_update.x.data.asnumpy(),
|
||||
np.array(expect, dtype=dtype))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_op2():
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
|
||||
def test_op2(dtype):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for updating int values
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate, self).__init__()
|
||||
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||
self.x = Parameter(Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32), name="x")
|
||||
self.x = Parameter(
|
||||
Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=dtype)), name="x")
|
||||
|
||||
def construct(self, indices, update):
|
||||
return self.scatter_nd_update(self.x, indices, update)
|
||||
|
||||
indices = Tensor(np.array([[4], [3], [1], [7]]), mstype.int32)
|
||||
update = Tensor(np.array([9, 10, 11, 12]), mstype.float32)
|
||||
update = Tensor(np.array([9, 10, 11, 12], dtype=dtype))
|
||||
|
||||
scatter_nd_update = ScatterNdUpdate()
|
||||
scatter_nd_update(indices, update)
|
||||
print("x:\n", scatter_nd_update.x.data.asnumpy())
|
||||
expect = [1, 11, 3, 10, 9, 6, 7, 12]
|
||||
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=float))
|
||||
assert np.allclose(scatter_nd_update.x.data.asnumpy(),
|
||||
np.array(expect, dtype=dtype))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_op3():
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64, np.int32, np.int64])
|
||||
def test_op3(dtype):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for updating int values
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate, self).__init__()
|
||||
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||
self.x = Parameter(Tensor(np.zeros((4, 4, 4)), mstype.float32), name="x")
|
||||
self.x = Parameter(Tensor(np.zeros((4, 4, 4)).astype(dtype)), name="x")
|
||||
|
||||
def construct(self, indices, update):
|
||||
return self.scatter_nd_update(self.x, indices, update)
|
||||
|
@ -89,7 +111,7 @@ def test_op3():
|
|||
update = Tensor(np.array([[[5, 5, 5, 5], [6, 6, 6, 6],
|
||||
[7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[5, 5, 5, 5], [6, 6, 6, 6],
|
||||
[7, 7, 7, 7], [8, 8, 8, 8]]]), mstype.float32)
|
||||
[7, 7, 7, 7], [8, 8, 8, 8]]], dtype=dtype))
|
||||
|
||||
scatter_nd_update = ScatterNdUpdate()
|
||||
scatter_nd_update(indices, update)
|
||||
|
@ -98,28 +120,34 @@ def test_op3():
|
|||
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
|
||||
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
|
||||
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
|
||||
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=float))
|
||||
assert np.allclose(scatter_nd_update.x.data.asnumpy(), np.array(expect, dtype=dtype))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_op4():
|
||||
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
|
||||
def test_op4(dtype):
|
||||
"""
|
||||
Feature: ALL TO ALL
|
||||
Description: test cases for updating single float value
|
||||
Expectation: the result match scipy
|
||||
"""
|
||||
class ScatterNdUpdate(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ScatterNdUpdate, self).__init__()
|
||||
self.scatter_nd_update = P.ScatterNdUpdate()
|
||||
self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32), name="x")
|
||||
self.x = Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]], dtype=dtype)), name="x")
|
||||
|
||||
def construct(self, indices, update):
|
||||
return self.scatter_nd_update(self.x, indices, update)
|
||||
|
||||
indices = Tensor(np.array([[0, 1]]), mstype.int32)
|
||||
update = Tensor(np.array([1.0]), mstype.float32)
|
||||
update = Tensor(np.array([1.0], dtype=dtype))
|
||||
|
||||
scatter_nd_update = ScatterNdUpdate()
|
||||
out = scatter_nd_update(indices, update)
|
||||
print("x:\n", out)
|
||||
assert np.allclose(out.asnumpy(), scatter_nd_update.x.data.asnumpy())
|
||||
expect = [[-0.1, 1.0, 3.6], [0.4, 0.5, -3.2]]
|
||||
assert np.allclose(out.asnumpy(), np.array(expect, np.float))
|
||||
assert np.allclose(out.asnumpy(), np.array(expect, dtype=dtype))
|
||||
|
|
Loading…
Reference in New Issue