forked from mindspore-Ecosystem/mindspore
!18838 fix bug of actor runtime host and device
Merge pull request !18838 from limingqi107/actor_runtime
This commit is contained in:
commit
119daba37b
|
@ -1129,9 +1129,6 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
|
|||
AnfAlgo::VisitKernelWithReturnType(front_node_with_index.first, front_node_with_index.second, false);
|
||||
auto front_output_node = front_output_with_index.first;
|
||||
MS_EXCEPTION_IF_NULL(front_output_node);
|
||||
MS_LOG(INFO) << "Link data arrow for internal parameter:" << internal_parameter->fullname_with_scope()
|
||||
<< ", corresponding front node:" << front_output_node->fullname_with_scope()
|
||||
<< " with output index:" << front_output_with_index.second;
|
||||
if (IsPersistentDeviceTensor(front_output_node)) {
|
||||
to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_output_node.get());
|
||||
return;
|
||||
|
@ -1141,6 +1138,11 @@ void GraphScheduler::LinkDataArrowForInternalParameter(const AnfNodePtr &interna
|
|||
<< ", internal parameter:" << internal_parameter->fullname_with_scope();
|
||||
}
|
||||
auto actor_pair = graph_output_to_actor_[front_output_with_index];
|
||||
MS_LOG(INFO) << "Graph " << graph->graph_id() << " internal parameter:" << internal_parameter->DebugString()
|
||||
<< ", corresponding front node:" << front_output_node->fullname_with_scope()
|
||||
<< " with index:" << front_output_with_index.second
|
||||
<< ", from actor:" << actor_pair.first->GetAID().Name() << " with index:" << actor_pair.second
|
||||
<< ", to actor:" << to_actor->GetAID().Name() << " with index:" << to_kernel_with_input_idx.second;
|
||||
|
||||
if (IsDeviceQueueDSActor(front_output_node)) {
|
||||
auto from_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor_pair.first);
|
||||
|
@ -1688,19 +1690,24 @@ void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<Ke
|
|||
MS_EXCEPTION_IF_NULL(another_device_context);
|
||||
copy_actor->output_device_context_ = another_device_context;
|
||||
|
||||
// LInk from copy actor to kernel actor users.
|
||||
if (kernel_actor->output_control_arrows_.size() == 0) {
|
||||
MS_LOG(INFO) << "The kernel actor has no control arrow:" << kernel_actor->GetAID().Name();
|
||||
}
|
||||
MS_LOG(INFO) << "The kernel actor: " << kernel_actor->GetAID().Name()
|
||||
<< "has control arrows number:" << kernel_actor->output_control_arrows_.size();
|
||||
// Link from copy actor to kernel actor users.
|
||||
for (auto &output_contorl : kernel_actor->output_control_arrows_) {
|
||||
copy_actor->output_control_arrows_.emplace_back(output_contorl);
|
||||
auto to_actor = FetchActor(output_contorl.Name());
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
if (output_contorl.Name().find("_LoopCountActor") != string::npos) {
|
||||
auto real_to_actor = dynamic_cast<LoopCountActor *>(to_actor);
|
||||
MS_EXCEPTION_IF_NULL(real_to_actor);
|
||||
real_to_actor->branch_id_to_input_controls_num_[kMainBranchID]++;
|
||||
} else if (output_contorl.Name().find("copy_from") != string::npos) {
|
||||
auto real_to_actor = dynamic_cast<CopyActor *>(to_actor);
|
||||
MS_EXCEPTION_IF_NULL(real_to_actor);
|
||||
real_to_actor->input_controls_num_++;
|
||||
} else {
|
||||
auto real_to_actor = dynamic_cast<KernelActor *>(to_actor);
|
||||
MS_EXCEPTION_IF_NULL(real_to_actor);
|
||||
real_to_actor->input_controls_num_++;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -579,24 +579,22 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
|
|||
const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
runtime::GraphScheduler::GetInstance().PrepareRun(actor_set, graph_compiler_info, input_tensors);
|
||||
|
||||
// PreExecuteGraph
|
||||
// Debugger pre-execute graph.
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
auto debugger = Debugger::GetInstance();
|
||||
if (debugger) {
|
||||
debugger->Debugger::PreExecuteGraphDebugger(graph_compiler_info.graphs_);
|
||||
if (Debugger::GetInstance()->DebuggerBackendEnabled()) {
|
||||
Debugger::GetInstance()->PreExecuteGraphDebugger(graph_compiler_info.graphs_);
|
||||
}
|
||||
#endif
|
||||
if (!runtime::GraphScheduler::GetInstance().Run(actor_set)) {
|
||||
MS_LOG(EXCEPTION) << "The actor runs failed, actor name: " << actor_set->name_;
|
||||
}
|
||||
|
||||
// PostExecuteGraph
|
||||
// Debugger post-execute graph.
|
||||
#ifdef ENABLE_DEBUGGER
|
||||
if (debugger) {
|
||||
debugger->Debugger::PostExecuteGraphDebugger(graph_compiler_info.graphs_);
|
||||
if (Debugger::GetInstance()->DebuggerBackendEnabled()) {
|
||||
Debugger::GetInstance()->PostExecuteGraphDebugger(graph_compiler_info.graphs_);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Sync device stream.
|
||||
const auto &first_device_context = graph_compiler_info.device_contexts_[0];
|
||||
MS_EXCEPTION_IF_NULL(first_device_context);
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
|
@ -147,8 +148,8 @@ class MindRTBackend : public Backend {
|
|||
// When compiling FuncGraph, it is divided according to the control nodes, and obtain the control nodes and several
|
||||
// node segments. Node segments will be compiled into kernelGraphs which are expressed as GraphId and bound to
|
||||
// the corresponding device_context.
|
||||
std::unordered_map<GraphId, DeviceContext *> graph_id_to_device_context_;
|
||||
std::unordered_map<GraphInfo, DeviceContext *> graph_info_to_device_context_;
|
||||
std::map<GraphId, DeviceContext *> graph_id_to_device_context_;
|
||||
std::map<GraphInfo, DeviceContext *> graph_info_to_device_context_;
|
||||
std::vector<AnfNodePtr> control_nodes_;
|
||||
|
||||
std::unordered_map<ActorInfo, std::unique_ptr<GraphCompilerInfo>> actor_to_graph_compiler_info_;
|
||||
|
|
Loading…
Reference in New Issue