forked from mindspore-Ecosystem/mindspore
make gpu equal op support int32
This commit is contained in:
parent
6d7e352524
commit
7479fb24a0
|
@ -22,6 +22,7 @@ equal_op_info = AkgGpuRegOp("Equal") \
|
|||
.output(0, "output") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
|
@ -60,6 +60,11 @@ def test_equal():
|
|||
y1_np = np.array([0, 1, -3]).astype(np.float32)
|
||||
y1 = Tensor(y1_np)
|
||||
expect1 = np.equal(x1_np, y1_np)
|
||||
x2_np = np.array([0, 1, 3]).astype(np.int32)
|
||||
x2 = Tensor(x2_np)
|
||||
y2_np = np.array([0, 1, -3]).astype(np.int32)
|
||||
y2 = Tensor(y2_np)
|
||||
expect2 = np.equal(x2_np, y2_np)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
equal = NetEqual()
|
||||
|
@ -69,6 +74,9 @@ def test_equal():
|
|||
output1 = equal(x1, y1)
|
||||
assert np.all(output1.asnumpy() == expect1)
|
||||
assert output1.shape == expect1.shape
|
||||
output2 = equal(x2, y2)
|
||||
assert np.all(output2.asnumpy() == expect2)
|
||||
assert output2.shape == expect2.shape
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
equal = NetEqual()
|
||||
|
@ -78,6 +86,9 @@ def test_equal():
|
|||
output1 = equal(x1, y1)
|
||||
assert np.all(output1.asnumpy() == expect1)
|
||||
assert output1.shape == expect1.shape
|
||||
output2 = equal(x2, y2)
|
||||
assert np.all(output2.asnumpy() == expect2)
|
||||
assert output2.shape == expect2.shape
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
Loading…
Reference in New Issue