!35018 Fix the summary issue for not record summary data in sub graphs

Merge pull request !35018 from maning202007/fix_summary_for_not_setsummarynodes
This commit is contained in:
i-robot 2022-06-01 01:54:01 +00:00 committed by Gitee
commit a5c2eed92e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 30 additions and 0 deletions

View File

@ -2085,6 +2085,17 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
summary_callback_ = callback; summary_callback_ = callback;
} }
void SessionBasic::SetSummaryNodesForAllGraphs(KernelGraph *graph, std::vector<KernelGraphPtr> all_graphs) {
MS_LOG(DEBUG) << "Set summary nodes for all graphs start.";
MS_EXCEPTION_IF_NULL(graph);
auto summary_nodes = graph->summary_nodes();
std::map<std::string, std::pair<AnfNodePtr, int>> summary;
summary.insert(summary_nodes.begin(), summary_nodes.end());
RecurseSetSummaryNodes(graph, all_graphs, &summary);
graph->set_summary_nodes(summary);
MS_LOG(INFO) << "The total summary nodes is: " << summary.size();
}
void SessionBasic::SetSummaryNodes(KernelGraph *graph) { void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
MS_LOG(DEBUG) << "Update summary Start"; MS_LOG(DEBUG) << "Update summary Start";
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
@ -2117,6 +2128,19 @@ void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
MS_LOG(DEBUG) << "Update summary end size: " << summary.size(); MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
} }
void SessionBasic::RecurseSetSummaryNodes(KernelGraph *graph, std::vector<KernelGraphPtr> all_graphs,
std::map<std::string, std::pair<AnfNodePtr, int>> *summary) {
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(summary);
for (auto &child_graph : all_graphs) {
MS_EXCEPTION_IF_NULL(child_graph);
SetSummaryNodes(child_graph.get());
auto child_graph_summary = child_graph->summary_nodes();
summary->insert(child_graph_summary.begin(), child_graph_summary.end());
}
graph->set_summary_nodes(*summary);
}
void SessionBasic::Summary(KernelGraph *graph) { void SessionBasic::Summary(KernelGraph *graph) {
if (summary_callback_ == nullptr) { if (summary_callback_ == nullptr) {
return; return;

View File

@ -274,6 +274,9 @@ class BACKEND_EXPORT SessionBasic : public std::enable_shared_from_this<SessionB
const std::map<KernelWithIndex, size_t> &cnode_refcount) {} const std::map<KernelWithIndex, size_t> &cnode_refcount) {}
#ifndef ENABLE_SECURITY #ifndef ENABLE_SECURITY
virtual void SetSummaryNodes(KernelGraph *graph); virtual void SetSummaryNodes(KernelGraph *graph);
void SetSummaryNodesForAllGraphs(KernelGraph *graph, std::vector<KernelGraphPtr> all_graphs);
void RecurseSetSummaryNodes(KernelGraph *graph, std::vector<KernelGraphPtr> all_graphs,
std::map<std::string, std::pair<AnfNodePtr, int>> *summary);
#endif #endif
void LoadInputs(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs_const) { void LoadInputs(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs_const) {

View File

@ -580,6 +580,9 @@ GraphId GraphCompiler::CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func
auto graph_id = CompileGraphImpl(root_graph, device_context); auto graph_id = CompileGraphImpl(root_graph, device_context);
// Set summary nodes for all graphs.
session_->SetSummaryNodesForAllGraphs(root_graph.get(), all_graphs);
// dump all graphs. // dump all graphs.
// for ascend mindRT. // for ascend mindRT.
session_->DumpGraphs(all_graphs); session_->DumpGraphs(all_graphs);