!48413 修复GPU算子Hswish的st用例

Merge pull request !48413 from hedongdong/hswish
This commit is contained in:
i-robot 2023-02-06 01:46:09 +00:00 committed by Gitee
commit ea725736a6
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 8 additions and 7 deletions

View File

@ -47,28 +47,28 @@ def expect_hswish_forward_result(x):
def expect_hswish_backward_result(x, dout): def expect_hswish_backward_result(x, dout):
return np.where(x <= -3, 0, np.where(x >= 3, 1, x / 3 + 0.5)) * dout return np.where(x <= -3, 0, np.where(x >= 3, 1, (x * 2 + 3) / 6)) * dout
def judge_result_correct(result, expect): def judge_result_correct(result, expect, loss):
assert result.dtype == expect.dtype assert result.dtype == expect.dtype
assert result.shape == expect.shape assert result.shape == expect.shape
assert np.allclose(result, expect) assert np.allclose(result, expect, loss, loss)
def generate_test_cases(np_type, mode): def generate_test_cases(np_type, mode, loss):
context.set_context(mode=mode, device_target="GPU") context.set_context(mode=mode, device_target="GPU")
x = np.array([-1, -2, 0, 4, 5]).astype(np_type) x = np.array([-1, -2, 0, 4, 5]).astype(np_type)
net = Net() net = Net()
output = net(Tensor(x)) output = net(Tensor(x))
expect = expect_hswish_forward_result(x) expect = expect_hswish_forward_result(x)
judge_result_correct(output.asnumpy(), expect) judge_result_correct(output.asnumpy(), expect, loss)
sens = np.array([-1.45, 0.63, 0.34, 6.43, 34.6]).astype(np_type) sens = np.array([-1.45, 0.63, 0.34, 6.43, 34.6]).astype(np_type)
backward_net = Grad(Net()) backward_net = Grad(Net())
output = backward_net(Tensor(x), Tensor(sens)) output = backward_net(Tensor(x), Tensor(sens))
expect = expect_hswish_backward_result(x, sens) expect = expect_hswish_backward_result(x, sens)
judge_result_correct(output[0].asnumpy(), expect) judge_result_correct(output[0].asnumpy(), expect, loss)
@pytest.mark.level1 @pytest.mark.level1
@ -79,4 +79,5 @@ def test_hswish_forward_and_backward():
dtypes = (np.float32, np.float16) dtypes = (np.float32, np.float16)
for mode in modes: for mode in modes:
for dtype in dtypes: for dtype in dtypes:
generate_test_cases(dtype, mode) loss = 1e-4 if (dtype == np.float32) else 1e-3
generate_test_cases(dtype, mode, loss)