support if break else break situation

add trace guard
This commit is contained in:
chenfei 2021-09-06 19:47:13 +08:00
parent 790bfeb292
commit f6fe46e469
3 changed files with 43 additions and 10 deletions

View File

@ -96,6 +96,7 @@ void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr
// Read variable from predecessors
AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) {
MS_LOG(DEBUG) << "Read begin, var: " << var << ", block id: " << func_graph_->debug_info()->debug_id();
// Get var node if it is found
auto found = vars_.find(var);
if (found != vars_.end()) {
@ -169,6 +170,7 @@ AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr) {
// Make a resolve node for symbol string
AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) {
MS_LOG(DEBUG) << "value: " << value;
if (value.compare(0, strlen("self"), "self") == 0) {
auto start = value.find_first_of('.') + 1;
if (start >= value.size()) {
@ -248,12 +250,11 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) {
AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) {
AnfNodePtr arg_node = nullptr;
MS_LOG(DEBUG) << "Prev_blocks size: " << prev_blocks_.size();
for (auto &prev : prev_blocks_) {
MS_EXCEPTION_IF_NULL(prev);
AnfNodePtr temp_node = prev->ReadVariable(var);
MS_EXCEPTION_IF_NULL(temp_node);
MS_LOG(DEBUG) << "graph " << (prev->func_graph_ ? prev->func_graph_->ToString() : "FG(Null)") << " phi "
<< (phi ? phi->ToString() : "null") << " for var " << var << " is " << temp_node->DebugString();
if (temp_node != phi) {
if (arg_node == nullptr) {
arg_node = temp_node;
@ -362,7 +363,14 @@ CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) {
// Perform a jump from this block to target block
void FunctionBlock::Jump(const FunctionBlockPtr &target_block, const std::vector<AnfNodePtr> &args) {
MS_LOG(DEBUG) << "Jump from " << func_graph_->debug_info()->debug_id() << " to "
<< target_block->func_graph()->debug_info()->debug_id();
MS_EXCEPTION_IF_NULL(target_block);
if (is_dead_block_) {
MS_LOG(DEBUG) << "Dead code block should not jump to other block! Block id:"
<< func_graph_->debug_info()->debug_id();
return;
}
if (func_graph_->get_return() != nullptr) {
MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: "
<< trace::GetDebugInfo(func_graph_->get_return()->debug_info());
@ -497,5 +505,7 @@ void FunctionBlock::AttachIsolatedNodesBeforeReturn() {
<< ", state: " << state->DebugString(2);
func_graph_->set_output(depend_node, true);
}
void FunctionBlock::SetAsDeadBlock() { is_dead_block_ = true; }
} // namespace parse
} // namespace mindspore

View File

@ -45,7 +45,7 @@ using FunctionBlockPtr = std::shared_ptr<FunctionBlock>;
class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
public:
explicit FunctionBlock(const Parser &parser);
virtual ~FunctionBlock() {}
virtual ~FunctionBlock() = default;
FuncGraphPtr func_graph() { return func_graph_; }
void WriteVariable(const std::string &var_name, const AnfNodePtr &node);
@ -74,6 +74,9 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
void FindIsolatedNodes();
void AddIsolatedNode(const AnfNodePtr &target);
void AttachIsolatedNodesBeforeReturn();
const std::vector<FunctionBlock *> &prev_blocks() const { return prev_blocks_; }
bool is_dead_block() const { return is_dead_block_; }
void SetAsDeadBlock();
private:
// Block graph
@ -116,6 +119,16 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> {
// Isolated nodes.
OrderedSet<AnfNodePtr> isolated_nodes_;
// If a block can never be executed, it's prev blocks will be empty, so this block is a dead block.
// while x > 5:
// x = x - 2
// if x > 7 :
// break
// else :
// break
// x = x - 1 #This after block is a dead block
bool is_dead_block_{false};
};
} // namespace parse

View File

@ -164,7 +164,8 @@ void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &as
}
py::str desc =
python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]);
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ". "
<< "Func graph id: " << func_graph->debug_info()->debug_id();
}
}
@ -323,11 +324,17 @@ FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::objec
size_t count = py::len(node_list);
MS_LOG(DEBUG) << "The nodes count is " << count;
for (size_t i = 0; i < count; ++i) {
MS_LOG(DEBUG) << "Start parse statement[" << i << "]: " << py::str(node_list[i]);
auto node = node_list[i];
block = ParseStatement(block, node);
MS_EXCEPTION_IF_NULL(block);
// Insert appropriate depended items for the function block if it has a return node
if (block->func_graph()->get_return() != nullptr) {
if (block->func_graph()->get_return() != nullptr || block->is_dead_block()) {
// If break is not the last expr.
if (i != count - 1) {
TraceGuard trace_guard(GetLocation(node_list[i + 1]));
MS_LOG(EXCEPTION) << "Dead code exist, please remove it.";
}
// Skip statements after 'return' (or 'break', 'continue').
break;
}
@ -359,7 +366,7 @@ FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py:
}
AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast expr";
MS_LOG(DEBUG) << "Process ast expr.";
TraceGuard trace_guard(GetLocation(node));
auto node_type = ast_->GetNodeType(node);
// Check the node type
@ -1043,7 +1050,6 @@ FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py:
MS_LOG(DEBUG) << "Process ast AugAssign";
MS_EXCEPTION_IF_NULL(block);
MS_EXCEPTION_IF_NULL(ast_);
py::object target_obj = python_adapter::GetPyObjAttr(node, "target");
py::object op_obj = python_adapter::GetPyObjAttr(node, "op");
py::object value_obj = python_adapter::GetPyObjAttr(node, "value");
@ -1066,6 +1072,7 @@ FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py:
}
CNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node});
WriteAssignVars(block, target_obj, augassign_app);
return block;
}
// Process global declaration such as 'global x';
@ -1119,6 +1126,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
// If the return_ is set, it has its own continuation block
if (true_end->func_graph()->get_return() == nullptr) {
MS_LOG(DEBUG) << "true end jump to after.";
true_end->Jump(after_block, {});
}
@ -1128,10 +1136,14 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
// If the return_ is set, it has its own continuation block
if (false_end->func_graph()->get_return() == nullptr) {
MS_LOG(DEBUG) << "false_end jump to after.";
false_end->Jump(after_block, {});
}
block->ConditionalJump(bool_node, true_block, false_block);
if (after_block->prev_blocks().empty()) {
after_block->SetAsDeadBlock();
}
after_block->Mature();
return after_block;
}
@ -1139,7 +1151,6 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object
FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast While";
MS_EXCEPTION_IF_NULL(block);
MS_LOG(INFO) << "Parse while statement";
FunctionBlockPtr header_block = nullptr;
FunctionBlockPtr body_block = nullptr;
FunctionBlockPtr after_block = nullptr;
@ -1173,12 +1184,11 @@ FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::obj
if (after_body->func_graph()->get_return() == nullptr) {
after_body->Jump(header_block, {});
}
header_block->Mature();
after_block->Mature();
auto &end_block = loop_context.EndBlock();
// end_block exists if we encounter 'break' in loop body.
if (end_block) {
// end_block exists if we encounter 'break' in loop body.
after_block->Jump(end_block, {});
end_block->Mature();
return end_block;