!7431 fix control flow when subgraphs merging root graph

Merge pull request !7431 from weiyang/master
This commit is contained in:
mindspore-ci-bot 2020-10-19 11:10:23 +08:00 committed by Gitee
commit b5bfea4e3a
3 changed files with 12 additions and 13 deletions

View File

@ -771,16 +771,13 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
}
std::vector<CNodePtr> execution_order;
uint32_t child_order_index = 0;
auto recurse_child_graph = [&](uint32_t index, uint32_t label_index, const CNodePtr &node) {
if (!CheckLabelIndex(index, label_index, node)) {
KernelGraphPtr cur_child_graph;
if (!CheckLabelIndex(index, label_index, node, &cur_child_graph)) {
MS_LOG(EXCEPTION) << "Check label index fail";
}
if (child_order_index >= graph->child_graph_order().size()) {
MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size();
}
auto child_graph = graph->child_graph_order()[child_order_index++];
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph.lock()), memo);
MS_EXCEPTION_IF_NULL(cur_child_graph);
auto child_execution_order = RecurseGraph(NOT_NULL(cur_child_graph), memo);
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
};
@ -809,18 +806,19 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
return execution_order;
}
bool AscendControlParser::CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cur_label) {
bool AscendControlParser::CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cur_label,
KernelGraphPtr *cur_child_graph) {
auto child_graphs = AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cur_label, kAttrChildGraph);
// check index and child order size
if (child_graphs.size() <= IntToSize(index)) {
MS_LOG(EXCEPTION) << "Child graph index is wrong, current node " << cur_label->ToString() << " child graph size "
<< child_graphs.size() << " goto index " << index;
}
auto child_graph = child_graphs[index];
MS_EXCEPTION_IF_NULL(child_graph);
*cur_child_graph = child_graphs[index];
MS_EXCEPTION_IF_NULL(*cur_child_graph);
// get start_label_set_index of child graph
auto start_label_set = child_graph->get_start_label();
auto start_label_set = (*cur_child_graph)->get_start_label();
uint32_t start_label_set_index = AnfAlgo::GetNodeAttr<uint32_t>(start_label_set, kAttrLabelIndex);
if (label_index != start_label_set_index) {
MS_EXCEPTION_IF_NULL(cur_label);

View File

@ -74,7 +74,8 @@ class AscendControlParser {
static void AttachChildGraphToReturnNode(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo);
// root graph order
static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode);
static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode,
KernelGraphPtr *cur_child_graph);
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph,
const NotNull<std::set<KernelGraphPtr> *> memo);
static void AttachOriginalInputsToGraph(NotNull<KernelGraphPtr> graph, const std::vector<AnfNodePtr> orig_inputs);

View File

@ -240,7 +240,7 @@ class KernelGraph : public FuncGraph {
// valid inputs
std::vector<bool> valid_inputs_;
// child graph execute order in root graph
// child graph execute order in parent graph
std::vector<std::weak_ptr<KernelGraph>> child_graph_order_;
// input_tensors of control parameter