fix bug of supported data type for LessEqual on CPU

This commit is contained in:
shibeiji 2021-01-07 17:12:45 +08:00
parent 9f9c132440
commit 984f31e901
2 changed files with 128 additions and 27 deletions

View File

@ -247,41 +247,16 @@ MS_REG_CPU_KERNEL(
GreaterEqual,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
LessEqual, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
LessEqual, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
LessEqual,
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
LessEqual,
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
LessEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
LessEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
LessEqual,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -32,7 +32,7 @@ class Net(nn.Cell):
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_net():
def test_net_fp32():
x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
@ -81,3 +81,129 @@ def test_net():
expect = x4_np <= y4_np
assert np.all(out == expect)
assert out.shape == expect.shape
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_net_fp16():
x0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16)
y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16)
x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16)
y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float16)
x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(np.float16)
y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float16)
x3_np = np.random.randint(1, 5, 1).astype(np.float16)
y3_np = np.random.randint(1, 5, 1).astype(np.float16)
x4_np = np.array(768).astype(np.float16)
y4_np = np.array(3072.5).astype(np.float16)
x0 = Tensor(x0_np)
y0 = Tensor(y0_np)
x1 = Tensor(x1_np)
y1 = Tensor(y1_np)
x2 = Tensor(x2_np)
y2 = Tensor(y2_np)
x3 = Tensor(x3_np)
y3 = Tensor(y3_np)
x4 = Tensor(x4_np)
y4 = Tensor(y4_np)
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net()
out = net(x0, y0).asnumpy()
expect = x0_np <= y0_np
assert np.all(out == expect)
assert out.shape == expect.shape
out = net(x1, y1).asnumpy()
expect = x1_np <= y1_np
assert np.all(out == expect)
assert out.shape == expect.shape
out = net(x2, y2).asnumpy()
expect = x2_np <= y2_np
assert np.all(out == expect)
assert out.shape == expect.shape
out = net(x3, y3).asnumpy()
expect = x3_np <= y3_np
assert np.all(out == expect)
assert out.shape == expect.shape
out = net(x4, y4).asnumpy()
expect = x4_np <= y4_np
assert np.all(out == expect)
assert out.shape == expect.shape
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_net_int32():
x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.int32)
y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.int32)
x1 = Tensor(x1_np)
y1 = Tensor(y1_np)
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net()
out = net(x1, y1).asnumpy()
expect = x1_np <= y1_np
assert np.all(out == expect)
assert out.shape == expect.shape
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_net_int64():
x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.int64)
y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.int64)
x1 = Tensor(x1_np)
y1 = Tensor(y1_np)
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net()
out = net(x1, y1).asnumpy()
expect = x1_np <= y1_np
assert np.all(out == expect)
assert out.shape == expect.shape
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_net_float64():
x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float64)
y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float64)
x1 = Tensor(x1_np)
y1 = Tensor(y1_np)
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net()
out = net(x1, y1).asnumpy()
expect = x1_np <= y1_np
assert np.all(out == expect)
assert out.shape == expect.shape
@pytest.mark.level0
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_net_int16():
x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.int16)
y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.int16)
x1 = Tensor(x1_np)
y1 = Tensor(y1_np)
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
net = Net()
out = net(x1, y1).asnumpy()
expect = x1_np <= y1_np
assert np.all(out == expect)
assert out.shape == expect.shape