From dfaa82a463219b9ef45772f3940d7683515971c5 Mon Sep 17 00:00:00 2001 From: maning202007 Date: Thu, 26 May 2022 21:54:14 +0800 Subject: [PATCH] Fix the summary issue for not record summary data in sub graphs --- .../backend/common/session/session_basic.cc | 24 +++++++++++++++++++ .../backend/common/session/session_basic.h | 3 +++ .../runtime/graph_scheduler/graph_compiler.cc | 3 +++ 3 files changed, 30 insertions(+) diff --git a/mindspore/ccsrc/backend/common/session/session_basic.cc b/mindspore/ccsrc/backend/common/session/session_basic.cc index d79894f7dce..57a0f7fb001 100644 --- a/mindspore/ccsrc/backend/common/session/session_basic.cc +++ b/mindspore/ccsrc/backend/common/session/session_basic.cc @@ -2085,6 +2085,17 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) { summary_callback_ = callback; } +void SessionBasic::SetSummaryNodesForAllGraphs(KernelGraph *graph, std::vector all_graphs) { + MS_LOG(DEBUG) << "Set summary nodes for all graphs start."; + MS_EXCEPTION_IF_NULL(graph); + auto summary_nodes = graph->summary_nodes(); + std::map> summary; + summary.insert(summary_nodes.begin(), summary_nodes.end()); + RecurseSetSummaryNodes(graph, all_graphs, &summary); + graph->set_summary_nodes(summary); + MS_LOG(INFO) << "The total summary nodes is: " << summary.size(); +} + void SessionBasic::SetSummaryNodes(KernelGraph *graph) { MS_LOG(DEBUG) << "Update summary Start"; MS_EXCEPTION_IF_NULL(graph); @@ -2117,6 +2128,19 @@ void SessionBasic::SetSummaryNodes(KernelGraph *graph) { MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); } +void SessionBasic::RecurseSetSummaryNodes(KernelGraph *graph, std::vector all_graphs, + std::map> *summary) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(summary); + for (auto &child_graph : all_graphs) { + MS_EXCEPTION_IF_NULL(child_graph); + SetSummaryNodes(child_graph.get()); + auto child_graph_summary = child_graph->summary_nodes(); + summary->insert(child_graph_summary.begin(), child_graph_summary.end()); + } + graph->set_summary_nodes(*summary); +} + void SessionBasic::Summary(KernelGraph *graph) { if (summary_callback_ == nullptr) { return; diff --git a/mindspore/ccsrc/backend/common/session/session_basic.h b/mindspore/ccsrc/backend/common/session/session_basic.h index f70e0277416..09620f048a3 100644 --- a/mindspore/ccsrc/backend/common/session/session_basic.h +++ b/mindspore/ccsrc/backend/common/session/session_basic.h @@ -274,6 +274,9 @@ class BACKEND_EXPORT SessionBasic : public std::enable_shared_from_this &cnode_refcount) {} #ifndef ENABLE_SECURITY virtual void SetSummaryNodes(KernelGraph *graph); + void SetSummaryNodesForAllGraphs(KernelGraph *graph, std::vector all_graphs); + void RecurseSetSummaryNodes(KernelGraph *graph, std::vector all_graphs, + std::map> *summary); #endif void LoadInputs(const GraphId &graph_id, const std::vector &inputs_const) { diff --git a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc index 5885a0ed6d6..615ca5ba39e 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc @@ -492,6 +492,9 @@ GraphId GraphCompiler::CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func auto graph_id = CompileGraphImpl(root_graph, device_context); + // Set summary nodes for all graphs. + session_->SetSummaryNodesForAllGraphs(root_graph.get(), all_graphs); + // dump all graphs. // for ascend mindRT. session_->DumpGraphs(all_graphs);