forked from mindspore-Ecosystem/mindspore
!7431 fix control flow when subgraphs merging root graph
Merge pull request !7431 from weiyang/master
This commit is contained in:
commit
b5bfea4e3a
|
@ -771,16 +771,13 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
|
|||
}
|
||||
|
||||
std::vector<CNodePtr> execution_order;
|
||||
uint32_t child_order_index = 0;
|
||||
auto recurse_child_graph = [&](uint32_t index, uint32_t label_index, const CNodePtr &node) {
|
||||
if (!CheckLabelIndex(index, label_index, node)) {
|
||||
KernelGraphPtr cur_child_graph;
|
||||
if (!CheckLabelIndex(index, label_index, node, &cur_child_graph)) {
|
||||
MS_LOG(EXCEPTION) << "Check label index fail";
|
||||
}
|
||||
if (child_order_index >= graph->child_graph_order().size()) {
|
||||
MS_LOG(EXCEPTION) << "Index out of range:" << graph->child_graph_order().size();
|
||||
}
|
||||
auto child_graph = graph->child_graph_order()[child_order_index++];
|
||||
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph.lock()), memo);
|
||||
MS_EXCEPTION_IF_NULL(cur_child_graph);
|
||||
auto child_execution_order = RecurseGraph(NOT_NULL(cur_child_graph), memo);
|
||||
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
|
||||
};
|
||||
|
||||
|
@ -809,18 +806,19 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr>
|
|||
return execution_order;
|
||||
}
|
||||
|
||||
bool AscendControlParser::CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cur_label) {
|
||||
bool AscendControlParser::CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cur_label,
|
||||
KernelGraphPtr *cur_child_graph) {
|
||||
auto child_graphs = AnfAlgo::GetNodeAttr<std::vector<KernelGraphPtr>>(cur_label, kAttrChildGraph);
|
||||
// check index and child order size
|
||||
if (child_graphs.size() <= IntToSize(index)) {
|
||||
MS_LOG(EXCEPTION) << "Child graph index is wrong, current node " << cur_label->ToString() << " child graph size "
|
||||
<< child_graphs.size() << " goto index " << index;
|
||||
}
|
||||
auto child_graph = child_graphs[index];
|
||||
MS_EXCEPTION_IF_NULL(child_graph);
|
||||
*cur_child_graph = child_graphs[index];
|
||||
MS_EXCEPTION_IF_NULL(*cur_child_graph);
|
||||
|
||||
// get start_label_set_index of child graph
|
||||
auto start_label_set = child_graph->get_start_label();
|
||||
auto start_label_set = (*cur_child_graph)->get_start_label();
|
||||
uint32_t start_label_set_index = AnfAlgo::GetNodeAttr<uint32_t>(start_label_set, kAttrLabelIndex);
|
||||
if (label_index != start_label_set_index) {
|
||||
MS_EXCEPTION_IF_NULL(cur_label);
|
||||
|
|
|
@ -74,7 +74,8 @@ class AscendControlParser {
|
|||
static void AttachChildGraphToReturnNode(NotNull<KernelGraphPtr> graph,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
// root graph order
|
||||
static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode);
|
||||
static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode,
|
||||
KernelGraphPtr *cur_child_graph);
|
||||
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
static void AttachOriginalInputsToGraph(NotNull<KernelGraphPtr> graph, const std::vector<AnfNodePtr> orig_inputs);
|
||||
|
|
|
@ -240,7 +240,7 @@ class KernelGraph : public FuncGraph {
|
|||
// valid inputs
|
||||
std::vector<bool> valid_inputs_;
|
||||
|
||||
// child graph execute order in root graph
|
||||
// child graph execute order in parent graph
|
||||
std::vector<std::weak_ptr<KernelGraph>> child_graph_order_;
|
||||
|
||||
// input_tensors of control parameter
|
||||
|
|
Loading…
Reference in New Issue