!26984 unified runtime fix execution timeout and no data source actor of control flow

Merge pull request !26984 from limingqi107/new_actor_runtime2
This commit is contained in:
i-robot 2021-12-01 06:45:28 +00:00 committed by Gitee
commit b3c51fc2aa
11 changed files with 187 additions and 76 deletions

View File

@ -255,6 +255,13 @@ void DumpEntranceActor(const EntranceActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor); MS_EXCEPTION_IF_NULL(actor);
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n'; ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
DumpControlActor(actor, ofs); DumpControlActor(actor, ofs);
if (actor->loop_body_input_control_arrow_aids().size() > 0) {
ofs << "\t\tinput_loop_body_control_arrow_actors:" << actor->loop_body_input_control_arrow_aids().size() << "\n ";
for (const auto &loop_body_input_control_arrow_aid : actor->loop_body_input_control_arrow_aids()) {
ofs << "\t\t\tfrom_actor_name:" << loop_body_input_control_arrow_aid.Name() << "\n";
}
}
} }
void DumpExitActor(const ExitActor *actor, std::ofstream &ofs) { void DumpExitActor(const ExitActor *actor, std::ofstream &ofs) {
@ -376,6 +383,9 @@ void DumpLoopCountActor(const LoopCountActorPtr &actor, std::ofstream &ofs) {
DumpAbstractActor(actor.get(), ofs); DumpAbstractActor(actor.get(), ofs);
ofs << "\t\t\tto_data_prepare_actor:" << actor->data_prepare_aid().Name() << "\n"; ofs << "\t\t\tto_data_prepare_actor:" << actor->data_prepare_aid().Name() << "\n";
for (auto &entrance_aid : actor->entrance_aids()) {
ofs << "\t\t\tto_entrance_actor:" << entrance_aid.Name() << "\n";
}
} }
void DumpOutputActor(const OutputActorPtr &actor, std::ofstream &ofs) { void DumpOutputActor(const OutputActorPtr &actor, std::ofstream &ofs) {

View File

@ -21,23 +21,54 @@ namespace mindspore {
namespace runtime { namespace runtime {
constexpr size_t kEntranceInputStartPos = 1; constexpr size_t kEntranceInputStartPos = 1;
void EntranceActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_;
if (is_loop_body_execution_) {
(void)loop_body_input_op_controls_[sequential_num].emplace_back(input_control);
} else {
(void)input_op_controls_[sequential_num].emplace_back(input_control);
}
auto is_run = CheckRunningCondition(context);
MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
<< ") receive the input op control and check running condition:" << is_run
<< ", loop body execution:" << is_loop_body_execution_;
if (is_run) {
Run(context);
}
}
void EntranceActor::RunOpRealParameterWithBranchID(OpRealParameterWithBranchID real_parameter_with_branch_id, void EntranceActor::RunOpRealParameterWithBranchID(OpRealParameterWithBranchID real_parameter_with_branch_id,
OpContext<DeviceTensor> *const context) { OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
auto &sequential_num = context->sequential_num_; auto &sequential_num = context->sequential_num_;
real_parameters_with_branch_id_[sequential_num].emplace(real_parameter_with_branch_id); real_parameters_with_branch_id_[sequential_num].emplace(real_parameter_with_branch_id);
if (CheckRunningCondition(context)) { auto is_run = CheckRunningCondition(context);
MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
<< ") receive the input op data with branch id and check running condition:" << is_run
<< ", loop body execution:" << is_loop_body_execution_;
if (is_run) {
Run(context); Run(context);
} }
} }
void EntranceActor::ClearDataOnStepEnd(AID *const input_control, OpContext<DeviceTensor> *const context) {
MS_EXCEPTION_IF_NULL(context);
is_loop_body_execution_ = false;
if (loop_body_input_controls_nums_ != 0) {
loop_body_input_op_controls_.clear();
}
}
void EntranceActor::Run(OpContext<DeviceTensor> *const context) { void EntranceActor::Run(OpContext<DeviceTensor> *const context) {
FetchInput(context); FetchInput(context);
EraseInput(context); EraseInput(context);
SendOutput(context); SendOutput(context);
// The actor needs to be disabled after the actor is running, until no actor is running in the entire funcgraph. // The begin execution of step is false and the others execution of step is true.
is_actor_ready_ = false; is_loop_body_execution_ = true;
} }
void EntranceActor::FetchInput(OpContext<DeviceTensor> *const context) { void EntranceActor::FetchInput(OpContext<DeviceTensor> *const context) {
@ -104,33 +135,34 @@ void EntranceActor::FetchInput(OpContext<DeviceTensor> *const context) {
} }
} }
bool EntranceActor::CheckActorStatus(const OpContext<DeviceTensor> *const context) const {
if (is_actor_ready_) {
return true;
}
// During operation, entrance actor can be enabled only when receives all control arrows.
if (input_controls_num_ != 0) {
const auto &control_iter = input_op_controls_.find(context->sequential_num_);
if (control_iter != input_op_controls_.end() && control_iter->second.size() == input_controls_num_) {
return true;
}
}
return false;
}
bool EntranceActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const { bool EntranceActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
// When the entrance actor is in the disabled state, it cannot be run. // Check the running condition in the begin execution of step.
if (!CheckActorStatus(context)) { // The input controls and input data exist the begin execution of root graph, and there will only be one of the two.
return false; if (!is_loop_body_execution_) {
if (input_controls_num_ != 0) {
const auto &control_iter = input_op_controls_.find(context->sequential_num_);
if ((control_iter != input_op_controls_.end()) && (control_iter->second.size() == input_controls_num_)) {
return true;
}
}
// Data comes from the data source actor.
if (input_datas_num_ != 0) {
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
if (data_iter != input_op_datas_.end() && data_iter->second.size() == input_datas_num_) {
return true;
}
}
} }
// Data comes from the data source actor. // Check the controls in the loop body execution of step.
if (input_datas_num_ != 0) { if (is_loop_body_execution_ && (loop_body_input_controls_nums_ != 0)) {
const auto &data_iter = input_op_datas_.find(context->sequential_num_); const auto &control_iter = loop_body_input_op_controls_.find(context->sequential_num_);
if (data_iter != input_op_datas_.end() && data_iter->second.size() == input_datas_num_) { if ((control_iter == loop_body_input_op_controls_.end()) ||
return true; (control_iter->second.size() != loop_body_input_controls_nums_)) {
return false;
} }
} }
@ -149,7 +181,6 @@ void EntranceActor::EraseInput(const OpContext<DeviceTensor> *const context) {
const auto &data_iter = input_op_datas_.find(sequential_num); const auto &data_iter = input_op_datas_.find(sequential_num);
if (data_iter != input_op_datas_.end()) { if (data_iter != input_op_datas_.end()) {
input_op_datas_.erase(data_iter); input_op_datas_.erase(data_iter);
return;
} }
const auto &control_iter = input_op_controls_.find(sequential_num); const auto &control_iter = input_op_controls_.find(sequential_num);
@ -157,14 +188,17 @@ void EntranceActor::EraseInput(const OpContext<DeviceTensor> *const context) {
input_op_controls_.erase(control_iter); input_op_controls_.erase(control_iter);
} }
const auto &iter = real_parameters_with_branch_id_.find(sequential_num); const auto &loop_body_control_iter = loop_body_input_op_controls_.find(sequential_num);
if (iter == real_parameters_with_branch_id_.end() || iter->second.empty()) { if (loop_body_control_iter != loop_body_input_op_controls_.end()) {
MS_LOG(ERROR) << "Cannot find input in batch op result for actor:" << GetAID(); loop_body_input_op_controls_.erase(loop_body_control_iter);
} }
iter->second.pop(); const auto &iter = real_parameters_with_branch_id_.find(sequential_num);
if (iter->second.empty()) { if (iter != real_parameters_with_branch_id_.end()) {
real_parameters_with_branch_id_.erase(sequential_num); iter->second.pop();
if (iter->second.empty()) {
real_parameters_with_branch_id_.erase(sequential_num);
}
} }
} }
} // namespace runtime } // namespace runtime

View File

@ -40,9 +40,17 @@ class EntranceActor : public ControlActor {
input_device_tensors_.resize(parameters.size()); input_device_tensors_.resize(parameters.size());
} }
~EntranceActor() override = default; ~EntranceActor() override = default;
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
void RunOpRealParameterWithBranchID(OpRealParameterWithBranchID real_parameter_with_branch_id, void RunOpRealParameterWithBranchID(OpRealParameterWithBranchID real_parameter_with_branch_id,
OpContext<DeviceTensor> *const context); OpContext<DeviceTensor> *const context);
// Clear the data which are generated in the loop body execution.
void ClearDataOnStepEnd(AID *const input_control, OpContext<DeviceTensor> *const context);
const std::vector<AID> &loop_body_input_control_arrow_aids() const { return loop_body_input_control_arrow_aids_; }
protected: protected:
void Run(OpContext<DeviceTensor> *const context) override; void Run(OpContext<DeviceTensor> *const context) override;
void FetchInput(OpContext<DeviceTensor> *const context) override; void FetchInput(OpContext<DeviceTensor> *const context) override;
@ -52,13 +60,14 @@ class EntranceActor : public ControlActor {
private: private:
friend class ControlNodeScheduler; friend class ControlNodeScheduler;
// Check if actor is enable. During operation, entrance actor can be enabled only when receives all control arrows. // Indicate whether the entrance actor is the execution of loop body. In the control flow, the subgraph can be
bool CheckActorStatus(const OpContext<DeviceTensor> *const context) const; // triggered to execute in two ways: one is the begin execution of step, another is the execution of loop body.
// The input controls are different in the two ways.
// Is actor ready indicates whether the entrance actor can be executed. In the control flow, the subgraph is an bool is_loop_body_execution_{false};
// atomic operation, and execution can only continue after the output of the corresponding exit actor is completed. // The dependent of loop body input actors.
// At this time, the exit actor will notify the entrance actor to change the ready to true. mindspore::HashMap<int, std::vector<AID *>> loop_body_input_op_controls_;
bool is_actor_ready_{true}; std::vector<AID> loop_body_input_control_arrow_aids_;
size_t loop_body_input_controls_nums_{0};
// Input data with branch id. // Input data with branch id.
mindspore::HashMap<int, std::queue<OpRealParameterWithBranchID>> real_parameters_with_branch_id_; mindspore::HashMap<int, std::queue<OpRealParameterWithBranchID>> real_parameters_with_branch_id_;

View File

@ -84,8 +84,8 @@ void ExitActor::SendOutput(OpContext<DeviceTensor> *const context) {
} }
auto output_partial = input_partials_[partial_arrow->from_output_index_]; auto output_partial = input_partials_[partial_arrow->from_output_index_];
MS_EXCEPTION_IF_NULL(output_partial->func_graph_); MS_EXCEPTION_IF_NULL(output_partial->func_graph_);
Async(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial, ActorDispatcher::Send(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial,
IntToSize(partial_arrow->to_input_index_), context); IntToSize(partial_arrow->to_input_index_), context);
} }
} }
} }

