forked from mindspore-Ecosystem/mindspore
!2446 Fix BackendCommonOptimization order
Merge pull request !2446 from zhoufeng/xiu-ba-ge
This commit is contained in:
commit
fa2166970a
|
@ -29,6 +29,7 @@
|
||||||
#include "device/ascend/ascend_kernel_runtime.h"
|
#include "device/ascend/ascend_kernel_runtime.h"
|
||||||
#include "device/ascend/ascend_device_address.h"
|
#include "device/ascend/ascend_device_address.h"
|
||||||
#include "pre_activate/ascend/ascend_backend_optimization.h"
|
#include "pre_activate/ascend/ascend_backend_optimization.h"
|
||||||
|
#include "pre_activate/common/common_backend_optimization.h"
|
||||||
#include "device/kernel_adjust.h"
|
#include "device/kernel_adjust.h"
|
||||||
#include "device/ascend/ascend_stream_assign.h"
|
#include "device/ascend/ascend_stream_assign.h"
|
||||||
#include "device/ascend/ascend_label_assign.h"
|
#include "device/ascend/ascend_label_assign.h"
|
||||||
|
@ -283,36 +284,38 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
|
||||||
|
|
||||||
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
||||||
MS_LOG(INFO) << "start";
|
MS_LOG(INFO) << "start";
|
||||||
auto graph = ConstructKernelGraph(func_graph);
|
std::vector<KernelGraphPtr> all_graphs;
|
||||||
|
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
|
||||||
|
BackendOptimization(all_graphs);
|
||||||
// split switch
|
// split switch
|
||||||
SplitGraphs(NOT_NULL(graph));
|
SplitGraphs(NOT_NULL(root_graph));
|
||||||
// insert goto labels and label_sets
|
// insert goto labels and label_sets
|
||||||
LinkChildGraphs(NOT_NULL(graph));
|
LinkChildGraphs(NOT_NULL(root_graph));
|
||||||
// resource initialize
|
// resource initialize
|
||||||
InitRuntimeResource();
|
InitRuntimeResource();
|
||||||
// assign label
|
// assign label
|
||||||
AssignLabel(NOT_NULL(graph));
|
AssignLabel(NOT_NULL(root_graph));
|
||||||
// recurse compile child graph
|
// recurse compile child root_graph
|
||||||
std::set<KernelGraphPtr> memo;
|
std::set<KernelGraphPtr> memo;
|
||||||
RecurseCompileGraph(NOT_NULL(graph), NOT_NULL(&memo));
|
RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||||
// root graph valiate,include genearte execute order and so on
|
// root root_graph valiate,include genearte execute order and so on
|
||||||
RootGraphExecutorValidate(NOT_NULL(graph));
|
RootGraphExecutorValidate(NOT_NULL(root_graph));
|
||||||
// adjust kernel
|
// adjust kernel
|
||||||
AdjustKernel(graph);
|
AdjustKernel(root_graph);
|
||||||
// assign stream
|
// assign stream
|
||||||
AssignStream(graph);
|
AssignStream(root_graph);
|
||||||
// insert profiling point
|
// insert profiling point
|
||||||
device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get()));
|
device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get()));
|
||||||
// build kernel
|
// build kernel
|
||||||
BuildKernel(graph);
|
BuildKernel(root_graph);
|
||||||
// alloc mem
|
// alloc mem
|
||||||
MemoryAlloc(graph.get());
|
MemoryAlloc(root_graph.get());
|
||||||
// task generate
|
// task generate
|
||||||
GenerateTaskInfo(graph);
|
GenerateTaskInfo(root_graph);
|
||||||
// load task into device
|
// load task into device
|
||||||
LoadTask(graph);
|
LoadTask(root_graph);
|
||||||
// return the graph id to backend
|
// return the root_graph id to backend
|
||||||
auto graph_id = graph->graph_id();
|
auto graph_id = root_graph->graph_id();
|
||||||
return graph_id;
|
return graph_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1569,6 +1572,14 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
|
||||||
return call_node_inputs;
|
return call_node_inputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void AscendSession::BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs) {
|
||||||
|
MS_LOG(INFO) << "Start BackendCommonOptimization";
|
||||||
|
for (auto &graph : all_graphs) {
|
||||||
|
opt::BackendCommonOptimization(graph);
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "End.";
|
||||||
|
}
|
||||||
|
|
||||||
void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) {
|
void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) {
|
||||||
std::set<KernelGraphPtr> memo;
|
std::set<KernelGraphPtr> memo;
|
||||||
// if root graph output is a call node ,the root graph is condition graph of 'if' sentence
|
// if root graph output is a call node ,the root graph is condition graph of 'if' sentence
|
||||||
|
|
|
@ -102,6 +102,7 @@ class AscendSession : public SessionBasic {
|
||||||
void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims);
|
void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims);
|
||||||
// split graphs with recurse from root graph
|
// split graphs with recurse from root graph
|
||||||
void SplitGraphs(NotNull<KernelGraphPtr> root_graph);
|
void SplitGraphs(NotNull<KernelGraphPtr> root_graph);
|
||||||
|
void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
|
||||||
void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
|
void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
|
||||||
void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph);
|
void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph);
|
||||||
std::vector<AnfNodePtr> ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
|
std::vector<AnfNodePtr> ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
|
||||||
|
|
|
@ -579,8 +579,10 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
||||||
return graph;
|
return graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
|
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
|
||||||
|
std::vector<KernelGraphPtr> *all_out_graph) {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(all_out_graph);
|
||||||
auto node_list = TopoSort(func_graph->get_return());
|
auto node_list = TopoSort(func_graph->get_return());
|
||||||
auto graph = NewKernelGraph();
|
auto graph = NewKernelGraph();
|
||||||
front_backend_graph_map_[func_graph] = graph;
|
front_backend_graph_map_[func_graph] = graph;
|
||||||
|
@ -607,7 +609,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
||||||
if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) {
|
if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) {
|
||||||
is_trace_back = true;
|
is_trace_back = true;
|
||||||
} else {
|
} else {
|
||||||
(void)ConstructKernelGraph(child_graph);
|
(void)ConstructKernelGraph(child_graph, all_out_graph);
|
||||||
}
|
}
|
||||||
(void)CreateValueNodeKernelGraph(node, graph.get());
|
(void)CreateValueNodeKernelGraph(node, graph.get());
|
||||||
}
|
}
|
||||||
|
@ -634,7 +636,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
||||||
if (ExistSummaryNode(graph.get())) {
|
if (ExistSummaryNode(graph.get())) {
|
||||||
graph->set_summary_node_exist(true);
|
graph->set_summary_node_exist(true);
|
||||||
}
|
}
|
||||||
opt::BackendCommonOptimization(graph);
|
all_out_graph->push_back(graph);
|
||||||
return graph;
|
return graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -75,7 +75,8 @@ class SessionBasic {
|
||||||
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
|
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
|
||||||
|
|
||||||
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
|
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
|
||||||
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph);
|
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
|
||||||
|
std::vector<KernelGraphPtr> *all_out_graph);
|
||||||
|
|
||||||
CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph,
|
CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph,
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
|
||||||
|
|
Loading…
Reference in New Issue