forked from mindspore-Ecosystem/mindspore
handle summary to adapt new control sink
This commit is contained in:
parent
e32d539b5f
commit
edbe02dfec
|
@ -745,17 +745,19 @@ GraphId AscendSession::SetFinalGraphInput(const std::vector<AnfNodePtr> &args) {
|
||||||
return final_graph_id_;
|
return final_graph_id_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void AscendSession::GetSummaryNodes(KernelGraph *graph) {
|
void AscendSession::RecurseGetSummaryNodes(KernelGraph *graph,
|
||||||
MS_LOG(DEBUG) << "Update summary Start";
|
std::map<std::string, std::pair<AnfNodePtr, int>> *summary) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(summary);
|
||||||
// if final graph have no child graph
|
// if final graph have no child graph
|
||||||
auto graph_order_iter = graph_execute_orders_.find(graph->graph_id());
|
auto graph_order_iter = graph_execute_orders_.find(graph->graph_id());
|
||||||
if (graph_order_iter == graph_execute_orders_.end()) {
|
if (graph_order_iter == graph_execute_orders_.end()) {
|
||||||
SessionBasic::GetSummaryNodes(graph);
|
SessionBasic::GetSummaryNodes(graph);
|
||||||
|
auto summary_nodes = graph->summary_nodes();
|
||||||
|
(*summary).insert(summary_nodes.begin(), summary_nodes.end());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// for every child graph, find summary nodes
|
// for every child graph, find summary nodes
|
||||||
auto summary = graph->summary_nodes();
|
|
||||||
auto graph_order = GetGraphOrder(graph->graph_id());
|
auto graph_order = GetGraphOrder(graph->graph_id());
|
||||||
for (size_t i = 0; i < graph_order.size(); i++) {
|
for (size_t i = 0; i < graph_order.size(); i++) {
|
||||||
auto child_graph = GetGraph(graph_order[i]);
|
auto child_graph = GetGraph(graph_order[i]);
|
||||||
|
@ -764,8 +766,19 @@ void AscendSession::GetSummaryNodes(KernelGraph *graph) {
|
||||||
}
|
}
|
||||||
SessionBasic::GetSummaryNodes(child_graph.get());
|
SessionBasic::GetSummaryNodes(child_graph.get());
|
||||||
auto child_graph_summary = child_graph->summary_nodes();
|
auto child_graph_summary = child_graph->summary_nodes();
|
||||||
summary.insert(child_graph_summary.begin(), child_graph_summary.end());
|
(*summary).insert(child_graph_summary.begin(), child_graph_summary.end());
|
||||||
|
RecurseGetSummaryNodes(child_graph.get(), summary);
|
||||||
}
|
}
|
||||||
|
graph->set_summary_nodes(*summary);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AscendSession::GetSummaryNodes(KernelGraph *graph) {
|
||||||
|
MS_LOG(DEBUG) << "Update summary Start";
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
auto summary_nodes = graph->summary_nodes();
|
||||||
|
std::map<std::string, std::pair<AnfNodePtr, int>> summary;
|
||||||
|
summary.insert(summary_nodes.begin(), summary_nodes.end());
|
||||||
|
RecurseGetSummaryNodes(graph, &summary);
|
||||||
graph->set_summary_nodes(summary);
|
graph->set_summary_nodes(summary);
|
||||||
MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
|
MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,7 +67,8 @@ class AscendSession : public SessionBasic {
|
||||||
void SetActive(GraphId, GraphId) override;
|
void SetActive(GraphId, GraphId) override;
|
||||||
// compile child graph when session have multiple child graphs
|
// compile child graph when session have multiple child graphs
|
||||||
void CompileChildGraph(const KernelGraphPtr &child_graph);
|
void CompileChildGraph(const KernelGraphPtr &child_graph);
|
||||||
void GetSummaryNodes(KernelGraph *graph) override;
|
void RecurseGetSummaryNodes(KernelGraph *graph, std::map<std::string, std::pair<AnfNodePtr, int>> *summary);
|
||||||
|
void GetSummaryNodes(KernelGraph *graph);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitRuntimeResource();
|
void InitRuntimeResource();
|
||||||
|
|
Loading…
Reference in New Issue