forked from mindspore-Ecosystem/mindspore
!48413 修复GPU算子Hswish的st用例
Merge pull request !48413 from hedongdong/hswish
This commit is contained in:
commit
ea725736a6
|
@ -47,28 +47,28 @@ def expect_hswish_forward_result(x):
|
|||
|
||||
|
||||
def expect_hswish_backward_result(x, dout):
|
||||
return np.where(x <= -3, 0, np.where(x >= 3, 1, x / 3 + 0.5)) * dout
|
||||
return np.where(x <= -3, 0, np.where(x >= 3, 1, (x * 2 + 3) / 6)) * dout
|
||||
|
||||
|
||||
def judge_result_correct(result, expect):
|
||||
def judge_result_correct(result, expect, loss):
|
||||
assert result.dtype == expect.dtype
|
||||
assert result.shape == expect.shape
|
||||
assert np.allclose(result, expect)
|
||||
assert np.allclose(result, expect, loss, loss)
|
||||
|
||||
|
||||
def generate_test_cases(np_type, mode):
|
||||
def generate_test_cases(np_type, mode, loss):
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
x = np.array([-1, -2, 0, 4, 5]).astype(np_type)
|
||||
net = Net()
|
||||
output = net(Tensor(x))
|
||||
expect = expect_hswish_forward_result(x)
|
||||
judge_result_correct(output.asnumpy(), expect)
|
||||
judge_result_correct(output.asnumpy(), expect, loss)
|
||||
|
||||
sens = np.array([-1.45, 0.63, 0.34, 6.43, 34.6]).astype(np_type)
|
||||
backward_net = Grad(Net())
|
||||
output = backward_net(Tensor(x), Tensor(sens))
|
||||
expect = expect_hswish_backward_result(x, sens)
|
||||
judge_result_correct(output[0].asnumpy(), expect)
|
||||
judge_result_correct(output[0].asnumpy(), expect, loss)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
@ -79,4 +79,5 @@ def test_hswish_forward_and_backward():
|
|||
dtypes = (np.float32, np.float16)
|
||||
for mode in modes:
|
||||
for dtype in dtypes:
|
||||
generate_test_cases(dtype, mode)
|
||||
loss = 1e-4 if (dtype == np.float32) else 1e-3
|
||||
generate_test_cases(dtype, mode, loss)
|
||||
|
|
Loading…
Reference in New Issue