forked from mindspore-Ecosystem/mindspore
[control flow]update st testcases for while bp
This commit is contained in:
parent
ce8d2f87f0
commit
971c2d3c6d
|
@ -104,6 +104,10 @@ def test_while_with_const_param_grad():
|
|||
assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001)
|
||||
assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_while_with_variable_grad():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -166,7 +170,10 @@ def test_while_with_param_forward():
|
|||
expect = np.array([[[6, 8], [10, 12]], [[19, 22], [25, 28]]], dtype=np.int32)
|
||||
assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_while_endless_case():
|
||||
"""endless case when optimization"""
|
||||
class MyWhileNet(nn.Cell):
|
||||
|
@ -235,6 +242,10 @@ def test_while_with_param_grad():
|
|||
expect = np.array([[[2, 2], [2, 2]], [[2, 2], [2, 2]]], dtype=np.int32)
|
||||
assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_while_with_param_forward_with_const_branch():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -266,7 +277,10 @@ def test_while_with_param_forward_with_const_branch():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_while_opt_endless():
|
||||
"""endless during optimization case"""
|
||||
class MyWhileNet(nn.Cell):
|
||||
|
@ -307,6 +321,12 @@ def test_while_opt_endless():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="not supported yet")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_no_while_call():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -336,7 +356,10 @@ def test_no_while_call():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_while_with_param_grad_with_const_branch():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -377,6 +400,11 @@ def test_while_with_param_grad_with_const_branch():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.skip(reason="not supported yet")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_for_while_with_param_grad_with_const_branch():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -420,6 +448,10 @@ def test_for_while_with_param_grad_with_const_branch():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_for_while_with_param_grad_basic():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -460,6 +492,10 @@ def test_for_while_with_param_grad_basic():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_for_while_with_param_grad_normal():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -500,6 +536,10 @@ def test_for_while_with_param_grad_normal():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_while_with_param_basic_grad():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -537,6 +577,10 @@ def test_while_with_param_basic_grad():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_while_with_param_basic_grad_mul():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -574,6 +618,10 @@ def test_while_with_param_basic_grad_mul():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_while_with_param_basic_grad_two():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -613,6 +661,10 @@ def test_while_with_param_basic_grad_two():
|
|||
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_while_with_param_basic_grad_three():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -654,6 +706,10 @@ def test_while_with_param_basic_grad_three():
|
|||
assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_while_if_with_param_grad():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -694,6 +750,11 @@ def test_while_if_with_param_grad():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.skip(reason="not supported yet")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_while_with_param_grad_not_enter_while():
|
||||
class MyWhileNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -730,6 +791,10 @@ def test_while_with_param_grad_not_enter_while():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_with_param_if_by_if_forward():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -762,7 +827,10 @@ def test_with_param_if_by_if_forward():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_with_param_if_by_if_grad_inputs():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -801,6 +869,10 @@ def test_with_param_if_by_if_grad_inputs():
|
|||
assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_with_param_if_by_if_grad_parameter():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -838,6 +910,10 @@ def test_with_param_if_by_if_grad_parameter():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_with_param_if_by_if_grad_param_excute_null():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -873,6 +949,10 @@ def test_with_param_if_by_if_grad_param_excute_null():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_by_if_return_inside_grad():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -910,6 +990,10 @@ def test_if_by_if_return_inside_grad():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_by_if_forward():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -948,7 +1032,10 @@ def test_if_by_if_forward():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_by_if_forward_control_tuple_switch():
|
||||
"""tuple_get from switch op will generate new switch inside to eliminate tuple_get"""
|
||||
class Branch3Net(nn.Cell):
|
||||
|
@ -1012,9 +1099,10 @@ def test_if_by_if_forward_control_tuple_switch():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_by_if_forward_control_inside_net():
|
||||
class Branch3Net(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -1077,8 +1165,10 @@ def test_if_by_if_forward_control_inside_net():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_by_if_forward_use_namespace():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -1117,7 +1207,10 @@ def test_if_by_if_forward_use_namespace():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_by_if_forward_use_global_op():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -1160,7 +1253,10 @@ def test_if_by_if_forward_use_global_op():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_for_with_if_by_if_forward():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -1190,8 +1286,10 @@ def test_for_with_if_by_if_forward():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_for_with_if_by_if_forward_namespace():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -1224,7 +1322,10 @@ def test_for_with_if_by_if_forward_namespace():
|
|||
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_by_if_forward_const_branch_inner():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
@ -1267,9 +1368,10 @@ def test_if_by_if_forward_const_branch_inner():
|
|||
pynative_output = net(idx, end, x)
|
||||
assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_if_by_if_forward_all_const_branch():
|
||||
class MyIfByIfNet(nn.Cell):
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue