forked from OSSInnovation/mindspore
!2456 Fix BackendCommonOptimization order
Merge pull request !2456 from zhoufeng/xiu-ba-ge
This commit is contained in:
commit
eef762e58a
|
@ -29,6 +29,7 @@
|
|||
#include "device/ascend/ascend_kernel_runtime.h"
|
||||
#include "device/ascend/ascend_device_address.h"
|
||||
#include "pre_activate/ascend/ascend_backend_optimization.h"
|
||||
#include "pre_activate/common/common_backend_optimization.h"
|
||||
#include "device/kernel_adjust.h"
|
||||
#include "device/ascend/ascend_stream_assign.h"
|
||||
#include "device/ascend/ascend_label_assign.h"
|
||||
|
@ -283,36 +284,38 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL
|
|||
|
||||
GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
||||
MS_LOG(INFO) << "start";
|
||||
auto graph = ConstructKernelGraph(func_graph);
|
||||
std::vector<KernelGraphPtr> all_graphs;
|
||||
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
|
||||
BackendOptimization(all_graphs);
|
||||
// split switch
|
||||
SplitGraphs(NOT_NULL(graph));
|
||||
SplitGraphs(NOT_NULL(root_graph));
|
||||
// insert goto labels and label_sets
|
||||
LinkChildGraphs(NOT_NULL(graph));
|
||||
LinkChildGraphs(NOT_NULL(root_graph));
|
||||
// resource initialize
|
||||
InitRuntimeResource();
|
||||
// assign label
|
||||
AssignLabel(NOT_NULL(graph));
|
||||
// recurse compile child graph
|
||||
AssignLabel(NOT_NULL(root_graph));
|
||||
// recurse compile child root_graph
|
||||
std::set<KernelGraphPtr> memo;
|
||||
RecurseCompileGraph(NOT_NULL(graph), NOT_NULL(&memo));
|
||||
// root graph valiate,include genearte execute order and so on
|
||||
RootGraphExecutorValidate(NOT_NULL(graph));
|
||||
RecurseCompileGraph(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||
// root root_graph valiate,include genearte execute order and so on
|
||||
RootGraphExecutorValidate(NOT_NULL(root_graph));
|
||||
// adjust kernel
|
||||
AdjustKernel(graph);
|
||||
AdjustKernel(root_graph);
|
||||
// assign stream
|
||||
AssignStream(graph);
|
||||
AssignStream(root_graph);
|
||||
// insert profiling point
|
||||
device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get()));
|
||||
device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get()));
|
||||
// build kernel
|
||||
BuildKernel(graph);
|
||||
BuildKernel(root_graph);
|
||||
// alloc mem
|
||||
MemoryAlloc(graph.get());
|
||||
MemoryAlloc(root_graph.get());
|
||||
// task generate
|
||||
GenerateTaskInfo(graph);
|
||||
GenerateTaskInfo(root_graph);
|
||||
// load task into device
|
||||
LoadTask(graph);
|
||||
// return the graph id to backend
|
||||
auto graph_id = graph->graph_id();
|
||||
LoadTask(root_graph);
|
||||
// return the root_graph id to backend
|
||||
auto graph_id = root_graph->graph_id();
|
||||
return graph_id;
|
||||
}
|
||||
|
||||
|
@ -1569,6 +1572,14 @@ std::vector<AnfNodePtr> AscendSession::ConstructSplitedGraph(const KernelGraphPt
|
|||
return call_node_inputs;
|
||||
}
|
||||
|
||||
void AscendSession::BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs) {
|
||||
MS_LOG(INFO) << "Start BackendCommonOptimization";
|
||||
for (auto &graph : all_graphs) {
|
||||
opt::BackendCommonOptimization(graph);
|
||||
}
|
||||
MS_LOG(INFO) << "End.";
|
||||
}
|
||||
|
||||
void AscendSession::SplitGraphs(NotNull<KernelGraphPtr> root_graph) {
|
||||
std::set<KernelGraphPtr> memo;
|
||||
// if root graph output is a call node ,the root graph is condition graph of 'if' sentence
|
||||
|
|
|
@ -102,6 +102,7 @@ class AscendSession : public SessionBasic {
|
|||
void SplitGraph(NotNull<KernelGraphPtr> graph, const std::set<PrimitivePtr> &cut_prims);
|
||||
// split graphs with recurse from root graph
|
||||
void SplitGraphs(NotNull<KernelGraphPtr> root_graph);
|
||||
void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
|
||||
void LinkChildGraphs(NotNull<KernelGraphPtr> graph);
|
||||
void RootGraphExecutorValidate(NotNull<KernelGraphPtr> graph);
|
||||
std::vector<AnfNodePtr> ConstructSplitedGraph(const KernelGraphPtr &new_kernel_graph,
|
||||
|
|
|
@ -579,8 +579,10 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con
|
|||
return graph;
|
||||
}
|
||||
|
||||
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
|
||||
std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
|
||||
std::vector<KernelGraphPtr> *all_out_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(all_out_graph);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
auto graph = NewKernelGraph();
|
||||
front_backend_graph_map_[func_graph] = graph;
|
||||
|
@ -607,7 +609,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
|||
if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) {
|
||||
is_trace_back = true;
|
||||
} else {
|
||||
(void)ConstructKernelGraph(child_graph);
|
||||
(void)ConstructKernelGraph(child_graph, all_out_graph);
|
||||
}
|
||||
(void)CreateValueNodeKernelGraph(node, graph.get());
|
||||
}
|
||||
|
@ -634,7 +636,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP
|
|||
if (ExistSummaryNode(graph.get())) {
|
||||
graph->set_summary_node_exist(true);
|
||||
}
|
||||
opt::BackendCommonOptimization(graph);
|
||||
all_out_graph->push_back(graph);
|
||||
return graph;
|
||||
}
|
||||
|
||||
|
|
|
@ -75,7 +75,8 @@ class SessionBasic {
|
|||
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
|
||||
|
||||
std::shared_ptr<KernelGraph> ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs);
|
||||
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph);
|
||||
std::shared_ptr<KernelGraph> ConstructKernelGraph(const FuncGraphPtr &func_graph,
|
||||
std::vector<KernelGraphPtr> *all_out_graph);
|
||||
|
||||
CNodePtr CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph, bool *from_other_graph,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
|
||||
|
|
Loading…
Reference in New Issue