forked from mindspore-Ecosystem/mindspore
equal op gpu kernel
This commit is contained in:
parent
4ce11a930b
commit
54e5025c1a
|
@ -161,7 +161,6 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
BroadcastOpGpuKernel, int)
|
||||
|
||||
// int64
|
||||
// int32
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Greater, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int64_t)
|
||||
|
@ -203,10 +202,16 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
MS_REG_GPU_KERNEL_ONE(
|
||||
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
BroadcastOpGpuKernel, int8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int8_t)
|
||||
|
||||
// uint8
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DivNoNan, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
BroadcastOpGpuKernel, uint8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Equal, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, uint8_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -92,6 +92,21 @@ def test_equal():
|
|||
y5_np = np.array([True, False, False]).astype(bool)
|
||||
y5 = Tensor(y5_np)
|
||||
expect5 = np.equal(x5_np, y5_np)
|
||||
x6_np = np.array([0, 1, 4]).astype(np.int8)
|
||||
x6 = Tensor(x4_np)
|
||||
y6_np = np.array([0, 1, 3]).astype(np.int8)
|
||||
y6 = Tensor(y4_np)
|
||||
expect6 = np.equal(x6_np, y6_np)
|
||||
x7_np = np.array([0, 1, 4]).astype(np.int64)
|
||||
x7 = Tensor(x4_np)
|
||||
y7_np = np.array([0, 1, 3]).astype(np.int64)
|
||||
y7 = Tensor(y4_np)
|
||||
expect7 = np.equal(x7_np, y7_np)
|
||||
x8_np = np.array([0, 1, 4]).astype(np.float16)
|
||||
x8 = Tensor(x4_np)
|
||||
y8_np = np.array([0, 1, 3]).astype(np.float16)
|
||||
y8 = Tensor(y4_np)
|
||||
expect8 = np.equal(x8_np, y8_np)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
equal = NetEqual()
|
||||
|
@ -136,6 +151,15 @@ def test_equal():
|
|||
output5 = equal(x5, y5)
|
||||
assert np.all(output5.asnumpy() == expect5)
|
||||
assert output5.shape == expect5.shape
|
||||
output6 = equal(x6, y6)
|
||||
assert np.all(output6.asnumpy() == expect6)
|
||||
assert output6.shape == expect6.shape
|
||||
output7 = equal(x7, y7)
|
||||
assert np.all(output7.asnumpy() == expect7)
|
||||
assert output7.shape == expect7.shape
|
||||
output8 = equal(x8, y8)
|
||||
assert np.all(output8.asnumpy() == expect8)
|
||||
assert output8.shape == expect8.shape
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
Loading…
Reference in New Issue