equal op gpu kernel

This commit is contained in:
jonwe 2020-12-10 13:53:10 -05:00
parent 4ce11a930b
commit 54e5025c1a
2 changed files with 30 additions and 1 deletions

View File

@ -161,7 +161,6 @@ MS_REG_GPU_KERNEL_ONE(
BroadcastOpGpuKernel, int)
// int64
// int32
MS_REG_GPU_KERNEL_ONE(
Greater, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int64_t)
@ -203,10 +202,16 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
BroadcastOpGpuKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int8_t)
// uint8
MS_REG_GPU_KERNEL_ONE(
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
BroadcastOpGpuKernel, uint8_t)
MS_REG_GPU_KERNEL_ONE(
Equal, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, uint8_t)
} // namespace kernel
} // namespace mindspore

View File

@ -92,6 +92,21 @@ def test_equal():
y5_np = np.array([True, False, False]).astype(bool)
y5 = Tensor(y5_np)
expect5 = np.equal(x5_np, y5_np)
x6_np = np.array([0, 1, 4]).astype(np.int8)
x6 = Tensor(x4_np)
y6_np = np.array([0, 1, 3]).astype(np.int8)
y6 = Tensor(y4_np)
expect6 = np.equal(x6_np, y6_np)
x7_np = np.array([0, 1, 4]).astype(np.int64)
x7 = Tensor(x4_np)
y7_np = np.array([0, 1, 3]).astype(np.int64)
y7 = Tensor(y4_np)
expect7 = np.equal(x7_np, y7_np)
x8_np = np.array([0, 1, 4]).astype(np.float16)
x8 = Tensor(x4_np)
y8_np = np.array([0, 1, 3]).astype(np.float16)
y8 = Tensor(y4_np)
expect8 = np.equal(x8_np, y8_np)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
equal = NetEqual()
@ -136,6 +151,15 @@ def test_equal():
output5 = equal(x5, y5)
assert np.all(output5.asnumpy() == expect5)
assert output5.shape == expect5.shape
output6 = equal(x6, y6)
assert np.all(output6.asnumpy() == expect6)
assert output6.shape == expect6.shape
output7 = equal(x7, y7)
assert np.all(output7.asnumpy() == expect7)
assert output7.shape == expect7.shape
output8 = equal(x8, y8)
assert np.all(output8.asnumpy() == expect8)
assert output8.shape == expect8.shape
@pytest.mark.level0