diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 478a8e3852e..2141ee58623 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -1174,29 +1174,36 @@ FuncGraphPtr KernelGraph::GetFuncGraph() { return nullptr; } -void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtr &backend_graph_output, - const AnfNodePtr &front_node) { - if ((backend_graph_output == nullptr) || (front_node == nullptr)) { +void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const std::vector &backend_outputs, + const std::vector &front_outputs) { + MS_LOG(INFO) << "Get graph backend output nodes."; + std::vector backend_output_nodes; + for (auto &backend_output : backend_outputs) { + auto temp_backend_outputs = AnfAlgo::GetAllOutputWithIndex(backend_output); + backend_output_nodes.insert(backend_output_nodes.end(), temp_backend_outputs.begin(), temp_backend_outputs.end()); + } + + MS_LOG(INFO) << "Get graph front output nodes."; + std::vector front_output_nodes; + for (auto &front_output : front_outputs) { + auto temp_front_outputs = AnfAlgo::GetAllOutputWithIndex(front_output); + front_output_nodes.insert(front_output_nodes.end(), temp_front_outputs.begin(), temp_front_outputs.end()); + } + + if (backend_output_nodes.size() != front_output_nodes.size()) { + MS_LOG(WARNING) << "The size(" << backend_output_nodes.size() << ") of backend outputs: " + << " is not equal to the size(" << front_output_nodes.size() << ") of front outputs."; return; } - auto backend_outputs = AnfAlgo::GetAllOutputWithIndex(backend_graph_output); - auto front_outputs = AnfAlgo::GetAllOutputWithIndex(front_node); - if (backend_outputs.size() != front_outputs.size()) { - MS_LOG(INFO) << "The size(" << backend_outputs.size() - << ") of backend output: " << backend_graph_output->DebugString() << " is not equal to the size(" - << front_outputs.size() << ") of front output: " << front_node->DebugString(); - return; - } - - for (size_t i = 0; i < backend_outputs.size(); ++i) { - auto backend_output = backend_outputs[i]; - auto front_output = front_outputs[i]; - graph_output_to_front_node_map_[backend_output] = front_output; - MS_LOG(INFO) << "Backend output: " << backend_output.first->fullname_with_scope() - << " with index: " << backend_output.second - << " map to front node: " << front_output.first->fullname_with_scope() - << " with index: " << front_output.second; + for (size_t i = 0; i < backend_output_nodes.size(); ++i) { + auto backend_output_node = backend_output_nodes[i]; + auto front_output_node = front_output_nodes[i]; + graph_output_to_front_node_map_[backend_output_node] = front_output_node; + MS_LOG(INFO) << "Backend output: " << backend_output_node.first->fullname_with_scope() + << " with index: " << backend_output_node.second + << " map to front node: " << front_output_node.first->fullname_with_scope() + << " with index: " << front_output_node.second; } } @@ -1209,45 +1216,6 @@ AnfWithOutIndex KernelGraph::GetFrontNodeWithIndexByGraphOutput( return AnfWithOutIndex(); } -void KernelGraph::UpdateGraphOutputMap(const std::vector &old_outputs, - const std::vector &new_outputs) { - MS_LOG(INFO) << "The size of old outputs: " << old_outputs.size() - << ", the size of new outputs: " << new_outputs.size(); - if (old_outputs.size() != new_outputs.size()) { - MS_LOG(EXCEPTION) << "The size of old outputs is not equal to the size of new outputs."; - } - - for (size_t i = 0; i < old_outputs.size(); ++i) { - auto old_output = old_outputs[i]; - auto new_output = new_outputs[i]; - if (old_output == new_output) { - continue; - } - // Update the graph output map. - if (graph_output_to_front_node_map_.count(old_output) > 0) { - MS_LOG(INFO) << "Replace backend output node " << old_output.first->fullname_with_scope() << " with index " - << old_output.second << " to " << new_output.first->fullname_with_scope() << " with index " - << new_output.second; - graph_output_to_front_node_map_[new_output] = graph_output_to_front_node_map_[old_output]; - (void)graph_output_to_front_node_map_.erase(old_output); - } - - if (old_output.first == new_output.first) { - continue; - } - // Update the front backend node map. - if ((backend_front_anf_map_.count(old_output.first) > 0) && old_output.first->isa() && - new_output.first->isa()) { - MS_LOG(INFO) << "Replace backend output node " << old_output.first->fullname_with_scope() << " to " - << new_output.first->fullname_with_scope(); - auto front_node = backend_front_anf_map_[old_output.first]; - front_backend_anf_map_[front_node] = new_output.first; - backend_front_anf_map_[new_output.first] = front_node; - (void)backend_front_anf_map_.erase(old_output.first); - } - } -} - AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const { auto iter = front_to_internal_outputs_map_.find(front_node); if (iter != front_to_internal_outputs_map_.end()) { diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 71c04ba57c8..9de7263a6c3 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -251,11 +251,9 @@ class KernelGraph : public FuncGraph { FuncGraphPtr GetFuncGraph(); // Cache the backend graph output nodes and corresponding to front nodes with output index into // graph_output_to_front_node_map_. - void CacheGraphOutputToFrontNodeWithIndex(const AnfNodePtr &backend_graph_output, const AnfNodePtr &front_node); + void CacheGraphOutputToFrontNodeWithIndex(const std::vector &backend_outputs, + const std::vector &front_outputs); AnfWithOutIndex GetFrontNodeWithIndexByGraphOutput(const AnfWithOutIndex &backend_graph_output_with_index) const; - // Update the related map of backend graph output nodes by modified backend output nodes. - void UpdateGraphOutputMap(const std::vector &old_outputs, - const std::vector &new_outputs); uint32_t current_epoch() const { return current_epoch_; } void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; } diff --git a/mindspore/ccsrc/runtime/framework/graph_compiler.cc b/mindspore/ccsrc/runtime/framework/graph_compiler.cc index 9202b0269c0..7a629345287 100644 --- a/mindspore/ccsrc/runtime/framework/graph_compiler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_compiler.cc @@ -318,15 +318,14 @@ GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePt session_->SetInputNodeUsage(graph, manager); graph->SetOptimizerFlag(); - // Cache the backend graph output nodes to front nodes with output index. - for (auto &output : outputs) { - auto backend_node = graph->GetBackendAnfByFrontAnf(output); - if (backend_node != nullptr) { - graph->CacheGraphOutputToFrontNodeWithIndex(backend_node, output); - } - } + auto graph_id = CompileGraphImpl(graph, device_context); - return CompileGraphImpl(graph, device_context); + // Cache the backend graph output nodes to front nodes with output index. + auto backend_node = graph->output(); + MS_EXCEPTION_IF_NULL(backend_node); + graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, outputs); + + return graph_id; } GraphId GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph, const DeviceContext *device_context) { @@ -347,14 +346,16 @@ GraphId GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph, const Device // The graph common optimization. opt::BackendCommonOptimization(root_graph); + auto graph_id = CompileGraphImpl(root_graph, device_context); + // Cache the backend graph output nodes to front nodes with output index. auto output = func_graph->output(); MS_EXCEPTION_IF_NULL(output); - auto backend_node = root_graph->GetBackendAnfByFrontAnf(output); + auto backend_node = root_graph->output(); MS_EXCEPTION_IF_NULL(backend_node); - root_graph->CacheGraphOutputToFrontNodeWithIndex(backend_node, output); + root_graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, {output}); - return CompileGraphImpl(root_graph, device_context); + return graph_id; } GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const { @@ -376,9 +377,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic graph->set_is_executing_sink(is_executing_sink); graph->set_is_loop_count_sink(is_loop_count_sink); - MS_LOG(INFO) << "Get graph outputs before optimizer, graph id: " << graph->graph_id(); - auto outputs_before_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output()); - // Execute optimization pass. device_context->OptimizeGraph(graph); @@ -389,11 +387,6 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic // Adjust kernel graph before run graph. device_context->PreprocessBeforeRunGraph(graph); - MS_LOG(INFO) << "Get graph outputs after optimizer, graph id: " << graph->graph_id(); - auto outputs_after_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output()); - // Update the output map of kernel graph by modified output nodes. - graph->UpdateGraphOutputMap(outputs_before_optimizer, outputs_after_optimizer); - if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { // Create device address for all anf nodes of graph. CreateDeviceAddress(graph, device_context);