forked from mindspore-Ecosystem/mindspore
Support 'break', 'continue' and 'pass'
To handle 'break' and 'continue' statement, a loop context is pushed to a stack before we parse the loop body, and pop it after body parsed. When a 'break', 'continue' statement is encountered, we retrieve current loop contex from the stack, and let the current block jump to the end block or header block; For 'break' statement, we added an extra 'end_block' follow the 'after_block', because 'after_block' is called from a ContionalJump in 'header_block', it can not be set as jump target from other place. to support 'break', we let loop body jump to the 'end_block' at the 'break' point. and 'after_block' maybe a good place to handle loop 'else' clause in the future. Handle 'pass' is simple, just bypass it when doing parse.
This commit is contained in:
parent
635acb6c27
commit
33fa90efc9
|
@ -193,6 +193,14 @@ class TraceForAfter : public TraceInfo {
|
|||
TraceInfoPtr clone() override { return std::make_shared<TraceForAfter>(*shared_from_base<TraceForAfter>()); }
|
||||
};
|
||||
|
||||
class TraceLoopEnd : public TraceInfo {
|
||||
public:
|
||||
explicit TraceLoopEnd(const DebugInfoPtr &info) : TraceInfo(info, "loop_end", "↓↓") {}
|
||||
MS_DECLARE_PARENT(TraceLoopEnd, TraceInfo);
|
||||
~TraceLoopEnd() override = default;
|
||||
TraceInfoPtr clone() override { return std::make_shared<TraceLoopEnd>(*shared_from_base<TraceLoopEnd>()); }
|
||||
};
|
||||
|
||||
class TraceEquiv : public TraceInfo {
|
||||
public:
|
||||
explicit TraceEquiv(const DebugInfoPtr &info) : TraceInfo(info, "equiv", "equiv") {}
|
||||
|
|
|
@ -89,6 +89,9 @@ void Parser::BuildMethodMap() {
|
|||
stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef;
|
||||
stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign;
|
||||
stmt_method_map_["Global"] = &Parser::ParseGlobal;
|
||||
stmt_method_map_["Break"] = &Parser::ParseBreak;
|
||||
stmt_method_map_["Continue"] = &Parser::ParseContinue;
|
||||
stmt_method_map_["Pass"] = &Parser::ParsePass;
|
||||
expr_method_map_["NoneType"] = &Parser::ParseNone;
|
||||
expr_method_map_["BinOp"] = &Parser::ParseBinOp;
|
||||
expr_method_map_["Name"] = &Parser::ParseName;
|
||||
|
@ -270,6 +273,8 @@ FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr fn_block, const py::ob
|
|||
// insert appropriate depended items for the function block if it has a return node
|
||||
if (fn_block->func_graph()->get_return() != nullptr) {
|
||||
fn_block->InsertDependItemsBeforeReturn();
|
||||
// Skip statements after 'return' (or 'break', 'continue').
|
||||
break;
|
||||
}
|
||||
}
|
||||
return fn_block;
|
||||
|
@ -966,13 +971,24 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj
|
|||
body_block->Mature();
|
||||
header_block->ConditionalJump(condition_node, body_block, after_block);
|
||||
|
||||
// Parse loop body statements with loop context.
|
||||
LoopContext loop_context{&loops_, header_block, nullptr};
|
||||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
||||
FunctionBlockPtr after_body = ParseStatements(body_block, body_node);
|
||||
if (after_body->func_graph()->get_return() == nullptr) {
|
||||
after_body->Jump(header_block, nullptr);
|
||||
}
|
||||
|
||||
header_block->Mature();
|
||||
after_block->Mature();
|
||||
auto &end_block = loop_context.EndBlock();
|
||||
if (end_block) {
|
||||
// end_block exists if we encounter 'break' in loop body.
|
||||
after_block->Jump(end_block, nullptr);
|
||||
end_block->Mature();
|
||||
return end_block;
|
||||
}
|
||||
// No 'break', no end_block.
|
||||
return after_block;
|
||||
}
|
||||
|
||||
|
@ -1049,13 +1065,24 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec
|
|||
body_block->Mature();
|
||||
header_block->ConditionalJump(cond_apply, body_block, after_block);
|
||||
|
||||
// Parse loop body statements with loop context.
|
||||
LoopContext loop_context{&loops_, header_block, iter2_app};
|
||||
py::object body_node = python_adapter::GetPyObjAttr(node, "body");
|
||||
FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
|
||||
if (after_body_block->func_graph()->get_return() == nullptr) {
|
||||
after_body_block->Jump(header_block, iter2_app);
|
||||
}
|
||||
|
||||
header_block->Mature();
|
||||
after_block->Mature();
|
||||
auto &end_block = loop_context.EndBlock();
|
||||
if (end_block) {
|
||||
// end_block exists if we encounter 'break' in loop body.
|
||||
after_block->Jump(end_block, nullptr);
|
||||
end_block->Mature();
|
||||
return end_block;
|
||||
}
|
||||
// No 'break', no end_block.
|
||||
return after_block;
|
||||
}
|
||||
AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) {
|
||||
|
@ -1222,6 +1249,52 @@ FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::ob
|
|||
return block;
|
||||
}
|
||||
|
||||
FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) {
|
||||
if (loops_.empty()) {
|
||||
// Report error if loop context not set for the 'break' statement.
|
||||
py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
|
||||
if (location.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "List size should not be less than 2.";
|
||||
}
|
||||
auto filename = location[0].cast<std::string>();
|
||||
auto line_no = location[1].cast<int>();
|
||||
MS_LOG(EXCEPTION) << "Unexpected 'break' at " << filename << ":" << line_no;
|
||||
}
|
||||
// Get current loop.
|
||||
Loop &loop = loops_.top();
|
||||
if (loop.end == nullptr) {
|
||||
// Create end_block if it is not existed.
|
||||
TraceManager::DebugTrace(std::make_shared<TraceLoopEnd>(block->func_graph()->debug_info()));
|
||||
loop.end = MakeFunctionBlock(*this);
|
||||
TraceManager::EndTrace();
|
||||
}
|
||||
// Jump to the end_block.
|
||||
block->Jump(loop.end, nullptr);
|
||||
return block;
|
||||
}
|
||||
|
||||
FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) {
|
||||
if (loops_.empty()) {
|
||||
// Report error if loop context not set for the 'continue' statement.
|
||||
py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
|
||||
if (location.size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "List size should not be less than 2.";
|
||||
}
|
||||
auto filename = location[0].cast<std::string>();
|
||||
auto line_no = location[1].cast<int>();
|
||||
MS_LOG(EXCEPTION) << "Unexpected 'continue' at " << filename << ":" << line_no;
|
||||
}
|
||||
// Jump to the header of the loop with iterator called.
|
||||
Loop &loop = loops_.top();
|
||||
block->Jump(loop.header, loop.iterator);
|
||||
return block;
|
||||
}
|
||||
|
||||
FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) {
|
||||
// We just bypass 'pass' statement.
|
||||
return block;
|
||||
}
|
||||
|
||||
void Parser::RemoveUnnecessaryPhis() {
|
||||
// merge all removable phis to one map;
|
||||
std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis;
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <string>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <stack>
|
||||
#include <memory>
|
||||
#include "utils/misc.h"
|
||||
#include "ir/anf.h"
|
||||
|
@ -50,6 +51,33 @@ enum ParseStatusCode : int {
|
|||
class AstNodeType;
|
||||
class ParseAst;
|
||||
|
||||
// Save loop info for 'continue' and 'break' statements.
|
||||
struct Loop {
|
||||
// Loop header block.
|
||||
FunctionBlockPtr header;
|
||||
// Loop iterator node, used in 'for loop'.
|
||||
AnfNodePtr iterator;
|
||||
// Loop end block.
|
||||
FunctionBlockPtr end;
|
||||
|
||||
Loop(const FunctionBlockPtr &header, const AnfNodePtr &iterator, const FunctionBlockPtr &end)
|
||||
: header(header), iterator(iterator), end(end) {}
|
||||
~Loop() = default;
|
||||
};
|
||||
|
||||
// Loop context for loop stack management.
|
||||
class LoopContext {
|
||||
public:
|
||||
LoopContext(std::stack<Loop> *loops, const FunctionBlockPtr &header, const AnfNodePtr &iterator) : loops_(loops) {
|
||||
loops_->emplace(header, iterator, nullptr);
|
||||
}
|
||||
~LoopContext() { loops_->pop(); }
|
||||
const FunctionBlockPtr &EndBlock() const { return loops_->top().end; }
|
||||
|
||||
private:
|
||||
std::stack<Loop> *loops_;
|
||||
};
|
||||
|
||||
// Parser to parse python function
|
||||
class Parser {
|
||||
public:
|
||||
|
@ -86,6 +114,12 @@ class Parser {
|
|||
FunctionBlockPtr ParseGlobal(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process assign statement
|
||||
FunctionBlockPtr ParseAssign(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process break statement
|
||||
FunctionBlockPtr ParseBreak(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process continue statement
|
||||
FunctionBlockPtr ParseContinue(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process pass statement
|
||||
FunctionBlockPtr ParsePass(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process the expr and slice node method list
|
||||
AnfNodePtr ParseBinOp(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a variable name
|
||||
|
@ -216,6 +250,8 @@ class Parser {
|
|||
std::map<std::string, pStmtFunc> stmt_method_map_;
|
||||
// define the function map to parse ast expression
|
||||
std::map<std::string, pExprFunc> expr_method_map_;
|
||||
// Save current loops to support 'continue', 'break' statement.
|
||||
std::stack<Loop> loops_;
|
||||
};
|
||||
|
||||
// AST node type define code to ast
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_cont_break """
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore.nn import Cell
|
||||
from mindspore import Tensor, Model, context
|
||||
|
||||
def run_test(netclass, count, dev):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=dev)
|
||||
net = netclass()
|
||||
model = Model(net)
|
||||
for _ in range(count):
|
||||
input_np = np.random.randn(2, 3).astype(np.float32)
|
||||
input_ms = Tensor(input_np)
|
||||
output_np = net.construct(input_np) # run python
|
||||
output_ms = model.predict(input_ms) # run graph
|
||||
np.testing.assert_array_almost_equal(output_np, output_ms.asnumpy(), decimal=3)
|
||||
|
||||
class for_loop_with_break(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
for i in range(8):
|
||||
if i > 5:
|
||||
x *= 3
|
||||
break
|
||||
x = x * 2
|
||||
pass
|
||||
return x
|
||||
|
||||
class for_loop_with_continue(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
for i in range(8):
|
||||
if i > 5:
|
||||
x *= 3
|
||||
continue
|
||||
x = x * 2
|
||||
return x
|
||||
|
||||
class for_loop_with_cont_break(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
for i in range(8):
|
||||
if i < 3:
|
||||
i *= 2
|
||||
continue
|
||||
if i > 5:
|
||||
x *= 3
|
||||
break
|
||||
x *= 2
|
||||
x = x * 2
|
||||
pass
|
||||
return x
|
||||
|
||||
class for_nested_loop_with_break(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
for i in range(3):
|
||||
for j in range(5):
|
||||
if j > 3:
|
||||
x *= 2
|
||||
break
|
||||
x = x * 1.5
|
||||
return x
|
||||
|
||||
class while_with_break(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
i = 0
|
||||
while i < 5:
|
||||
if i > 3:
|
||||
x *= 2
|
||||
break
|
||||
x = x * 1.5
|
||||
i += 1
|
||||
return x
|
||||
|
||||
class while_with_continue(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
i = 0
|
||||
while i < 5:
|
||||
if i > 3:
|
||||
x *= 2
|
||||
i += 1
|
||||
continue
|
||||
x = x * 1.5
|
||||
i += 1
|
||||
return x
|
||||
|
||||
class while_for_nested(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
i = 0
|
||||
while i < 5:
|
||||
if i > 3:
|
||||
for j in range(3):
|
||||
if j > 1:
|
||||
break
|
||||
x *= 2
|
||||
i += 1
|
||||
continue
|
||||
x = x * 1.5
|
||||
i += 1
|
||||
return x
|
||||
|
||||
class pass_branch(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
i = 0
|
||||
while i < 5:
|
||||
if i > 3:
|
||||
pass
|
||||
else:
|
||||
x = x * 1.5
|
||||
i += 1
|
||||
return x
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_cont_break():
|
||||
count = 20
|
||||
dev = 'CPU'
|
||||
run_test(for_loop_with_break, count, dev)
|
||||
run_test(for_loop_with_continue, count, dev)
|
||||
run_test(for_loop_with_cont_break, count, dev)
|
||||
run_test(for_nested_loop_with_break, count, dev)
|
||||
run_test(while_with_break, count, dev)
|
||||
run_test(while_with_continue, count, dev)
|
||||
run_test(while_for_nested, count, dev)
|
||||
run_test(pass_branch, count, dev)
|
||||
|
|
@ -0,0 +1,180 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_cont_break """
|
||||
import numpy as np
|
||||
from mindspore.nn import Cell
|
||||
from mindspore import Tensor, Model, context
|
||||
from ...ut_filter import non_graph_engine
|
||||
|
||||
def run_test(netclass, count):
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = netclass()
|
||||
model = Model(net)
|
||||
for _ in range(count):
|
||||
input_np = np.random.randn(2, 3).astype(np.float32)
|
||||
input_ms = Tensor(input_np)
|
||||
output_np = net.construct(input_np) # run python
|
||||
output_ms = model.predict(input_ms) # run graph
|
||||
assert np.shape(output_np) == np.shape(output_ms.asnumpy())
|
||||
# Disable equal assert because UT in CI use fake backend.
|
||||
# np.testing.assert_array_almost_equal(output_np, output_ms.asnumpy(), decimal=3)
|
||||
|
||||
class for_loop_with_break(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
for i in range(8):
|
||||
if i > 5:
|
||||
x *= 3
|
||||
break
|
||||
x = x * 2
|
||||
pass
|
||||
return x
|
||||
|
||||
@non_graph_engine
|
||||
def test_for_loop_with_break():
|
||||
run_test(for_loop_with_break, 10)
|
||||
|
||||
class for_loop_with_continue(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
for i in range(8):
|
||||
if i > 5:
|
||||
x *= 3
|
||||
continue
|
||||
x = x * 2
|
||||
return x
|
||||
|
||||
@non_graph_engine
|
||||
def test_for_loop_with_continue():
|
||||
run_test(for_loop_with_continue, 10)
|
||||
|
||||
class for_loop_with_cont_break(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
for i in range(8):
|
||||
if i < 3:
|
||||
i *= 2
|
||||
continue
|
||||
if i > 5:
|
||||
x *= 3
|
||||
break
|
||||
x *= 2
|
||||
x = x * 2
|
||||
pass
|
||||
return x
|
||||
|
||||
@non_graph_engine
|
||||
def test_for_loop_with_cont_break():
|
||||
run_test(for_loop_with_cont_break, 10)
|
||||
|
||||
class for_nested_loop_with_break(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
for i in range(3):
|
||||
for j in range(5):
|
||||
if j > 3:
|
||||
x *= 2
|
||||
break
|
||||
x = x * 1.5
|
||||
return x
|
||||
|
||||
@non_graph_engine
|
||||
def test_for_nested_loop_with_break():
|
||||
run_test(for_nested_loop_with_break, 10)
|
||||
|
||||
class while_with_break(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
i = 0
|
||||
while i < 5:
|
||||
if i > 3:
|
||||
x *= 2
|
||||
break
|
||||
x = x * 1.5
|
||||
i += 1
|
||||
return x
|
||||
|
||||
@non_graph_engine
|
||||
def test_while_with_break():
|
||||
run_test(while_with_break, 10)
|
||||
|
||||
class while_with_continue(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
i = 0
|
||||
while i < 5:
|
||||
if i > 3:
|
||||
x *= 2
|
||||
i += 1
|
||||
continue
|
||||
x = x * 1.5
|
||||
i += 1
|
||||
return x
|
||||
|
||||
@non_graph_engine
|
||||
def test_while_with_continue():
|
||||
run_test(while_with_continue, 10)
|
||||
|
||||
class while_for_nested(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
i = 0
|
||||
while i < 5:
|
||||
if i > 3:
|
||||
for j in range(3):
|
||||
if j > 1:
|
||||
break
|
||||
x *= 2
|
||||
i += 1
|
||||
continue
|
||||
x = x * 1.5
|
||||
i += 1
|
||||
return x
|
||||
|
||||
@non_graph_engine
|
||||
def test_while_for_nested():
|
||||
run_test(while_for_nested, 10)
|
||||
|
||||
class pass_branch(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def construct(self, x):
|
||||
i = 0
|
||||
while i < 5:
|
||||
if i > 3:
|
||||
pass
|
||||
else:
|
||||
x = x * 1.5
|
||||
i += 1
|
||||
return x
|
||||
|
||||
@non_graph_engine
|
||||
def test_pass_branch():
|
||||
run_test(pass_branch, 10)
|
Loading…
Reference in New Issue