diff --git a/mindspore/ccsrc/debug/trace_info.h b/mindspore/ccsrc/debug/trace_info.h index 85eae0e9580..7d7b7d44b39 100644 --- a/mindspore/ccsrc/debug/trace_info.h +++ b/mindspore/ccsrc/debug/trace_info.h @@ -193,6 +193,14 @@ class TraceForAfter : public TraceInfo { TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } }; +class TraceLoopEnd : public TraceInfo { + public: + explicit TraceLoopEnd(const DebugInfoPtr &info) : TraceInfo(info, "loop_end", "↓↓") {} + MS_DECLARE_PARENT(TraceLoopEnd, TraceInfo); + ~TraceLoopEnd() override = default; + TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } +}; + class TraceEquiv : public TraceInfo { public: explicit TraceEquiv(const DebugInfoPtr &info) : TraceInfo(info, "equiv", "equiv") {} diff --git a/mindspore/ccsrc/pipeline/parse/parse.cc b/mindspore/ccsrc/pipeline/parse/parse.cc index 82d254d2fc9..c815383efc0 100644 --- a/mindspore/ccsrc/pipeline/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/parse/parse.cc @@ -89,6 +89,9 @@ void Parser::BuildMethodMap() { stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef; stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign; stmt_method_map_["Global"] = &Parser::ParseGlobal; + stmt_method_map_["Break"] = &Parser::ParseBreak; + stmt_method_map_["Continue"] = &Parser::ParseContinue; + stmt_method_map_["Pass"] = &Parser::ParsePass; expr_method_map_["NoneType"] = &Parser::ParseNone; expr_method_map_["BinOp"] = &Parser::ParseBinOp; expr_method_map_["Name"] = &Parser::ParseName; @@ -270,6 +273,8 @@ FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr fn_block, const py::ob // insert appropriate depended items for the function block if it has a return node if (fn_block->func_graph()->get_return() != nullptr) { fn_block->InsertDependItemsBeforeReturn(); + // Skip statements after 'return' (or 'break', 'continue'). + break; } } return fn_block; @@ -966,13 +971,24 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj body_block->Mature(); header_block->ConditionalJump(condition_node, body_block, after_block); + // Parse loop body statements with loop context. + LoopContext loop_context{&loops_, header_block, nullptr}; py::object body_node = python_adapter::GetPyObjAttr(node, "body"); FunctionBlockPtr after_body = ParseStatements(body_block, body_node); if (after_body->func_graph()->get_return() == nullptr) { after_body->Jump(header_block, nullptr); } + header_block->Mature(); after_block->Mature(); + auto &end_block = loop_context.EndBlock(); + if (end_block) { + // end_block exists if we encounter 'break' in loop body. + after_block->Jump(end_block, nullptr); + end_block->Mature(); + return end_block; + } + // No 'break', no end_block. return after_block; } @@ -1049,13 +1065,24 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec body_block->Mature(); header_block->ConditionalJump(cond_apply, body_block, after_block); + // Parse loop body statements with loop context. + LoopContext loop_context{&loops_, header_block, iter2_app}; py::object body_node = python_adapter::GetPyObjAttr(node, "body"); FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node); if (after_body_block->func_graph()->get_return() == nullptr) { after_body_block->Jump(header_block, iter2_app); } + header_block->Mature(); after_block->Mature(); + auto &end_block = loop_context.EndBlock(); + if (end_block) { + // end_block exists if we encounter 'break' in loop body. + after_block->Jump(end_block, nullptr); + end_block->Mature(); + return end_block; + } + // No 'break', no end_block. return after_block; } AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) { @@ -1222,6 +1249,52 @@ FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::ob return block; } +FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) { + if (loops_.empty()) { + // Report error if loop context not set for the 'break' statement. + py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); + if (location.size() < 2) { + MS_LOG(EXCEPTION) << "List size should not be less than 2."; + } + auto filename = location[0].cast(); + auto line_no = location[1].cast(); + MS_LOG(EXCEPTION) << "Unexpected 'break' at " << filename << ":" << line_no; + } + // Get current loop. + Loop &loop = loops_.top(); + if (loop.end == nullptr) { + // Create end_block if it is not existed. + TraceManager::DebugTrace(std::make_shared(block->func_graph()->debug_info())); + loop.end = MakeFunctionBlock(*this); + TraceManager::EndTrace(); + } + // Jump to the end_block. + block->Jump(loop.end, nullptr); + return block; +} + +FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) { + if (loops_.empty()) { + // Report error if loop context not set for the 'continue' statement. + py::list location = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); + if (location.size() < 2) { + MS_LOG(EXCEPTION) << "List size should not be less than 2."; + } + auto filename = location[0].cast(); + auto line_no = location[1].cast(); + MS_LOG(EXCEPTION) << "Unexpected 'continue' at " << filename << ":" << line_no; + } + // Jump to the header of the loop with iterator called. + Loop &loop = loops_.top(); + block->Jump(loop.header, loop.iterator); + return block; +} + +FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) { + // We just bypass 'pass' statement. + return block; +} + void Parser::RemoveUnnecessaryPhis() { // merge all removable phis to one map; std::unordered_map removable_phis; diff --git a/mindspore/ccsrc/pipeline/parse/parse.h b/mindspore/ccsrc/pipeline/parse/parse.h index be6b09600c5..969effbd18a 100644 --- a/mindspore/ccsrc/pipeline/parse/parse.h +++ b/mindspore/ccsrc/pipeline/parse/parse.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include "utils/misc.h" #include "ir/anf.h" @@ -50,6 +51,33 @@ enum ParseStatusCode : int { class AstNodeType; class ParseAst; +// Save loop info for 'continue' and 'break' statements. +struct Loop { + // Loop header block. + FunctionBlockPtr header; + // Loop iterator node, used in 'for loop'. + AnfNodePtr iterator; + // Loop end block. + FunctionBlockPtr end; + + Loop(const FunctionBlockPtr &header, const AnfNodePtr &iterator, const FunctionBlockPtr &end) + : header(header), iterator(iterator), end(end) {} + ~Loop() = default; +}; + +// Loop context for loop stack management. +class LoopContext { + public: + LoopContext(std::stack *loops, const FunctionBlockPtr &header, const AnfNodePtr &iterator) : loops_(loops) { + loops_->emplace(header, iterator, nullptr); + } + ~LoopContext() { loops_->pop(); } + const FunctionBlockPtr &EndBlock() const { return loops_->top().end; } + + private: + std::stack *loops_; +}; + // Parser to parse python function class Parser { public: @@ -86,6 +114,12 @@ class Parser { FunctionBlockPtr ParseGlobal(const FunctionBlockPtr &block, const py::object &node); // process assign statement FunctionBlockPtr ParseAssign(const FunctionBlockPtr &block, const py::object &node); + // process break statement + FunctionBlockPtr ParseBreak(const FunctionBlockPtr &block, const py::object &node); + // process continue statement + FunctionBlockPtr ParseContinue(const FunctionBlockPtr &block, const py::object &node); + // process pass statement + FunctionBlockPtr ParsePass(const FunctionBlockPtr &block, const py::object &node); // process the expr and slice node method list AnfNodePtr ParseBinOp(const FunctionBlockPtr &block, const py::object &node); // process a variable name @@ -216,6 +250,8 @@ class Parser { std::map stmt_method_map_; // define the function map to parse ast expression std::map expr_method_map_; + // Save current loops to support 'continue', 'break' statement. + std::stack loops_; }; // AST node type define code to ast diff --git a/tests/st/control/test_cont_break.py b/tests/st/control/test_cont_break.py new file mode 100644 index 00000000000..124ee3efa62 --- /dev/null +++ b/tests/st/control/test_cont_break.py @@ -0,0 +1,162 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_cont_break """ +import pytest +import numpy as np +from mindspore.nn import Cell +from mindspore import Tensor, Model, context + +def run_test(netclass, count, dev): + context.set_context(mode=context.GRAPH_MODE, device_target=dev) + net = netclass() + model = Model(net) + for _ in range(count): + input_np = np.random.randn(2, 3).astype(np.float32) + input_ms = Tensor(input_np) + output_np = net.construct(input_np) # run python + output_ms = model.predict(input_ms) # run graph + np.testing.assert_array_almost_equal(output_np, output_ms.asnumpy(), decimal=3) + +class for_loop_with_break(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + for i in range(8): + if i > 5: + x *= 3 + break + x = x * 2 + pass + return x + +class for_loop_with_continue(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + for i in range(8): + if i > 5: + x *= 3 + continue + x = x * 2 + return x + +class for_loop_with_cont_break(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + for i in range(8): + if i < 3: + i *= 2 + continue + if i > 5: + x *= 3 + break + x *= 2 + x = x * 2 + pass + return x + +class for_nested_loop_with_break(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + for i in range(3): + for j in range(5): + if j > 3: + x *= 2 + break + x = x * 1.5 + return x + +class while_with_break(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + i = 0 + while i < 5: + if i > 3: + x *= 2 + break + x = x * 1.5 + i += 1 + return x + +class while_with_continue(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + i = 0 + while i < 5: + if i > 3: + x *= 2 + i += 1 + continue + x = x * 1.5 + i += 1 + return x + +class while_for_nested(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + i = 0 + while i < 5: + if i > 3: + for j in range(3): + if j > 1: + break + x *= 2 + i += 1 + continue + x = x * 1.5 + i += 1 + return x + +class pass_branch(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + i = 0 + while i < 5: + if i > 3: + pass + else: + x = x * 1.5 + i += 1 + return x + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cont_break(): + count = 20 + dev = 'CPU' + run_test(for_loop_with_break, count, dev) + run_test(for_loop_with_continue, count, dev) + run_test(for_loop_with_cont_break, count, dev) + run_test(for_nested_loop_with_break, count, dev) + run_test(while_with_break, count, dev) + run_test(while_with_continue, count, dev) + run_test(while_for_nested, count, dev) + run_test(pass_branch, count, dev) + diff --git a/tests/ut/python/pipeline/parse/test_cont_break.py b/tests/ut/python/pipeline/parse/test_cont_break.py new file mode 100644 index 00000000000..d556981a7b8 --- /dev/null +++ b/tests/ut/python/pipeline/parse/test_cont_break.py @@ -0,0 +1,180 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_cont_break """ +import numpy as np +from mindspore.nn import Cell +from mindspore import Tensor, Model, context +from ...ut_filter import non_graph_engine + +def run_test(netclass, count): + context.set_context(mode=context.GRAPH_MODE) + net = netclass() + model = Model(net) + for _ in range(count): + input_np = np.random.randn(2, 3).astype(np.float32) + input_ms = Tensor(input_np) + output_np = net.construct(input_np) # run python + output_ms = model.predict(input_ms) # run graph + assert np.shape(output_np) == np.shape(output_ms.asnumpy()) + # Disable equal assert because UT in CI use fake backend. + # np.testing.assert_array_almost_equal(output_np, output_ms.asnumpy(), decimal=3) + +class for_loop_with_break(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + for i in range(8): + if i > 5: + x *= 3 + break + x = x * 2 + pass + return x + +@non_graph_engine +def test_for_loop_with_break(): + run_test(for_loop_with_break, 10) + +class for_loop_with_continue(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + for i in range(8): + if i > 5: + x *= 3 + continue + x = x * 2 + return x + +@non_graph_engine +def test_for_loop_with_continue(): + run_test(for_loop_with_continue, 10) + +class for_loop_with_cont_break(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + for i in range(8): + if i < 3: + i *= 2 + continue + if i > 5: + x *= 3 + break + x *= 2 + x = x * 2 + pass + return x + +@non_graph_engine +def test_for_loop_with_cont_break(): + run_test(for_loop_with_cont_break, 10) + +class for_nested_loop_with_break(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + for i in range(3): + for j in range(5): + if j > 3: + x *= 2 + break + x = x * 1.5 + return x + +@non_graph_engine +def test_for_nested_loop_with_break(): + run_test(for_nested_loop_with_break, 10) + +class while_with_break(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + i = 0 + while i < 5: + if i > 3: + x *= 2 + break + x = x * 1.5 + i += 1 + return x + +@non_graph_engine +def test_while_with_break(): + run_test(while_with_break, 10) + +class while_with_continue(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + i = 0 + while i < 5: + if i > 3: + x *= 2 + i += 1 + continue + x = x * 1.5 + i += 1 + return x + +@non_graph_engine +def test_while_with_continue(): + run_test(while_with_continue, 10) + +class while_for_nested(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + i = 0 + while i < 5: + if i > 3: + for j in range(3): + if j > 1: + break + x *= 2 + i += 1 + continue + x = x * 1.5 + i += 1 + return x + +@non_graph_engine +def test_while_for_nested(): + run_test(while_for_nested, 10) + +class pass_branch(Cell): + def __init__(self): + super().__init__() + + def construct(self, x): + i = 0 + while i < 5: + if i > 3: + pass + else: + x = x * 1.5 + i += 1 + return x + +@non_graph_engine +def test_pass_branch(): + run_test(pass_branch, 10)