unified runtime fix the bug of graph output update

This commit is contained in:
limingqi107 2021-07-31 17:31:51 +08:00
parent 9720bab9c9
commit 481060bba1
2 changed files with 18 additions and 5 deletions

View File

@ -320,19 +320,24 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id())); DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id()));
} }
// Execute optimization pass. MS_LOG(INFO) << "Get graph outputs before optimizer, graph id: " << graph->graph_id();
auto outputs_before_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output()); auto outputs_before_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
// Execute optimization pass.
device_context->OptimizeGraph(graph); device_context->OptimizeGraph(graph);
auto outputs_after_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
// Update the output map of kernel graph by modified output nodes.
graph->UpdateGraphOutputMap(outputs_before_optimizer, outputs_after_optimizer);
// Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel, // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
// 'KernelMod' is real executive object of kernel. // 'KernelMod' is real executive object of kernel.
device_context->CreateKernel(graph->execution_order()); device_context->CreateKernel(graph->execution_order());
// Adjust kernel graph before run graph.
device_context->PreprocessBeforeRunGraph(graph); device_context->PreprocessBeforeRunGraph(graph);
MS_LOG(INFO) << "Get graph outputs after optimizer, graph id: " << graph->graph_id();
auto outputs_after_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
// Update the output map of kernel graph by modified output nodes.
graph->UpdateGraphOutputMap(outputs_before_optimizer, outputs_after_optimizer);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) { if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
// Create device address for all anf nodes of graph. // Create device address for all anf nodes of graph.
CreateDeviceAddress(graph, device_context); CreateDeviceAddress(graph, device_context);

View File

@ -800,6 +800,10 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info)
} }
void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info) { void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info) {
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
return;
}
for (const auto &graph : graph_compiler_info.graphs_) { for (const auto &graph : graph_compiler_info.graphs_) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output()); auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
@ -808,6 +812,8 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp
MS_EXCEPTION_IF_NULL(output_kernel); MS_EXCEPTION_IF_NULL(output_kernel);
auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index); auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index);
if (origin_output_with_index.first == nullptr) { if (origin_output_with_index.first == nullptr) {
MS_LOG(WARNING) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
<< " with index: " << output_with_index.second << " has no actor.";
continue; continue;
} }
@ -837,7 +843,9 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp
MS_EXCEPTION_IF_NULL(actor); MS_EXCEPTION_IF_NULL(actor);
MS_LOG(INFO) << "Cache the graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope() MS_LOG(INFO) << "Cache the graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
<< " with index: " << output_with_index.second << " to actor:" << actor->GetAID().Name() << " with index: " << output_with_index.second << " to actor:" << actor->GetAID().Name()
<< " with index:" << actor_output_index; << " with index:" << actor_output_index
<< ", from front node:" << origin_output_with_index.first->fullname_with_scope()
<< " with index: " << origin_output_with_index.second;
(void)graph_output_to_actor_.emplace(origin_output_with_index, GraphOutputPair(actor, actor_output_index)); (void)graph_output_to_actor_.emplace(origin_output_with_index, GraphOutputPair(actor, actor_output_index));
} }
} }