!5754 adding int32 support for Greater gpu kernel
Merge pull request !5754 from Peilin/add-type-support-for-greater
This commit is contained in:
commit
87d0d0bf8d
|
@ -109,6 +109,9 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
BroadcastOpGpuKernel, half)
|
||||
|
||||
// int32
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Greater, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
Less, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
BroadcastOpGpuKernel, int)
|
||||
|
|
|
@ -43,6 +43,9 @@ def test_nobroadcast():
|
|||
output_ms = P.Greater()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = x1_np > x2_np
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
output_ms = P.Greater()(Tensor(x1_np_int32), Tensor(x2_np_int32))
|
||||
output_np = x1_np_int32 > x2_np_int32
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = x1_np < x2_np
|
||||
|
@ -132,6 +135,9 @@ def test_broadcast():
|
|||
output_ms = P.Greater()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = x1_np > x2_np
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
output_ms = P.Greater()(Tensor(x1_np_int32), Tensor(x2_np_int32))
|
||||
output_np = x1_np_int32 > x2_np_int32
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = x1_np < x2_np
|
||||
|
@ -175,6 +181,9 @@ def test_broadcast_diff_dims():
|
|||
output_ms = P.Maximum()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = np.maximum(x1_np, x2_np)
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
output_ms = P.Greater()(Tensor(x1_np_int32), Tensor(x2_np_int32))
|
||||
output_np = x1_np_int32 > x2_np_int32
|
||||
assert np.allclose(output_ms.asnumpy(), output_np)
|
||||
|
||||
output_ms = P.Greater()(Tensor(x1_np), Tensor(x2_np))
|
||||
output_np = x1_np > x2_np
|
||||
|
|
Loading…
Reference in New Issue