!25436 unified runtime optimize the graph output

Merge pull request !25436 from limingqi107/new_actor_runtime
This commit is contained in:
i-robot 2021-10-26 11:54:36 +00:00 committed by Gitee
commit c0de00df28
3 changed files with 41 additions and 82 deletions

View File

@ -1174,29 +1174,36 @@ FuncGraphPtr KernelGraph::GetFuncGraph() {
return nullptr;
}
void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtr &backend_graph_output,
const AnfNodePtr &front_node) {
if ((backend_graph_output == nullptr) || (front_node == nullptr)) {
void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const std::vector<AnfNodePtr> &backend_outputs,
const std::vector<AnfNodePtr> &front_outputs) {
MS_LOG(INFO) << "Get graph backend output nodes.";
std::vector<KernelWithIndex> backend_output_nodes;
for (auto &backend_output : backend_outputs) {
auto temp_backend_outputs = AnfAlgo::GetAllOutputWithIndex(backend_output);
backend_output_nodes.insert(backend_output_nodes.end(), temp_backend_outputs.begin(), temp_backend_outputs.end());
}
MS_LOG(INFO) << "Get graph front output nodes.";
std::vector<KernelWithIndex> front_output_nodes;
for (auto &front_output : front_outputs) {
auto temp_front_outputs = AnfAlgo::GetAllOutputWithIndex(front_output);
front_output_nodes.insert(front_output_nodes.end(), temp_front_outputs.begin(), temp_front_outputs.end());
}
if (backend_output_nodes.size() != front_output_nodes.size()) {
MS_LOG(WARNING) << "The size(" << backend_output_nodes.size() << ") of backend outputs: "
<< " is not equal to the size(" << front_output_nodes.size() << ") of front outputs.";
return;
}
auto backend_outputs = AnfAlgo::GetAllOutputWithIndex(backend_graph_output);
auto front_outputs = AnfAlgo::GetAllOutputWithIndex(front_node);
if (backend_outputs.size() != front_outputs.size()) {
MS_LOG(INFO) << "The size(" << backend_outputs.size()
<< ") of backend output: " << backend_graph_output->DebugString() << " is not equal to the size("
<< front_outputs.size() << ") of front output: " << front_node->DebugString();
return;
}
for (size_t i = 0; i < backend_outputs.size(); ++i) {
auto backend_output = backend_outputs[i];
auto front_output = front_outputs[i];
graph_output_to_front_node_map_[backend_output] = front_output;
MS_LOG(INFO) << "Backend output: " << backend_output.first->fullname_with_scope()
<< " with index: " << backend_output.second
<< " map to front node: " << front_output.first->fullname_with_scope()
<< " with index: " << front_output.second;
for (size_t i = 0; i < backend_output_nodes.size(); ++i) {
auto backend_output_node = backend_output_nodes[i];
auto front_output_node = front_output_nodes[i];
graph_output_to_front_node_map_[backend_output_node] = front_output_node;
MS_LOG(INFO) << "Backend output: " << backend_output_node.first->fullname_with_scope()
<< " with index: " << backend_output_node.second
<< " map to front node: " << front_output_node.first->fullname_with_scope()
<< " with index: " << front_output_node.second;
}
}
@ -1209,45 +1216,6 @@ AnfWithOutIndex KernelGraph::GetFrontNodeWithIndexByGraphOutput(
return AnfWithOutIndex();
}
void KernelGraph::UpdateGraphOutputMap(const std::vector<AnfWithOutIndex> &old_outputs,
const std::vector<AnfWithOutIndex> &new_outputs) {
MS_LOG(INFO) << "The size of old outputs: " << old_outputs.size()
<< ", the size of new outputs: " << new_outputs.size();
if (old_outputs.size() != new_outputs.size()) {
MS_LOG(EXCEPTION) << "The size of old outputs is not equal to the size of new outputs.";
}
for (size_t i = 0; i < old_outputs.size(); ++i) {
auto old_output = old_outputs[i];
auto new_output = new_outputs[i];
if (old_output == new_output) {
continue;
}
// Update the graph output map.
if (graph_output_to_front_node_map_.count(old_output) > 0) {
MS_LOG(INFO) << "Replace backend output node " << old_output.first->fullname_with_scope() << " with index "
<< old_output.second << " to " << new_output.first->fullname_with_scope() << " with index "
<< new_output.second;
graph_output_to_front_node_map_[new_output] = graph_output_to_front_node_map_[old_output];
(void)graph_output_to_front_node_map_.erase(old_output);
}
if (old_output.first == new_output.first) {
continue;
}
// Update the front backend node map.
if ((backend_front_anf_map_.count(old_output.first) > 0) && old_output.first->isa<CNode>() &&
new_output.first->isa<CNode>()) {
MS_LOG(INFO) << "Replace backend output node " << old_output.first->fullname_with_scope() << " to "
<< new_output.first->fullname_with_scope();
auto front_node = backend_front_anf_map_[old_output.first];
front_backend_anf_map_[front_node] = new_output.first;
backend_front_anf_map_[new_output.first] = front_node;
(void)backend_front_anf_map_.erase(old_output.first);
}
}
}
AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const {
auto iter = front_to_internal_outputs_map_.find(front_node);
if (iter != front_to_internal_outputs_map_.end()) {

View File

@ -251,11 +251,9 @@ class KernelGraph : public FuncGraph {
FuncGraphPtr GetFuncGraph();
// Cache the backend graph output nodes and corresponding to front nodes with output index into
// graph_output_to_front_node_map_.
void CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtr &backend_graph_output, const AnfNodePtr &front_node);
void CacheGraphOutputToFrontNodeWithIndex(const std::vector<AnfNodePtr> &backend_outputs,
const std::vector<AnfNodePtr> &front_outputs);
AnfWithOutIndex GetFrontNodeWithIndexByGraphOutput(const AnfWithOutIndex &backend_graph_output_with_index) const;
// Update the related map of backend graph output nodes by modified backend output nodes.
void UpdateGraphOutputMap(const std::vector<AnfWithOutIndex> &old_outputs,
const std::vector<AnfWithOutIndex> &new_outputs);
uint32_t current_epoch() const { return current_epoch_; }
void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; }

View File

@ -318,15 +318,14 @@ GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePt
session_->SetInputNodeUsage(graph, manager);
graph->SetOptimizerFlag();
// Cache the backend graph output nodes to front nodes with output index.
for (auto &output : outputs) {
auto backend_node = graph->GetBackendAnfByFrontAnf(output);
if (backend_node != nullptr) {
graph->CacheGraphOutputToFrontNodeWithIndex(backend_node, output);
}
}
auto graph_id = CompileGraphImpl(graph, device_context);
return CompileGraphImpl(graph, device_context);
// Cache the backend graph output nodes to front nodes with output index.
auto backend_node = graph->output();
MS_EXCEPTION_IF_NULL(backend_node);
graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, outputs);
return graph_id;
}
GraphId GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph, const DeviceContext *device_context) {
@ -347,14 +346,16 @@ GraphId GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph, const Device
// The graph common optimization.
opt::BackendCommonOptimization(root_graph);
auto graph_id = CompileGraphImpl(root_graph, device_context);
// Cache the backend graph output nodes to front nodes with output index.
auto output = func_graph->output();
MS_EXCEPTION_IF_NULL(output);
auto backend_node = root_graph->GetBackendAnfByFrontAnf(output);
auto backend_node = root_graph->output();
MS_EXCEPTION_IF_NULL(backend_node);
root_graph->CacheGraphOutputToFrontNodeWithIndex(backend_node, output);
root_graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, {output});
return CompileGraphImpl(root_graph, device_context);
return graph_id;
}
GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
@ -376,9 +377,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
graph->set_is_executing_sink(is_executing_sink);
graph->set_is_loop_count_sink(is_loop_count_sink);
MS_LOG(INFO) << "Get graph outputs before optimizer, graph id: " << graph->graph_id();
auto outputs_before_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
// Execute optimization pass.
device_context->OptimizeGraph(graph);
@ -389,11 +387,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
// Adjust kernel graph before run graph.
device_context->PreprocessBeforeRunGraph(graph);
MS_LOG(INFO) << "Get graph outputs after optimizer, graph id: " << graph->graph_id();
auto outputs_after_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
// Update the output map of kernel graph by modified output nodes.
graph->UpdateGraphOutputMap(outputs_before_optimizer, outputs_after_optimizer);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
// Create device address for all anf nodes of graph.
CreateDeviceAddress(graph, device_context);