forked from mindspore-Ecosystem/mindspore
unified runtime fix the bug of graph output update
This commit is contained in:
parent
9720bab9c9
commit
481060bba1
|
@ -320,19 +320,24 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
||||||
DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id()));
|
DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute optimization pass.
|
MS_LOG(INFO) << "Get graph outputs before optimizer, graph id: " << graph->graph_id();
|
||||||
auto outputs_before_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
auto outputs_before_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||||
|
|
||||||
|
// Execute optimization pass.
|
||||||
device_context->OptimizeGraph(graph);
|
device_context->OptimizeGraph(graph);
|
||||||
auto outputs_after_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
|
||||||
// Update the output map of kernel graph by modified output nodes.
|
|
||||||
graph->UpdateGraphOutputMap(outputs_before_optimizer, outputs_after_optimizer);
|
|
||||||
|
|
||||||
// Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
|
// Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
|
||||||
// 'KernelMod' is real executive object of kernel.
|
// 'KernelMod' is real executive object of kernel.
|
||||||
device_context->CreateKernel(graph->execution_order());
|
device_context->CreateKernel(graph->execution_order());
|
||||||
|
|
||||||
|
// Adjust kernel graph before run graph.
|
||||||
device_context->PreprocessBeforeRunGraph(graph);
|
device_context->PreprocessBeforeRunGraph(graph);
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Get graph outputs after optimizer, graph id: " << graph->graph_id();
|
||||||
|
auto outputs_after_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||||
|
// Update the output map of kernel graph by modified output nodes.
|
||||||
|
graph->UpdateGraphOutputMap(outputs_before_optimizer, outputs_after_optimizer);
|
||||||
|
|
||||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||||
// Create device address for all anf nodes of graph.
|
// Create device address for all anf nodes of graph.
|
||||||
CreateDeviceAddress(graph, device_context);
|
CreateDeviceAddress(graph, device_context);
|
||||||
|
|
|
@ -800,6 +800,10 @@ ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info)
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info) {
|
void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||||
|
if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
for (const auto &graph : graph_compiler_info.graphs_) {
|
for (const auto &graph : graph_compiler_info.graphs_) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
|
||||||
|
@ -808,6 +812,8 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp
|
||||||
MS_EXCEPTION_IF_NULL(output_kernel);
|
MS_EXCEPTION_IF_NULL(output_kernel);
|
||||||
auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index);
|
auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index);
|
||||||
if (origin_output_with_index.first == nullptr) {
|
if (origin_output_with_index.first == nullptr) {
|
||||||
|
MS_LOG(WARNING) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
|
||||||
|
<< " with index: " << output_with_index.second << " has no actor.";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -837,7 +843,9 @@ void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_comp
|
||||||
MS_EXCEPTION_IF_NULL(actor);
|
MS_EXCEPTION_IF_NULL(actor);
|
||||||
MS_LOG(INFO) << "Cache the graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
|
MS_LOG(INFO) << "Cache the graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
|
||||||
<< " with index: " << output_with_index.second << " to actor:" << actor->GetAID().Name()
|
<< " with index: " << output_with_index.second << " to actor:" << actor->GetAID().Name()
|
||||||
<< " with index:" << actor_output_index;
|
<< " with index:" << actor_output_index
|
||||||
|
<< ", from front node:" << origin_output_with_index.first->fullname_with_scope()
|
||||||
|
<< " with index: " << origin_output_with_index.second;
|
||||||
(void)graph_output_to_actor_.emplace(origin_output_with_index, GraphOutputPair(actor, actor_output_index));
|
(void)graph_output_to_actor_.emplace(origin_output_with_index, GraphOutputPair(actor, actor_output_index));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue