!10006 support int32 dtype of Mul cpu op
From: @wuxuejian Reviewed-by: @kisnwang,@c_34 Signed-off-by: @c_34
This commit is contained in:
commit
f75280f6c2
|
@ -105,6 +105,9 @@ MS_REG_CPU_KERNEL(
|
|||
MS_REG_CPU_KERNEL(
|
||||
AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
ArithmeticCPUKernel);
|
||||
MS_REG_CPU_KERNEL(
|
||||
Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ArithmeticCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -84,3 +84,54 @@ def test_mul():
|
|||
err = np.ones(shape=exp.shape) * 1.0e-5
|
||||
assert np.all(diff < err)
|
||||
assert out.shape == exp.shape
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_mul_int32():
|
||||
x0 = Tensor(np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.int32))
|
||||
y0 = Tensor(np.random.uniform(-2, 2, (1, 1, 1, 1)).astype(np.int32))
|
||||
x1 = Tensor(np.random.uniform(-2, 2, (1, 3, 1, 4)).astype(np.int32))
|
||||
y1 = Tensor(np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.int32))
|
||||
x2 = Tensor(np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.int32))
|
||||
y2 = Tensor(2, mstype.int32)
|
||||
x3 = Tensor(2, mstype.int32)
|
||||
y3 = Tensor(2, mstype.int32)
|
||||
x4 = Tensor(np.random.uniform(-2, 2, (4)).astype(np.int32))
|
||||
y4 = Tensor(np.random.uniform(-2, 2, (4, 4)).astype(np.int32))
|
||||
mul = Net()
|
||||
out = mul(x0, y0).asnumpy()
|
||||
exp = x0.asnumpy() * y0.asnumpy()
|
||||
diff = np.abs(out - exp)
|
||||
err = np.ones(shape=exp.shape) * 1.0e-5
|
||||
assert np.all(diff < err)
|
||||
assert out.shape == exp.shape
|
||||
|
||||
out = mul(x1, y1).asnumpy()
|
||||
exp = x1.asnumpy() * y1.asnumpy()
|
||||
diff = np.abs(out - exp)
|
||||
err = np.ones(shape=exp.shape) * 1.0e-5
|
||||
assert np.all(diff < err)
|
||||
assert out.shape == exp.shape
|
||||
|
||||
out = mul(x2, y2).asnumpy()
|
||||
exp = x2.asnumpy() * y2.asnumpy()
|
||||
diff = np.abs(out - exp)
|
||||
err = np.ones(shape=exp.shape) * 1.0e-5
|
||||
assert np.all(diff < err)
|
||||
assert out.shape == exp.shape
|
||||
|
||||
out = mul(x3, y3).asnumpy()
|
||||
exp = x3.asnumpy() * y3.asnumpy()
|
||||
diff = np.abs(out - exp)
|
||||
err = np.ones(shape=exp.shape) * 1.0e-5
|
||||
assert np.all(diff < err)
|
||||
assert out.shape == exp.shape
|
||||
|
||||
out = mul(x4, y4).asnumpy()
|
||||
exp = x4.asnumpy() * y4.asnumpy()
|
||||
diff = np.abs(out - exp)
|
||||
err = np.ones(shape=exp.shape) * 1.0e-5
|
||||
assert np.all(diff < err)
|
||||
assert out.shape == exp.shape
|
||||
|
|
Loading…
Reference in New Issue