forked from mindspore-Ecosystem/mindspore
!25436 unified runtime optimize the graph output
Merge pull request !25436 from limingqi107/new_actor_runtime
This commit is contained in:
commit
c0de00df28
|
@ -1174,29 +1174,36 @@ FuncGraphPtr KernelGraph::GetFuncGraph() {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtr &backend_graph_output,
|
||||
const AnfNodePtr &front_node) {
|
||||
if ((backend_graph_output == nullptr) || (front_node == nullptr)) {
|
||||
void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const std::vector<AnfNodePtr> &backend_outputs,
|
||||
const std::vector<AnfNodePtr> &front_outputs) {
|
||||
MS_LOG(INFO) << "Get graph backend output nodes.";
|
||||
std::vector<KernelWithIndex> backend_output_nodes;
|
||||
for (auto &backend_output : backend_outputs) {
|
||||
auto temp_backend_outputs = AnfAlgo::GetAllOutputWithIndex(backend_output);
|
||||
backend_output_nodes.insert(backend_output_nodes.end(), temp_backend_outputs.begin(), temp_backend_outputs.end());
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Get graph front output nodes.";
|
||||
std::vector<KernelWithIndex> front_output_nodes;
|
||||
for (auto &front_output : front_outputs) {
|
||||
auto temp_front_outputs = AnfAlgo::GetAllOutputWithIndex(front_output);
|
||||
front_output_nodes.insert(front_output_nodes.end(), temp_front_outputs.begin(), temp_front_outputs.end());
|
||||
}
|
||||
|
||||
if (backend_output_nodes.size() != front_output_nodes.size()) {
|
||||
MS_LOG(WARNING) << "The size(" << backend_output_nodes.size() << ") of backend outputs: "
|
||||
<< " is not equal to the size(" << front_output_nodes.size() << ") of front outputs.";
|
||||
return;
|
||||
}
|
||||
|
||||
auto backend_outputs = AnfAlgo::GetAllOutputWithIndex(backend_graph_output);
|
||||
auto front_outputs = AnfAlgo::GetAllOutputWithIndex(front_node);
|
||||
if (backend_outputs.size() != front_outputs.size()) {
|
||||
MS_LOG(INFO) << "The size(" << backend_outputs.size()
|
||||
<< ") of backend output: " << backend_graph_output->DebugString() << " is not equal to the size("
|
||||
<< front_outputs.size() << ") of front output: " << front_node->DebugString();
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < backend_outputs.size(); ++i) {
|
||||
auto backend_output = backend_outputs[i];
|
||||
auto front_output = front_outputs[i];
|
||||
graph_output_to_front_node_map_[backend_output] = front_output;
|
||||
MS_LOG(INFO) << "Backend output: " << backend_output.first->fullname_with_scope()
|
||||
<< " with index: " << backend_output.second
|
||||
<< " map to front node: " << front_output.first->fullname_with_scope()
|
||||
<< " with index: " << front_output.second;
|
||||
for (size_t i = 0; i < backend_output_nodes.size(); ++i) {
|
||||
auto backend_output_node = backend_output_nodes[i];
|
||||
auto front_output_node = front_output_nodes[i];
|
||||
graph_output_to_front_node_map_[backend_output_node] = front_output_node;
|
||||
MS_LOG(INFO) << "Backend output: " << backend_output_node.first->fullname_with_scope()
|
||||
<< " with index: " << backend_output_node.second
|
||||
<< " map to front node: " << front_output_node.first->fullname_with_scope()
|
||||
<< " with index: " << front_output_node.second;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1209,45 +1216,6 @@ AnfWithOutIndex KernelGraph::GetFrontNodeWithIndexByGraphOutput(
|
|||
return AnfWithOutIndex();
|
||||
}
|
||||
|
||||
void KernelGraph::UpdateGraphOutputMap(const std::vector<AnfWithOutIndex> &old_outputs,
|
||||
const std::vector<AnfWithOutIndex> &new_outputs) {
|
||||
MS_LOG(INFO) << "The size of old outputs: " << old_outputs.size()
|
||||
<< ", the size of new outputs: " << new_outputs.size();
|
||||
if (old_outputs.size() != new_outputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of old outputs is not equal to the size of new outputs.";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < old_outputs.size(); ++i) {
|
||||
auto old_output = old_outputs[i];
|
||||
auto new_output = new_outputs[i];
|
||||
if (old_output == new_output) {
|
||||
continue;
|
||||
}
|
||||
// Update the graph output map.
|
||||
if (graph_output_to_front_node_map_.count(old_output) > 0) {
|
||||
MS_LOG(INFO) << "Replace backend output node " << old_output.first->fullname_with_scope() << " with index "
|
||||
<< old_output.second << " to " << new_output.first->fullname_with_scope() << " with index "
|
||||
<< new_output.second;
|
||||
graph_output_to_front_node_map_[new_output] = graph_output_to_front_node_map_[old_output];
|
||||
(void)graph_output_to_front_node_map_.erase(old_output);
|
||||
}
|
||||
|
||||
if (old_output.first == new_output.first) {
|
||||
continue;
|
||||
}
|
||||
// Update the front backend node map.
|
||||
if ((backend_front_anf_map_.count(old_output.first) > 0) && old_output.first->isa<CNode>() &&
|
||||
new_output.first->isa<CNode>()) {
|
||||
MS_LOG(INFO) << "Replace backend output node " << old_output.first->fullname_with_scope() << " to "
|
||||
<< new_output.first->fullname_with_scope();
|
||||
auto front_node = backend_front_anf_map_[old_output.first];
|
||||
front_backend_anf_map_[front_node] = new_output.first;
|
||||
backend_front_anf_map_[new_output.first] = front_node;
|
||||
(void)backend_front_anf_map_.erase(old_output.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const {
|
||||
auto iter = front_to_internal_outputs_map_.find(front_node);
|
||||
if (iter != front_to_internal_outputs_map_.end()) {
|
||||
|
|
|
@ -251,11 +251,9 @@ class KernelGraph : public FuncGraph {
|
|||
FuncGraphPtr GetFuncGraph();
|
||||
// Cache the backend graph output nodes and corresponding to front nodes with output index into
|
||||
// graph_output_to_front_node_map_.
|
||||
void CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtr &backend_graph_output, const AnfNodePtr &front_node);
|
||||
void CacheGraphOutputToFrontNodeWithIndex(const std::vector<AnfNodePtr> &backend_outputs,
|
||||
const std::vector<AnfNodePtr> &front_outputs);
|
||||
AnfWithOutIndex GetFrontNodeWithIndexByGraphOutput(const AnfWithOutIndex &backend_graph_output_with_index) const;
|
||||
// Update the related map of backend graph output nodes by modified backend output nodes.
|
||||
void UpdateGraphOutputMap(const std::vector<AnfWithOutIndex> &old_outputs,
|
||||
const std::vector<AnfWithOutIndex> &new_outputs);
|
||||
|
||||
uint32_t current_epoch() const { return current_epoch_; }
|
||||
void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; }
|
||||
|
|
|
@ -318,15 +318,14 @@ GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePt
|
|||
session_->SetInputNodeUsage(graph, manager);
|
||||
graph->SetOptimizerFlag();
|
||||
|
||||
// Cache the backend graph output nodes to front nodes with output index.
|
||||
for (auto &output : outputs) {
|
||||
auto backend_node = graph->GetBackendAnfByFrontAnf(output);
|
||||
if (backend_node != nullptr) {
|
||||
graph->CacheGraphOutputToFrontNodeWithIndex(backend_node, output);
|
||||
}
|
||||
}
|
||||
auto graph_id = CompileGraphImpl(graph, device_context);
|
||||
|
||||
return CompileGraphImpl(graph, device_context);
|
||||
// Cache the backend graph output nodes to front nodes with output index.
|
||||
auto backend_node = graph->output();
|
||||
MS_EXCEPTION_IF_NULL(backend_node);
|
||||
graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, outputs);
|
||||
|
||||
return graph_id;
|
||||
}
|
||||
|
||||
GraphId GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph, const DeviceContext *device_context) {
|
||||
|
@ -347,14 +346,16 @@ GraphId GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph, const Device
|
|||
// The graph common optimization.
|
||||
opt::BackendCommonOptimization(root_graph);
|
||||
|
||||
auto graph_id = CompileGraphImpl(root_graph, device_context);
|
||||
|
||||
// Cache the backend graph output nodes to front nodes with output index.
|
||||
auto output = func_graph->output();
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
auto backend_node = root_graph->GetBackendAnfByFrontAnf(output);
|
||||
auto backend_node = root_graph->output();
|
||||
MS_EXCEPTION_IF_NULL(backend_node);
|
||||
root_graph->CacheGraphOutputToFrontNodeWithIndex(backend_node, output);
|
||||
root_graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, {output});
|
||||
|
||||
return CompileGraphImpl(root_graph, device_context);
|
||||
return graph_id;
|
||||
}
|
||||
|
||||
GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
|
||||
|
@ -376,9 +377,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
|||
graph->set_is_executing_sink(is_executing_sink);
|
||||
graph->set_is_loop_count_sink(is_loop_count_sink);
|
||||
|
||||
MS_LOG(INFO) << "Get graph outputs before optimizer, graph id: " << graph->graph_id();
|
||||
auto outputs_before_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||
|
||||
// Execute optimization pass.
|
||||
device_context->OptimizeGraph(graph);
|
||||
|
||||
|
@ -389,11 +387,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
|||
// Adjust kernel graph before run graph.
|
||||
device_context->PreprocessBeforeRunGraph(graph);
|
||||
|
||||
MS_LOG(INFO) << "Get graph outputs after optimizer, graph id: " << graph->graph_id();
|
||||
auto outputs_after_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||
// Update the output map of kernel graph by modified output nodes.
|
||||
graph->UpdateGraphOutputMap(outputs_before_optimizer, outputs_after_optimizer);
|
||||
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
// Create device address for all anf nodes of graph.
|
||||
CreateDeviceAddress(graph, device_context);
|
||||
|
|
Loading…
Reference in New Issue