fix handle builtin func in while body

This commit is contained in:
huanghui 2022-06-02 15:03:14 +08:00
parent 4ea499991c
commit b0daae6137
4 changed files with 39 additions and 23 deletions

View File

@ -1184,7 +1184,7 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
if (namespace_info.size() == namespace_info_size) {
auto syntax_support = namespace_info[flag_index].cast<int32_t>();
SymbolPtr symbol = std::make_shared<Symbol>(namespace_info[symbol_index].cast<std::string>());
if (syntax_support == SYNTAX_UNSUPPORTED_NAMESPACE && name_id == symbol->name()) {
if (syntax_support != SYNTAX_SUPPORTED && name_id == symbol->name()) {
call_cnode->set_interpret(true);
call_cnode = HandleInterpret(block, call_cnode, node);
}

View File

@ -88,24 +88,6 @@ def test_single_while_3():
assert res == 7
@pytest.mark.skip(reason='Not support graph fallback feature yet')
def test_single_while_5():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def control_flow_while():
i = 0
while i <= 3:
i += int(1)
return i
res = control_flow_while()
assert res == 3
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training

View File

@ -147,3 +147,38 @@ def test_single_while_builtin_function_max_numpy():
res_x, res_y = control_flow_while()
assert res_x == -5
assert res_y == 3
def test_single_while_builtin_function_first_in_while_body():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def control_flow_while():
i = 0
while i <= 3:
i += int(1)
return i
res = control_flow_while()
assert res == 4
def test_single_while_print_in_while_body():
"""
Feature: JIT Fallback
Description: Test fallback with control flow.
Expectation: No exception.
"""
@ms_function
def control_flow_while():
i = 0
while i <= 3:
i += 1
print(i)
return i
res = control_flow_while()
assert res == 4

View File

@ -377,6 +377,7 @@ def test_return_const_value_with_side_effect_op():
Description: Test side effect with returned const value.
Expectation: Throw exception.
"""
class Demo(nn.Cell):
def construct(self, x):
print('print here...')
@ -386,7 +387,5 @@ def test_return_const_value_with_side_effect_op():
x = [[1, 2, 3, 4], [5, 6, 7, 8]]
net = Demo()
with pytest.raises(RuntimeError) as info:
output = net(x)
print(output)
assert "Side Effect Invalid" in str(info.value)
output = net(x)
assert output == (5, 9, 7, 8)