From cc9fef2baa84ce5aafd724177c8ad8b3d46cd8fe Mon Sep 17 00:00:00 2001 From: wuxuejian Date: Tue, 15 Dec 2020 19:38:28 +0800 Subject: [PATCH] support int32 dtype of Mul cpu op --- .../cpu/arithmetic_cpu_kernel.h | 3 ++ tests/st/ops/cpu/test_mul_op.py | 51 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h index 2f5ce298a00..daf17c2eab9 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h @@ -105,6 +105,9 @@ MS_REG_CPU_KERNEL( MS_REG_CPU_KERNEL( AssignAdd, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ArithmeticCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/tests/st/ops/cpu/test_mul_op.py b/tests/st/ops/cpu/test_mul_op.py index 91e9a0b5fce..907273a1f00 100644 --- a/tests/st/ops/cpu/test_mul_op.py +++ b/tests/st/ops/cpu/test_mul_op.py @@ -84,3 +84,54 @@ def test_mul(): err = np.ones(shape=exp.shape) * 1.0e-5 assert np.all(diff < err) assert out.shape == exp.shape + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_mul_int32(): + x0 = Tensor(np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.int32)) + y0 = Tensor(np.random.uniform(-2, 2, (1, 1, 1, 1)).astype(np.int32)) + x1 = Tensor(np.random.uniform(-2, 2, (1, 3, 1, 4)).astype(np.int32)) + y1 = Tensor(np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.int32)) + x2 = Tensor(np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.int32)) + y2 = Tensor(2, mstype.int32) + x3 = Tensor(2, mstype.int32) + y3 = Tensor(2, mstype.int32) + x4 = Tensor(np.random.uniform(-2, 2, (4)).astype(np.int32)) + y4 = Tensor(np.random.uniform(-2, 2, (4, 4)).astype(np.int32)) + mul = Net() + out = mul(x0, y0).asnumpy() + exp = x0.asnumpy() * y0.asnumpy() + diff = np.abs(out - exp) + err = np.ones(shape=exp.shape) * 1.0e-5 + assert np.all(diff < err) + assert out.shape == exp.shape + + out = mul(x1, y1).asnumpy() + exp = x1.asnumpy() * y1.asnumpy() + diff = np.abs(out - exp) + err = np.ones(shape=exp.shape) * 1.0e-5 + assert np.all(diff < err) + assert out.shape == exp.shape + + out = mul(x2, y2).asnumpy() + exp = x2.asnumpy() * y2.asnumpy() + diff = np.abs(out - exp) + err = np.ones(shape=exp.shape) * 1.0e-5 + assert np.all(diff < err) + assert out.shape == exp.shape + + out = mul(x3, y3).asnumpy() + exp = x3.asnumpy() * y3.asnumpy() + diff = np.abs(out - exp) + err = np.ones(shape=exp.shape) * 1.0e-5 + assert np.all(diff < err) + assert out.shape == exp.shape + + out = mul(x4, y4).asnumpy() + exp = x4.asnumpy() * y4.asnumpy() + diff = np.abs(out - exp) + err = np.ones(shape=exp.shape) * 1.0e-5 + assert np.all(diff < err) + assert out.shape == exp.shape