forked from mindspore-Ecosystem/mindspore
!48413 修复GPU算子Hswish的st用例
Merge pull request !48413 from hedongdong/hswish
This commit is contained in:
commit
ea725736a6
|
@ -47,28 +47,28 @@ def expect_hswish_forward_result(x):
|
||||||
|
|
||||||
|
|
||||||
def expect_hswish_backward_result(x, dout):
|
def expect_hswish_backward_result(x, dout):
|
||||||
return np.where(x <= -3, 0, np.where(x >= 3, 1, x / 3 + 0.5)) * dout
|
return np.where(x <= -3, 0, np.where(x >= 3, 1, (x * 2 + 3) / 6)) * dout
|
||||||
|
|
||||||
|
|
||||||
def judge_result_correct(result, expect):
|
def judge_result_correct(result, expect, loss):
|
||||||
assert result.dtype == expect.dtype
|
assert result.dtype == expect.dtype
|
||||||
assert result.shape == expect.shape
|
assert result.shape == expect.shape
|
||||||
assert np.allclose(result, expect)
|
assert np.allclose(result, expect, loss, loss)
|
||||||
|
|
||||||
|
|
||||||
def generate_test_cases(np_type, mode):
|
def generate_test_cases(np_type, mode, loss):
|
||||||
context.set_context(mode=mode, device_target="GPU")
|
context.set_context(mode=mode, device_target="GPU")
|
||||||
x = np.array([-1, -2, 0, 4, 5]).astype(np_type)
|
x = np.array([-1, -2, 0, 4, 5]).astype(np_type)
|
||||||
net = Net()
|
net = Net()
|
||||||
output = net(Tensor(x))
|
output = net(Tensor(x))
|
||||||
expect = expect_hswish_forward_result(x)
|
expect = expect_hswish_forward_result(x)
|
||||||
judge_result_correct(output.asnumpy(), expect)
|
judge_result_correct(output.asnumpy(), expect, loss)
|
||||||
|
|
||||||
sens = np.array([-1.45, 0.63, 0.34, 6.43, 34.6]).astype(np_type)
|
sens = np.array([-1.45, 0.63, 0.34, 6.43, 34.6]).astype(np_type)
|
||||||
backward_net = Grad(Net())
|
backward_net = Grad(Net())
|
||||||
output = backward_net(Tensor(x), Tensor(sens))
|
output = backward_net(Tensor(x), Tensor(sens))
|
||||||
expect = expect_hswish_backward_result(x, sens)
|
expect = expect_hswish_backward_result(x, sens)
|
||||||
judge_result_correct(output[0].asnumpy(), expect)
|
judge_result_correct(output[0].asnumpy(), expect, loss)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level1
|
@pytest.mark.level1
|
||||||
|
@ -79,4 +79,5 @@ def test_hswish_forward_and_backward():
|
||||||
dtypes = (np.float32, np.float16)
|
dtypes = (np.float32, np.float16)
|
||||||
for mode in modes:
|
for mode in modes:
|
||||||
for dtype in dtypes:
|
for dtype in dtypes:
|
||||||
generate_test_cases(dtype, mode)
|
loss = 1e-4 if (dtype == np.float32) else 1e-3
|
||||||
|
generate_test_cases(dtype, mode, loss)
|
||||||
|
|
Loading…
Reference in New Issue