add test
This commit is contained in:
xiongkun 2023-01-09 09:59:08 +08:00
parent fdc60712df
commit b8b3018125
1 changed files with 9 additions and 15 deletions

View File

@ -68,19 +68,6 @@ def test_reluv2(dtype, mode):
assert np.allclose(dx.asnumpy(), expect_dx)
class ReluForwardNet(nn.Cell):
"""ReluForwardNet"""
def __init__(self):
"""init"""
super(ReluForwardNet, self).__init__()
self.relu = P.ReLU()
def construct(self, x):
"""construct"""
y = self.relu(x)
return y
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -97,13 +84,20 @@ def test_reluv2_uint(dtype, mode):
x = Tensor(np.array([[[[1, 1, 10],
[1, 1, 1],
[10, 1, 1]]]]).astype(dtype))
dy = Tensor(np.array([[[[1, 0, 3],
[0, 1, 0],
[2, 1, 1]]]]).astype(dtype))
expect_y = np.array([[[[1, 1, 10],
[1, 1, 1],
[10, 1, 1.]]]]).astype(dtype)
net = ReluForwardNet()
y = net(Tensor(x))
expect_dx = np.array([[[[1, 0, 3],
[0, 1, 0],
[2, 1, 1]]]]).astype(dtype)
net = ReluNet()
y, dx = net(Tensor(x), Tensor(dy))
assert np.allclose(y.asnumpy(), expect_y)
assert np.allclose(dx.asnumpy(), expect_dx)
class AddReluNet(nn.Cell):