View File

@ -20,6 +20,7 @@
#include "runtime/framework/actor/memory_manager_actor.h" #include "runtime/framework/actor/memory_manager_actor.h"
#include "runtime/framework/actor/recorder_actor.h" #include "runtime/framework/actor/recorder_actor.h"
#include "runtime/framework/actor/debug_actor.h" #include "runtime/framework/actor/debug_actor.h"
#include "runtime/framework/actor/control_flow/entrance_actor.h"
#include "mindrt/include/async/async.h" #include "mindrt/include/async/async.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
@ -70,6 +71,11 @@ void LoopCountActor::SendOutput(OpContext<DeviceTensor> *const context) {
ActorDispatcher::Send(output_control, &OpActor::RunOpControl, from_aid, context); ActorDispatcher::Send(output_control, &OpActor::RunOpControl, from_aid, context);
} }
// Send to EntranceActor to clear the data which are generated in the loop body execution.
for (auto &entrance_aid : entrance_aids_) {
ActorDispatcher::Send(entrance_aid, &EntranceActor::ClearDataOnStepEnd, from_aid, context);
}
// The LoopCountActor exits. // The LoopCountActor exits.
if (current_count_ == loop_count_) { if (current_count_ == loop_count_) {
current_count_ = 0; current_count_ = 0;

View File

@ -52,6 +52,7 @@ class LoopCountActor : public DebugAwareActor {
// Get the member. // Get the member.
size_t loop_count() const { return loop_count_; } size_t loop_count() const { return loop_count_; }
const AID &data_prepare_aid() const { return data_prepare_aid_; } const AID &data_prepare_aid() const { return data_prepare_aid_; }
const std::vector<AID> &entrance_aids() const { return entrance_aids_; }
protected: protected:
void Run(OpContext<DeviceTensor> *const context) override; void Run(OpContext<DeviceTensor> *const context) override;
@ -59,6 +60,7 @@ class LoopCountActor : public DebugAwareActor {
private: private:
friend class GraphScheduler; friend class GraphScheduler;
friend class ControlNodeScheduler;
void IncreaseLoopCount(OpContext<DeviceTensor> *const context); void IncreaseLoopCount(OpContext<DeviceTensor> *const context);
@ -68,7 +70,9 @@ class LoopCountActor : public DebugAwareActor {
// The total running count represents the toal step running count. // The total running count represents the toal step running count.
size_t total_running_count_; size_t total_running_count_;
// The actors which need be handled separately by loop count actor.
AID data_prepare_aid_; AID data_prepare_aid_;
std::vector<AID> entrance_aids_;
}; };
using LoopCountActorPtr = std::shared_ptr<LoopCountActor>; using LoopCountActorPtr = std::shared_ptr<LoopCountActor>;

View File

@ -380,12 +380,12 @@ void ControlNodeScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &gr
// Link data arrows and partial arrows between control actors. // Link data arrows and partial arrows between control actors.
LinkArrowForControlActor(actor_set->control_actors_.get(), graph_compiler_info); LinkArrowForControlActor(actor_set->control_actors_.get(), graph_compiler_info);
// Link arrows from host data source actor or data prepare actor to entrance actor of root graph.
LinkArrowForRootGraphEntranceActor(graph_compiler_info);
// Link output data arrows from control actors to output actor. // Link output data arrows from control actors to output actor.
LinkDataArrowForOutputActor(actor_set, graph_compiler_info); LinkDataArrowForOutputActor(actor_set, graph_compiler_info);
// Link data arrows from host data source actor to control actors.
LinkDataArrowForHostDSActor(graph_compiler_info);
// Link data arrows from entrance actors to kernel actors. // Link data arrows from entrance actors to kernel actors.
LinkDataArrowForKernelActor(graph_compiler_info); LinkDataArrowForKernelActor(graph_compiler_info);
@ -397,6 +397,19 @@ void ControlNodeScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &gr
// Link control arrows for no input and no output kernel actor. // Link control arrows for no input and no output kernel actor.
LinkControlArrowForKernelActor(actor_set, graph_compiler_info); LinkControlArrowForKernelActor(actor_set, graph_compiler_info);
LinkControlArrowForLoopCountActor(actor_set, graph_compiler_info);
}
void ControlNodeScheduler::ClearActorData(const ControlActorSet *control_actor_set) {
if (control_actor_set == nullptr) {
return;
}
for (auto &exit_actor : control_actor_set->exit_actors_) {
MS_EXCEPTION_IF_NULL(exit_actor);
exit_actor->created_device_tensors_.clear();
}
} }
void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const control_actor_set, void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const control_actor_set,
@ -657,21 +670,11 @@ void ControlNodeScheduler::LinkArrowByKernel(const AnfNodePtr &kernel, ControlAc
void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor_set, void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor_set,
const GraphCompilerInfo &graph_compiler_info) { const GraphCompilerInfo &graph_compiler_info) {
// Get the exit actor of root graph, In control flow, the final output is always sent by the exit of the root graph.
MS_EXCEPTION_IF_NULL(actor_set); MS_EXCEPTION_IF_NULL(actor_set);
auto control_actor_set = actor_set->control_actors_.get(); auto control_actor_set = actor_set->control_actors_.get();
MS_EXCEPTION_IF_NULL(control_actor_set); MS_EXCEPTION_IF_NULL(control_actor_set);
const auto &parser = graph_compiler_info.control_node_parser_; const auto &parser = graph_compiler_info.control_node_parser_;
MS_EXCEPTION_IF_NULL(parser); MS_EXCEPTION_IF_NULL(parser);
const auto &root_graph = parser->root_func_graph_;
MS_EXCEPTION_IF_NULL(root_graph);
const auto &exit_actor_name = root_graph->ToString() + kExitActorNameSuffix;
auto actor = FetchActor(exit_actor_name);
MS_EXCEPTION_IF_NULL(actor);
MS_EXCEPTION_IF_NULL(actor_set->loop_count_actor_);
auto root_exit_actor = dynamic_cast<ExitActor *>(actor);
// link control arrow from root exit actor to loop count actor.
LinkControlArrowForExitActor(root_exit_actor, actor_set->loop_count_actor_.get(), kMainBranchID);
// Since only one set of real parameters are allowed to be executed in funcgraph at the same time, when the funcgraph // Since only one set of real parameters are allowed to be executed in funcgraph at the same time, when the funcgraph
// stops running, it is necessary to send the control arrow to the corresponding entrance actor at the exit of the // stops running, it is necessary to send the control arrow to the corresponding entrance actor at the exit of the
@ -682,9 +685,7 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
const auto &func_graph = graph_to_nodes.first; const auto &func_graph = graph_to_nodes.first;
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix; auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
actor = FetchActor(actor_name); auto entrance_actor = dynamic_cast<EntranceActor *>(FetchActor(actor_name));
MS_EXCEPTION_IF_NULL(actor);
auto entrance_actor = dynamic_cast<EntranceActor *>(actor);
MS_EXCEPTION_IF_NULL(entrance_actor); MS_EXCEPTION_IF_NULL(entrance_actor);
const auto &nodes = graph_to_nodes.second; const auto &nodes = graph_to_nodes.second;
@ -696,11 +697,9 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
} else { } else {
actor_name = GetActorName(node); actor_name = GetActorName(node);
} }
actor = FetchActor(actor_name); auto from_actor = dynamic_cast<ControlActor *>(FetchActor(actor_name));
MS_EXCEPTION_IF_NULL(actor);
auto from_actor = dynamic_cast<ControlActor *>(actor);
MS_EXCEPTION_IF_NULL(from_actor); MS_EXCEPTION_IF_NULL(from_actor);
LinkControlArrow(from_actor, entrance_actor); LinkLoopBodyControlArrow(from_actor, entrance_actor);
} }
} }
@ -720,9 +719,7 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
const FuncGraphPtr &func_graph = control_actor->node_->func_graph(); const FuncGraphPtr &func_graph = control_actor->node_->func_graph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix; const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
actor = FetchActor(actor_name); const auto &entrance_actor = dynamic_cast<EntranceActor *>(FetchActor(actor_name));
MS_EXCEPTION_IF_NULL(actor);
const auto &entrance_actor = dynamic_cast<EntranceActor *>(actor);
MS_EXCEPTION_IF_NULL(entrance_actor); MS_EXCEPTION_IF_NULL(entrance_actor);
LinkControlArrow(entrance_actor, control_actor); LinkControlArrow(entrance_actor, control_actor);
} }
@ -757,6 +754,31 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
} }
} }
void ControlNodeScheduler::LinkControlArrowForLoopCountActor(const ActorSet *actor_set,
const GraphCompilerInfo &graph_compiler_info) {
MS_EXCEPTION_IF_NULL(actor_set);
auto loop_count_actor = actor_set->loop_count_actor_;
MS_EXCEPTION_IF_NULL(loop_count_actor);
// The final output is always sent by the exit of the root graph in control flow.
const auto &parser = graph_compiler_info.control_node_parser_;
MS_EXCEPTION_IF_NULL(parser);
const auto &root_graph = parser->root_func_graph_;
MS_EXCEPTION_IF_NULL(root_graph);
auto exit_actor_name = root_graph->ToString() + kExitActorNameSuffix;
auto root_exit_actor = dynamic_cast<ExitActor *>(FetchActor(exit_actor_name));
MS_EXCEPTION_IF_NULL(root_exit_actor);
// link control arrow from root exit actor to loop count actor.
LinkControlArrowForExitActor(root_exit_actor, loop_count_actor.get(), kMainBranchID);
// The entrance actor will generate some data in the loop body execution, so need clear on the end of step.
MS_EXCEPTION_IF_NULL(actor_set->control_actors_);
for (auto &entrance_actor : actor_set->control_actors_->entrance_actors_) {
MS_EXCEPTION_IF_NULL(entrance_actor);
(void)loop_count_actor->entrance_aids_.emplace_back(entrance_actor->GetAID());
}
}
void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_set, void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_set,
const GraphCompilerInfo &graph_compiler_info) { const GraphCompilerInfo &graph_compiler_info) {
const auto &parser = graph_compiler_info.control_node_parser_; const auto &parser = graph_compiler_info.control_node_parser_;
@ -1024,21 +1046,26 @@ void ControlNodeScheduler::LinkDataArrowForOutputActor(ActorSet *const actor_set
actor_set->output_actor_->device_contexts_ = iter->second; actor_set->output_actor_->device_contexts_ = iter->second;
} }
void ControlNodeScheduler::LinkDataArrowForHostDSActor(const GraphCompilerInfo &graph_compiler_info) { void ControlNodeScheduler::LinkArrowForRootGraphEntranceActor(const GraphCompilerInfo &graph_compiler_info) {
// In control flow, the host data source actor sends all the input to the entrance actor of the root graph. MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
const auto &host_ds_actor_name = graph_compiler_info.name_ + "_HostDSActor";
auto actor = FetchActor(host_ds_actor_name);
MS_EXCEPTION_IF_NULL(actor);
const auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(actor);
MS_EXCEPTION_IF_NULL(host_ds_actor);
const auto &root_graph = graph_compiler_info.control_node_parser_->root_func_graph_; const auto &root_graph = graph_compiler_info.control_node_parser_->root_func_graph_;
MS_EXCEPTION_IF_NULL(root_graph); MS_EXCEPTION_IF_NULL(root_graph);
const auto &entrance_actor_name = root_graph->ToString() + kEntranceActorNameSuffix; const auto &entrance_actor_name = root_graph->ToString() + kEntranceActorNameSuffix;
actor = FetchActor(entrance_actor_name); auto to_actor = dynamic_cast<EntranceActor *>(FetchActor(entrance_actor_name));
MS_EXCEPTION_IF_NULL(actor); MS_EXCEPTION_IF_NULL(to_actor);
auto to_actor = dynamic_cast<EntranceActor *>(actor);
const auto &host_ds_actor_name = graph_compiler_info.name_ + "_HostDSActor";
auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(host_ds_actor_name));
// No host data source actor scenario.
if (host_ds_actor == nullptr) {
const auto &data_prepare_actor_name = graph_compiler_info.name_ + "_DataPrepareActor";
auto data_prepare_actor = FetchActor(data_prepare_actor_name);
MS_EXCEPTION_IF_NULL(data_prepare_actor);
LinkControlArrow(data_prepare_actor, to_actor);
return;
}
// The host data source actor sends all the input to the entrance actor of the root graph.
for (size_t i = 0; i < to_actor->formal_parameters_.size(); ++i) { for (size_t i = 0; i < to_actor->formal_parameters_.size(); ++i) {
const auto &formal_parameter = to_actor->formal_parameters_[i]; const auto &formal_parameter = to_actor->formal_parameters_[i];
MS_EXCEPTION_IF_NULL(formal_parameter.first); MS_EXCEPTION_IF_NULL(formal_parameter.first);
@ -1076,6 +1103,14 @@ void ControlNodeScheduler::LinkControlArrow(AbstractActor *from_actor, AbstractA
(void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID()); (void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID());
} }
void ControlNodeScheduler::LinkLoopBodyControlArrow(AbstractActor *from_actor, EntranceActor *to_actor) {
MS_EXCEPTION_IF_NULL(from_actor);
MS_EXCEPTION_IF_NULL(to_actor);
(void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
to_actor->loop_body_input_controls_nums_++;
(void)to_actor->loop_body_input_control_arrow_aids_.emplace_back(from_actor->GetAID());
}
void ControlNodeScheduler::LinkDataArrowForExitActor(ExitActor *const exit_actor, AbstractActor *const to_actor, void ControlNodeScheduler::LinkDataArrowForExitActor(ExitActor *const exit_actor, AbstractActor *const to_actor,
size_t from_index, size_t to_index, int branch_id) { size_t from_index, size_t to_index, int branch_id) {
MS_EXCEPTION_IF_NULL(exit_actor); MS_EXCEPTION_IF_NULL(exit_actor);

View File

@ -43,6 +43,9 @@ class ControlNodeScheduler {
bool CheckActorValid(const ControlActorSetPtr &control_actor_set); bool CheckActorValid(const ControlActorSetPtr &control_actor_set);
// The control flow actor will generate some data in the loop body execution, so need clear on the end of execution.
void ClearActorData(const ControlActorSet *control_actor_set);
private: private:
// Interface to create control actors. // Interface to create control actors.
std::vector<SwitchActorPtr> BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info); std::vector<SwitchActorPtr> BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info);
@ -77,14 +80,16 @@ class ControlNodeScheduler {
void LinkDataArrowForKernelActor(const GraphCompilerInfo &graph_compiler_info); void LinkDataArrowForKernelActor(const GraphCompilerInfo &graph_compiler_info);
void LinkDataArrowByKernelGraph(const KernelGraphPtr &graph, bool is_call_input_graph, void LinkDataArrowByKernelGraph(const KernelGraphPtr &graph, bool is_call_input_graph,
ControlActor *const entrance_actor); ControlActor *const entrance_actor);
void LinkArrowForRootGraphEntranceActor(const GraphCompilerInfo &graph_compiler_info);
void LinkControlArrowForLoopCountActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info);
void LinkDataArrowForOutputActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info); void LinkDataArrowForOutputActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
void LinkDataArrowForHostDSActor(const GraphCompilerInfo &graph_compiler_info);
void LinkControlArrowForKernelActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info); void LinkControlArrowForKernelActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
void LinkControlArrowByAutoMonad(ControlActor *to_actor, const AnfNodePtr &from_node, void LinkControlArrowByAutoMonad(ControlActor *to_actor, const AnfNodePtr &from_node,
const ControlNodeParserPtr &parser); const ControlNodeParserPtr &parser);
// Interface tool to link arrows between actors. // Interface tool to link arrows between actors.
void LinkControlArrow(AbstractActor *from_actor, AbstractActor *to_actor); void LinkControlArrow(AbstractActor *from_actor, AbstractActor *to_actor);
void LinkLoopBodyControlArrow(AbstractActor *from_actor, EntranceActor *to_actor);
// Data arrow with branch id is only exists from gather actor to entrance actor. // Data arrow with branch id is only exists from gather actor to entrance actor.
void LinkDataWithBranchIDArrow(GatherActor *const gather_actor, EntranceActor *const entrance_actor, void LinkDataWithBranchIDArrow(GatherActor *const gather_actor, EntranceActor *const entrance_actor,
const FuncGraphPtr &func_graph); const FuncGraphPtr &func_graph);

View File

@ -215,6 +215,11 @@ void GraphScheduler::Clear() {
ClearAllActors(); ClearAllActors();
} }
void GraphScheduler::ClearActorData(const ActorSet *actor_set) {
MS_EXCEPTION_IF_NULL(actor_set);
control_node_scheduler_.ClearActorData(actor_set->control_actors_.get());
}
using DataArrowLinkFunc = void (GraphScheduler::*)(AbstractActor *const, AbstractActor *const, const KernelWithIndex &, using DataArrowLinkFunc = void (GraphScheduler::*)(AbstractActor *const, AbstractActor *const, const KernelWithIndex &,
const KernelWithIndex &, const KernelGraphPtr &); const KernelWithIndex &, const KernelGraphPtr &);
static std::map<KernelTransformType, DataArrowLinkFunc> kKernelTypeToLinkFunc; static std::map<KernelTransformType, DataArrowLinkFunc> kKernelTypeToLinkFunc;

View File

@ -57,6 +57,8 @@ class GraphScheduler {
// Clear the members. // Clear the members.
void Clear(); void Clear();
void Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs) noexcept; void Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs) noexcept;
// The control flow actors will generate some data in the loop body execution, so need clear on the end of execution.
void ClearActorData(const ActorSet *actor_set);
// Transform graph to actor DAG, contains build and link. // Transform graph to actor DAG, contains build and link.
ActorSet *Transform(const GraphCompilerInfo &graph_compiler_info); ActorSet *Transform(const GraphCompilerInfo &graph_compiler_info);

View File

@ -882,6 +882,7 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
size_t output_position = 0; size_t output_position = 0;
ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs); ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs);
} }
runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
MS_LOG(INFO) << "Status record: end run actor: " << actor_info; MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
} }