forked from mindspore-Ecosystem/mindspore
fix handle builtin func in while body
This commit is contained in:
parent
4ea499991c
commit
b0daae6137
|
@ -1184,7 +1184,7 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
|
|||
if (namespace_info.size() == namespace_info_size) {
|
||||
auto syntax_support = namespace_info[flag_index].cast<int32_t>();
|
||||
SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[symbol_index].cast<std::string>());
|
||||
if (syntax_support == SYNTAX_UNSUPPORTED_NAMESPACE && name_id == symbol->name()) {
|
||||
if (syntax_support != SYNTAX_SUPPORTED && name_id == symbol->name()) {
|
||||
call_cnode->set_interpret(true);
|
||||
call_cnode = HandleInterpret(block, call_cnode, node);
|
||||
}
|
||||
|
|
|
@ -88,24 +88,6 @@ def test_single_while_3():
|
|||
assert res == 7
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support graph fallback feature yet')
|
||||
def test_single_while_5():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_while():
|
||||
i = 0
|
||||
while i <= 3:
|
||||
i += int(1)
|
||||
return i
|
||||
|
||||
res = control_flow_while()
|
||||
assert res == 3
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
|
|
@ -147,3 +147,38 @@ def test_single_while_builtin_function_max_numpy():
|
|||
res_x, res_y = control_flow_while()
|
||||
assert res_x == -5
|
||||
assert res_y == 3
|
||||
|
||||
|
||||
def test_single_while_builtin_function_first_in_while_body():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_while():
|
||||
i = 0
|
||||
while i <= 3:
|
||||
i += int(1)
|
||||
return i
|
||||
|
||||
res = control_flow_while()
|
||||
assert res == 4
|
||||
|
||||
|
||||
def test_single_while_print_in_while_body():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test fallback with control flow.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def control_flow_while():
|
||||
i = 0
|
||||
while i <= 3:
|
||||
i += 1
|
||||
print(i)
|
||||
return i
|
||||
|
||||
res = control_flow_while()
|
||||
assert res == 4
|
||||
|
|
|
@ -377,6 +377,7 @@ def test_return_const_value_with_side_effect_op():
|
|||
Description: Test side effect with returned const value.
|
||||
Expectation: Throw exception.
|
||||
"""
|
||||
|
||||
class Demo(nn.Cell):
|
||||
def construct(self, x):
|
||||
print('print here...')
|
||||
|
@ -386,7 +387,5 @@ def test_return_const_value_with_side_effect_op():
|
|||
|
||||
x = [[1, 2, 3, 4], [5, 6, 7, 8]]
|
||||
net = Demo()
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
output = net(x)
|
||||
print(output)
|
||||
assert "Side Effect Invalid" in str(info.value)
|
||||
output = net(x)
|
||||
assert output == (5, 9, 7, 8)
|
||||
|
|
Loading…
Reference in New Issue