From 312b26080b3fa9df69c4a39bcb9b6097d2dc9878 Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Tue, 30 Nov 2021 12:45:11 +0800 Subject: [PATCH] unified runtime fix execution timeout and no data source actor of control flow --- .../runtime/framework/actor/actor_dump.cc | 10 ++ .../actor/control_flow/entrance_actor.cc | 98 +++++++++++------ .../actor/control_flow/entrance_actor.h | 23 ++-- .../actor/control_flow/exit_actor.cc | 4 +- .../framework/actor/loop_count_actor.cc | 6 + .../framework/actor/loop_count_actor.h | 4 + .../framework/control_node_scheduler.cc | 103 ++++++++++++------ .../framework/control_node_scheduler.h | 7 +- .../runtime/framework/graph_scheduler.cc | 5 + .../ccsrc/runtime/framework/graph_scheduler.h | 2 + mindspore/ccsrc/vm/backend.cc | 1 + 11 files changed, 187 insertions(+), 76 deletions(-) diff --git a/mindspore/ccsrc/runtime/framework/actor/actor_dump.cc b/mindspore/ccsrc/runtime/framework/actor/actor_dump.cc index fb9d5341d9d..8d8f2c97bcb 100644 --- a/mindspore/ccsrc/runtime/framework/actor/actor_dump.cc +++ b/mindspore/ccsrc/runtime/framework/actor/actor_dump.cc @@ -255,6 +255,13 @@ void DumpEntranceActor(const EntranceActor *actor, std::ofstream &ofs) { MS_EXCEPTION_IF_NULL(actor); ofs << "\tactor_name:" << actor->GetAID().Name() << '\n'; DumpControlActor(actor, ofs); + + if (actor->loop_body_input_control_arrow_aids().size() > 0) { + ofs << "\t\tinput_loop_body_control_arrow_actors:" << actor->loop_body_input_control_arrow_aids().size() << "\n "; + for (const auto &loop_body_input_control_arrow_aid : actor->loop_body_input_control_arrow_aids()) { + ofs << "\t\t\tfrom_actor_name:" << loop_body_input_control_arrow_aid.Name() << "\n"; + } + } } void DumpExitActor(const ExitActor *actor, std::ofstream &ofs) { @@ -376,6 +383,9 @@ void DumpLoopCountActor(const LoopCountActorPtr &actor, std::ofstream &ofs) { DumpAbstractActor(actor.get(), ofs); ofs << "\t\t\tto_data_prepare_actor:" << actor->data_prepare_aid().Name() << "\n"; + for (auto &entrance_aid : actor->entrance_aids()) { + ofs << "\t\t\tto_entrance_actor:" << entrance_aid.Name() << "\n"; + } } void DumpOutputActor(const OutputActorPtr &actor, std::ofstream &ofs) { diff --git a/mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.cc b/mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.cc index ed1322f65f8..6fba3abbba9 100644 --- a/mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.cc @@ -21,23 +21,54 @@ namespace mindspore { namespace runtime { constexpr size_t kEntranceInputStartPos = 1; +void EntranceActor::RunOpControl(AID *const input_control, OpContext *const context) { + MS_EXCEPTION_IF_NULL(context); + auto &sequential_num = context->sequential_num_; + if (is_loop_body_execution_) { + (void)loop_body_input_op_controls_[sequential_num].emplace_back(input_control); + } else { + (void)input_op_controls_[sequential_num].emplace_back(input_control); + } + + auto is_run = CheckRunningCondition(context); + MS_LOG(DEBUG) << "Actor(" << GetAID().Name() + << ") receive the input op control and check running condition:" << is_run + << ", loop body execution:" << is_loop_body_execution_; + if (is_run) { + Run(context); + } +} + void EntranceActor::RunOpRealParameterWithBranchID(OpRealParameterWithBranchID real_parameter_with_branch_id, OpContext *const context) { MS_EXCEPTION_IF_NULL(context); auto &sequential_num = context->sequential_num_; real_parameters_with_branch_id_[sequential_num].emplace(real_parameter_with_branch_id); - if (CheckRunningCondition(context)) { + auto is_run = CheckRunningCondition(context); + MS_LOG(DEBUG) << "Actor(" << GetAID().Name() + << ") receive the input op data with branch id and check running condition:" << is_run + << ", loop body execution:" << is_loop_body_execution_; + if (is_run) { Run(context); } } +void EntranceActor::ClearDataOnStepEnd(AID *const input_control, OpContext *const context) { + MS_EXCEPTION_IF_NULL(context); + is_loop_body_execution_ = false; + + if (loop_body_input_controls_nums_ != 0) { + loop_body_input_op_controls_.clear(); + } +} + void EntranceActor::Run(OpContext *const context) { FetchInput(context); EraseInput(context); SendOutput(context); - // The actor needs to be disabled after the actor is running, until no actor is running in the entire funcgraph. - is_actor_ready_ = false; + // The begin execution of step is false and the others execution of step is true. + is_loop_body_execution_ = true; } void EntranceActor::FetchInput(OpContext *const context) { @@ -104,33 +135,34 @@ void EntranceActor::FetchInput(OpContext *const context) { } } -bool EntranceActor::CheckActorStatus(const OpContext *const context) const { - if (is_actor_ready_) { - return true; - } - // During operation, entrance actor can be enabled only when receives all control arrows. - if (input_controls_num_ != 0) { - const auto &control_iter = input_op_controls_.find(context->sequential_num_); - if (control_iter != input_op_controls_.end() && control_iter->second.size() == input_controls_num_) { - return true; - } - } - return false; -} - bool EntranceActor::CheckRunningCondition(const OpContext *context) const { MS_EXCEPTION_IF_NULL(context); - // When the entrance actor is in the disabled state, it cannot be run. - if (!CheckActorStatus(context)) { - return false; + // Check the running condition in the begin execution of step. + // The input controls and input data exist the begin execution of root graph, and there will only be one of the two. + if (!is_loop_body_execution_) { + if (input_controls_num_ != 0) { + const auto &control_iter = input_op_controls_.find(context->sequential_num_); + if ((control_iter != input_op_controls_.end()) && (control_iter->second.size() == input_controls_num_)) { + return true; + } + } + + // Data comes from the data source actor. + if (input_datas_num_ != 0) { + const auto &data_iter = input_op_datas_.find(context->sequential_num_); + if (data_iter != input_op_datas_.end() && data_iter->second.size() == input_datas_num_) { + return true; + } + } } - // Data comes from the data source actor. - if (input_datas_num_ != 0) { - const auto &data_iter = input_op_datas_.find(context->sequential_num_); - if (data_iter != input_op_datas_.end() && data_iter->second.size() == input_datas_num_) { - return true; + // Check the controls in the loop body execution of step. + if (is_loop_body_execution_ && (loop_body_input_controls_nums_ != 0)) { + const auto &control_iter = loop_body_input_op_controls_.find(context->sequential_num_); + if ((control_iter == loop_body_input_op_controls_.end()) || + (control_iter->second.size() != loop_body_input_controls_nums_)) { + return false; } } @@ -149,7 +181,6 @@ void EntranceActor::EraseInput(const OpContext *const context) { const auto &data_iter = input_op_datas_.find(sequential_num); if (data_iter != input_op_datas_.end()) { input_op_datas_.erase(data_iter); - return; } const auto &control_iter = input_op_controls_.find(sequential_num); @@ -157,14 +188,17 @@ void EntranceActor::EraseInput(const OpContext *const context) { input_op_controls_.erase(control_iter); } - const auto &iter = real_parameters_with_branch_id_.find(sequential_num); - if (iter == real_parameters_with_branch_id_.end() || iter->second.empty()) { - MS_LOG(ERROR) << "Cannot find input in batch op result for actor:" << GetAID(); + const auto &loop_body_control_iter = loop_body_input_op_controls_.find(sequential_num); + if (loop_body_control_iter != loop_body_input_op_controls_.end()) { + loop_body_input_op_controls_.erase(loop_body_control_iter); } - iter->second.pop(); - if (iter->second.empty()) { - real_parameters_with_branch_id_.erase(sequential_num); + const auto &iter = real_parameters_with_branch_id_.find(sequential_num); + if (iter != real_parameters_with_branch_id_.end()) { + iter->second.pop(); + if (iter->second.empty()) { + real_parameters_with_branch_id_.erase(sequential_num); + } } } } // namespace runtime diff --git a/mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.h b/mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.h index 839813ff593..17a2410d22c 100644 --- a/mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.h +++ b/mindspore/ccsrc/runtime/framework/actor/control_flow/entrance_actor.h @@ -40,9 +40,17 @@ class EntranceActor : public ControlActor { input_device_tensors_.resize(parameters.size()); } ~EntranceActor() override = default; + + void RunOpControl(AID *const input_control, OpContext *const context) override; + void RunOpRealParameterWithBranchID(OpRealParameterWithBranchID real_parameter_with_branch_id, OpContext *const context); + // Clear the data which are generated in the loop body execution. + void ClearDataOnStepEnd(AID *const input_control, OpContext *const context); + + const std::vector &loop_body_input_control_arrow_aids() const { return loop_body_input_control_arrow_aids_; } + protected: void Run(OpContext *const context) override; void FetchInput(OpContext *const context) override; @@ -52,13 +60,14 @@ class EntranceActor : public ControlActor { private: friend class ControlNodeScheduler; - // Check if actor is enable. During operation, entrance actor can be enabled only when receives all control arrows. - bool CheckActorStatus(const OpContext *const context) const; - - // Is actor ready indicates whether the entrance actor can be executed. In the control flow, the subgraph is an - // atomic operation, and execution can only continue after the output of the corresponding exit actor is completed. - // At this time, the exit actor will notify the entrance actor to change the ready to true. - bool is_actor_ready_{true}; + // Indicate whether the entrance actor is the execution of loop body. In the control flow, the subgraph can be + // triggered to execute in two ways: one is the begin execution of step, another is the execution of loop body. + // The input controls are different in the two ways. + bool is_loop_body_execution_{false}; + // The dependent of loop body input actors. + mindspore::HashMap> loop_body_input_op_controls_; + std::vector loop_body_input_control_arrow_aids_; + size_t loop_body_input_controls_nums_{0}; // Input data with branch id. mindspore::HashMap> real_parameters_with_branch_id_; diff --git a/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc b/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc index 21de9bb9a7c..974606eab20 100644 --- a/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/control_flow/exit_actor.cc @@ -84,8 +84,8 @@ void ExitActor::SendOutput(OpContext *const context) { } auto output_partial = input_partials_[partial_arrow->from_output_index_]; MS_EXCEPTION_IF_NULL(output_partial->func_graph_); - Async(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial, - IntToSize(partial_arrow->to_input_index_), context); + ActorDispatcher::Send(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial, + IntToSize(partial_arrow->to_input_index_), context); } } } diff --git a/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.cc b/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.cc index 813d5631e65..d6ad83cbc61 100644 --- a/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.cc @@ -20,6 +20,7 @@ #include "runtime/framework/actor/memory_manager_actor.h" #include "runtime/framework/actor/recorder_actor.h" #include "runtime/framework/actor/debug_actor.h" +#include "runtime/framework/actor/control_flow/entrance_actor.h" #include "mindrt/include/async/async.h" #include "utils/log_adapter.h" @@ -70,6 +71,11 @@ void LoopCountActor::SendOutput(OpContext *const context) { ActorDispatcher::Send(output_control, &OpActor::RunOpControl, from_aid, context); } + // Send to EntranceActor to clear the data which are generated in the loop body execution. + for (auto &entrance_aid : entrance_aids_) { + ActorDispatcher::Send(entrance_aid, &EntranceActor::ClearDataOnStepEnd, from_aid, context); + } + // The LoopCountActor exits. if (current_count_ == loop_count_) { current_count_ = 0; diff --git a/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h b/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h index f695738cd36..52d98f6abc7 100644 --- a/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h +++ b/mindspore/ccsrc/runtime/framework/actor/loop_count_actor.h @@ -52,6 +52,7 @@ class LoopCountActor : public DebugAwareActor { // Get the member. size_t loop_count() const { return loop_count_; } const AID &data_prepare_aid() const { return data_prepare_aid_; } + const std::vector &entrance_aids() const { return entrance_aids_; } protected: void Run(OpContext *const context) override; @@ -59,6 +60,7 @@ class LoopCountActor : public DebugAwareActor { private: friend class GraphScheduler; + friend class ControlNodeScheduler; void IncreaseLoopCount(OpContext *const context); @@ -68,7 +70,9 @@ class LoopCountActor : public DebugAwareActor { // The total running count represents the toal step running count. size_t total_running_count_; + // The actors which need be handled separately by loop count actor. AID data_prepare_aid_; + std::vector entrance_aids_; }; using LoopCountActorPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc b/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc index 16f89a89ab3..4b8b4279514 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc @@ -380,12 +380,12 @@ void ControlNodeScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &gr // Link data arrows and partial arrows between control actors. LinkArrowForControlActor(actor_set->control_actors_.get(), graph_compiler_info); + // Link arrows from host data source actor or data prepare actor to entrance actor of root graph. + LinkArrowForRootGraphEntranceActor(graph_compiler_info); + // Link output data arrows from control actors to output actor. LinkDataArrowForOutputActor(actor_set, graph_compiler_info); - // Link data arrows from host data source actor to control actors. - LinkDataArrowForHostDSActor(graph_compiler_info); - // Link data arrows from entrance actors to kernel actors. LinkDataArrowForKernelActor(graph_compiler_info); @@ -397,6 +397,19 @@ void ControlNodeScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &gr // Link control arrows for no input and no output kernel actor. LinkControlArrowForKernelActor(actor_set, graph_compiler_info); + + LinkControlArrowForLoopCountActor(actor_set, graph_compiler_info); +} + +void ControlNodeScheduler::ClearActorData(const ControlActorSet *control_actor_set) { + if (control_actor_set == nullptr) { + return; + } + + for (auto &exit_actor : control_actor_set->exit_actors_) { + MS_EXCEPTION_IF_NULL(exit_actor); + exit_actor->created_device_tensors_.clear(); + } } void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const control_actor_set, @@ -657,21 +670,11 @@ void ControlNodeScheduler::LinkArrowByKernel(const AnfNodePtr &kernel, ControlAc void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info) { - // Get the exit actor of root graph, In control flow, the final output is always sent by the exit of the root graph. MS_EXCEPTION_IF_NULL(actor_set); auto control_actor_set = actor_set->control_actors_.get(); MS_EXCEPTION_IF_NULL(control_actor_set); const auto &parser = graph_compiler_info.control_node_parser_; MS_EXCEPTION_IF_NULL(parser); - const auto &root_graph = parser->root_func_graph_; - MS_EXCEPTION_IF_NULL(root_graph); - const auto &exit_actor_name = root_graph->ToString() + kExitActorNameSuffix; - auto actor = FetchActor(exit_actor_name); - MS_EXCEPTION_IF_NULL(actor); - MS_EXCEPTION_IF_NULL(actor_set->loop_count_actor_); - auto root_exit_actor = dynamic_cast(actor); - // link control arrow from root exit actor to loop count actor. - LinkControlArrowForExitActor(root_exit_actor, actor_set->loop_count_actor_.get(), kMainBranchID); // Since only one set of real parameters are allowed to be executed in funcgraph at the same time, when the funcgraph // stops running, it is necessary to send the control arrow to the corresponding entrance actor at the exit of the @@ -682,9 +685,7 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor const auto &func_graph = graph_to_nodes.first; MS_EXCEPTION_IF_NULL(func_graph); auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix; - actor = FetchActor(actor_name); - MS_EXCEPTION_IF_NULL(actor); - auto entrance_actor = dynamic_cast(actor); + auto entrance_actor = dynamic_cast(FetchActor(actor_name)); MS_EXCEPTION_IF_NULL(entrance_actor); const auto &nodes = graph_to_nodes.second; @@ -696,11 +697,9 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor } else { actor_name = GetActorName(node); } - actor = FetchActor(actor_name); - MS_EXCEPTION_IF_NULL(actor); - auto from_actor = dynamic_cast(actor); + auto from_actor = dynamic_cast(FetchActor(actor_name)); MS_EXCEPTION_IF_NULL(from_actor); - LinkControlArrow(from_actor, entrance_actor); + LinkLoopBodyControlArrow(from_actor, entrance_actor); } } @@ -720,9 +719,7 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor const FuncGraphPtr &func_graph = control_actor->node_->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix; - actor = FetchActor(actor_name); - MS_EXCEPTION_IF_NULL(actor); - const auto &entrance_actor = dynamic_cast(actor); + const auto &entrance_actor = dynamic_cast(FetchActor(actor_name)); MS_EXCEPTION_IF_NULL(entrance_actor); LinkControlArrow(entrance_actor, control_actor); } @@ -757,6 +754,31 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor } } +void ControlNodeScheduler::LinkControlArrowForLoopCountActor(const ActorSet *actor_set, + const GraphCompilerInfo &graph_compiler_info) { + MS_EXCEPTION_IF_NULL(actor_set); + auto loop_count_actor = actor_set->loop_count_actor_; + MS_EXCEPTION_IF_NULL(loop_count_actor); + + // The final output is always sent by the exit of the root graph in control flow. + const auto &parser = graph_compiler_info.control_node_parser_; + MS_EXCEPTION_IF_NULL(parser); + const auto &root_graph = parser->root_func_graph_; + MS_EXCEPTION_IF_NULL(root_graph); + auto exit_actor_name = root_graph->ToString() + kExitActorNameSuffix; + auto root_exit_actor = dynamic_cast(FetchActor(exit_actor_name)); + MS_EXCEPTION_IF_NULL(root_exit_actor); + // link control arrow from root exit actor to loop count actor. + LinkControlArrowForExitActor(root_exit_actor, loop_count_actor.get(), kMainBranchID); + + // The entrance actor will generate some data in the loop body execution, so need clear on the end of step. + MS_EXCEPTION_IF_NULL(actor_set->control_actors_); + for (auto &entrance_actor : actor_set->control_actors_->entrance_actors_) { + MS_EXCEPTION_IF_NULL(entrance_actor); + (void)loop_count_actor->entrance_aids_.emplace_back(entrance_actor->GetAID()); + } +} + void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info) { const auto &parser = graph_compiler_info.control_node_parser_; @@ -1024,21 +1046,26 @@ void ControlNodeScheduler::LinkDataArrowForOutputActor(ActorSet *const actor_set actor_set->output_actor_->device_contexts_ = iter->second; } -void ControlNodeScheduler::LinkDataArrowForHostDSActor(const GraphCompilerInfo &graph_compiler_info) { - // In control flow, the host data source actor sends all the input to the entrance actor of the root graph. - const auto &host_ds_actor_name = graph_compiler_info.name_ + "_HostDSActor"; - auto actor = FetchActor(host_ds_actor_name); - MS_EXCEPTION_IF_NULL(actor); - const auto host_ds_actor = dynamic_cast(actor); - MS_EXCEPTION_IF_NULL(host_ds_actor); - +void ControlNodeScheduler::LinkArrowForRootGraphEntranceActor(const GraphCompilerInfo &graph_compiler_info) { + MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_); const auto &root_graph = graph_compiler_info.control_node_parser_->root_func_graph_; MS_EXCEPTION_IF_NULL(root_graph); const auto &entrance_actor_name = root_graph->ToString() + kEntranceActorNameSuffix; - actor = FetchActor(entrance_actor_name); - MS_EXCEPTION_IF_NULL(actor); - auto to_actor = dynamic_cast(actor); + auto to_actor = dynamic_cast(FetchActor(entrance_actor_name)); + MS_EXCEPTION_IF_NULL(to_actor); + const auto &host_ds_actor_name = graph_compiler_info.name_ + "_HostDSActor"; + auto host_ds_actor = dynamic_cast(FetchActor(host_ds_actor_name)); + // No host data source actor scenario. + if (host_ds_actor == nullptr) { + const auto &data_prepare_actor_name = graph_compiler_info.name_ + "_DataPrepareActor"; + auto data_prepare_actor = FetchActor(data_prepare_actor_name); + MS_EXCEPTION_IF_NULL(data_prepare_actor); + LinkControlArrow(data_prepare_actor, to_actor); + return; + } + + // The host data source actor sends all the input to the entrance actor of the root graph. for (size_t i = 0; i < to_actor->formal_parameters_.size(); ++i) { const auto &formal_parameter = to_actor->formal_parameters_[i]; MS_EXCEPTION_IF_NULL(formal_parameter.first); @@ -1076,6 +1103,14 @@ void ControlNodeScheduler::LinkControlArrow(AbstractActor *from_actor, AbstractA (void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID()); } +void ControlNodeScheduler::LinkLoopBodyControlArrow(AbstractActor *from_actor, EntranceActor *to_actor) { + MS_EXCEPTION_IF_NULL(from_actor); + MS_EXCEPTION_IF_NULL(to_actor); + (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID()); + to_actor->loop_body_input_controls_nums_++; + (void)to_actor->loop_body_input_control_arrow_aids_.emplace_back(from_actor->GetAID()); +} + void ControlNodeScheduler::LinkDataArrowForExitActor(ExitActor *const exit_actor, AbstractActor *const to_actor, size_t from_index, size_t to_index, int branch_id) { MS_EXCEPTION_IF_NULL(exit_actor); diff --git a/mindspore/ccsrc/runtime/framework/control_node_scheduler.h b/mindspore/ccsrc/runtime/framework/control_node_scheduler.h index 5b4494bc6f7..00b60e6cdf3 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_scheduler.h +++ b/mindspore/ccsrc/runtime/framework/control_node_scheduler.h @@ -43,6 +43,9 @@ class ControlNodeScheduler { bool CheckActorValid(const ControlActorSetPtr &control_actor_set); + // The control flow actor will generate some data in the loop body execution, so need clear on the end of execution. + void ClearActorData(const ControlActorSet *control_actor_set); + private: // Interface to create control actors. std::vector BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info); @@ -77,14 +80,16 @@ class ControlNodeScheduler { void LinkDataArrowForKernelActor(const GraphCompilerInfo &graph_compiler_info); void LinkDataArrowByKernelGraph(const KernelGraphPtr &graph, bool is_call_input_graph, ControlActor *const entrance_actor); + void LinkArrowForRootGraphEntranceActor(const GraphCompilerInfo &graph_compiler_info); + void LinkControlArrowForLoopCountActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info); void LinkDataArrowForOutputActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info); - void LinkDataArrowForHostDSActor(const GraphCompilerInfo &graph_compiler_info); void LinkControlArrowForKernelActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info); void LinkControlArrowByAutoMonad(ControlActor *to_actor, const AnfNodePtr &from_node, const ControlNodeParserPtr &parser); // Interface tool to link arrows between actors. void LinkControlArrow(AbstractActor *from_actor, AbstractActor *to_actor); + void LinkLoopBodyControlArrow(AbstractActor *from_actor, EntranceActor *to_actor); // Data arrow with branch id is only exists from gather actor to entrance actor. void LinkDataWithBranchIDArrow(GatherActor *const gather_actor, EntranceActor *const entrance_actor, const FuncGraphPtr &func_graph); diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc index b280b6caa62..abf9ab5bc1c 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.cc @@ -215,6 +215,11 @@ void GraphScheduler::Clear() { ClearAllActors(); } +void GraphScheduler::ClearActorData(const ActorSet *actor_set) { + MS_EXCEPTION_IF_NULL(actor_set); + control_node_scheduler_.ClearActorData(actor_set->control_actors_.get()); +} + using DataArrowLinkFunc = void (GraphScheduler::*)(AbstractActor *const, AbstractActor *const, const KernelWithIndex &, const KernelWithIndex &, const KernelGraphPtr &); static std::map kKernelTypeToLinkFunc; diff --git a/mindspore/ccsrc/runtime/framework/graph_scheduler.h b/mindspore/ccsrc/runtime/framework/graph_scheduler.h index 5a37f998175..4757f9ca7b1 100644 --- a/mindspore/ccsrc/runtime/framework/graph_scheduler.h +++ b/mindspore/ccsrc/runtime/framework/graph_scheduler.h @@ -57,6 +57,8 @@ class GraphScheduler { // Clear the members. void Clear(); void Clear(const ActorInfo &actor_info, const std::vector &graphs) noexcept; + // The control flow actors will generate some data in the loop body execution, so need clear on the end of execution. + void ClearActorData(const ActorSet *actor_set); // Transform graph to actor DAG, contains build and link. ActorSet *Transform(const GraphCompilerInfo &graph_compiler_info); diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index ba3b8729bd3..524f14f6376 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -882,6 +882,7 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, size_t output_position = 0; ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs); } + runtime::GraphScheduler::GetInstance().ClearActorData(actor_set); MS_LOG(INFO) << "Status record: end run actor: " << actor_info; }