!10006 support int32 dtype of Mul cpu op

From: @wuxuejian
Reviewed-by: @kisnwang,@c_34
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2020-12-16 16:22:06 +08:00 committed by Gitee
commit f75280f6c2
2 changed files with 54 additions and 0 deletions

View File

@ -105,6 +105,9 @@ MS_REG_CPU_KERNEL(
MS_REG_CPU_KERNEL(
AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -84,3 +84,54 @@ def test_mul():
err = np.ones(shape=exp.shape) * 1.0e-5
assert np.all(diff < err)
assert out.shape == exp.shape
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_mul_int32():
x0 = Tensor(np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.int32))
y0 = Tensor(np.random.uniform(-2, 2, (1, 1, 1, 1)).astype(np.int32))
x1 = Tensor(np.random.uniform(-2, 2, (1, 3, 1, 4)).astype(np.int32))
y1 = Tensor(np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.int32))
x2 = Tensor(np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.int32))
y2 = Tensor(2, mstype.int32)
x3 = Tensor(2, mstype.int32)
y3 = Tensor(2, mstype.int32)
x4 = Tensor(np.random.uniform(-2, 2, (4)).astype(np.int32))
y4 = Tensor(np.random.uniform(-2, 2, (4, 4)).astype(np.int32))
mul = Net()
out = mul(x0, y0).asnumpy()
exp = x0.asnumpy() * y0.asnumpy()
diff = np.abs(out - exp)
err = np.ones(shape=exp.shape) * 1.0e-5
assert np.all(diff < err)
assert out.shape == exp.shape
out = mul(x1, y1).asnumpy()
exp = x1.asnumpy() * y1.asnumpy()
diff = np.abs(out - exp)
err = np.ones(shape=exp.shape) * 1.0e-5
assert np.all(diff < err)
assert out.shape == exp.shape
out = mul(x2, y2).asnumpy()
exp = x2.asnumpy() * y2.asnumpy()
diff = np.abs(out - exp)
err = np.ones(shape=exp.shape) * 1.0e-5
assert np.all(diff < err)
assert out.shape == exp.shape
out = mul(x3, y3).asnumpy()
exp = x3.asnumpy() * y3.asnumpy()
diff = np.abs(out - exp)
err = np.ones(shape=exp.shape) * 1.0e-5
assert np.all(diff < err)
assert out.shape == exp.shape
out = mul(x4, y4).asnumpy()
exp = x4.asnumpy() * y4.asnumpy()
diff = np.abs(out - exp)
err = np.ones(shape=exp.shape) * 1.0e-5
assert np.all(diff < err)
assert out.shape == exp.shape