diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index ce421d03c75..4abbdab7c40 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -745,17 +745,19 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector &args) { return final_graph_id_; } -void AscendSession::GetSummaryNodes(KernelGraph *graph) { - MS_LOG(DEBUG) << "Update summary Start"; +void AscendSession::RecurseGetSummaryNodes(KernelGraph *graph, + std::map> *summary) { MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(summary); // if final graph have no child graph auto graph_order_iter = graph_execute_orders_.find(graph->graph_id()); if (graph_order_iter == graph_execute_orders_.end()) { SessionBasic::GetSummaryNodes(graph); + auto summary_nodes = graph->summary_nodes(); + (*summary).insert(summary_nodes.begin(), summary_nodes.end()); return; } // for every child graph, find summary nodes - auto summary = graph->summary_nodes(); auto graph_order = GetGraphOrder(graph->graph_id()); for (size_t i = 0; i < graph_order.size(); i++) { auto child_graph = GetGraph(graph_order[i]); @@ -764,8 +766,19 @@ void AscendSession::GetSummaryNodes(KernelGraph *graph) { } SessionBasic::GetSummaryNodes(child_graph.get()); auto child_graph_summary = child_graph->summary_nodes(); - summary.insert(child_graph_summary.begin(), child_graph_summary.end()); + (*summary).insert(child_graph_summary.begin(), child_graph_summary.end()); + RecurseGetSummaryNodes(child_graph.get(), summary); } + graph->set_summary_nodes(*summary); +} + +void AscendSession::GetSummaryNodes(KernelGraph *graph) { + MS_LOG(DEBUG) << "Update summary Start"; + MS_EXCEPTION_IF_NULL(graph); + auto summary_nodes = graph->summary_nodes(); + std::map> summary; + summary.insert(summary_nodes.begin(), summary_nodes.end()); + RecurseGetSummaryNodes(graph, &summary); graph->set_summary_nodes(summary); MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); } diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index 55eb4546334..e035f84c9db 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -67,7 +67,8 @@ class AscendSession : public SessionBasic { void SetActive(GraphId, GraphId) override; // compile child graph when session have multiple child graphs void CompileChildGraph(const KernelGraphPtr &child_graph); - void GetSummaryNodes(KernelGraph *graph) override; + void RecurseGetSummaryNodes(KernelGraph *graph, std::map> *summary); + void GetSummaryNodes(KernelGraph *graph); private: void InitRuntimeResource();