unified runtime fix execution timeout and no data source actor of control flow
This commit is contained in:
parent
db5ef1136f
commit
312b26080b
|
@ -255,6 +255,13 @@ void DumpEntranceActor(const EntranceActor *actor, std::ofstream &ofs) {
|
|||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
DumpControlActor(actor, ofs);
|
||||
|
||||
if (actor->loop_body_input_control_arrow_aids().size() > 0) {
|
||||
ofs << "\t\tinput_loop_body_control_arrow_actors:" << actor->loop_body_input_control_arrow_aids().size() << "\n ";
|
||||
for (const auto &loop_body_input_control_arrow_aid : actor->loop_body_input_control_arrow_aids()) {
|
||||
ofs << "\t\t\tfrom_actor_name:" << loop_body_input_control_arrow_aid.Name() << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DumpExitActor(const ExitActor *actor, std::ofstream &ofs) {
|
||||
|
@ -376,6 +383,9 @@ void DumpLoopCountActor(const LoopCountActorPtr &actor, std::ofstream &ofs) {
|
|||
DumpAbstractActor(actor.get(), ofs);
|
||||
|
||||
ofs << "\t\t\tto_data_prepare_actor:" << actor->data_prepare_aid().Name() << "\n";
|
||||
for (auto &entrance_aid : actor->entrance_aids()) {
|
||||
ofs << "\t\t\tto_entrance_actor:" << entrance_aid.Name() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
void DumpOutputActor(const OutputActorPtr &actor, std::ofstream &ofs) {
|
||||
|
|
|
@ -21,23 +21,54 @@ namespace mindspore {
|
|||
namespace runtime {
|
||||
constexpr size_t kEntranceInputStartPos = 1;
|
||||
|
||||
void EntranceActor::RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
if (is_loop_body_execution_) {
|
||||
(void)loop_body_input_op_controls_[sequential_num].emplace_back(input_control);
|
||||
} else {
|
||||
(void)input_op_controls_[sequential_num].emplace_back(input_control);
|
||||
}
|
||||
|
||||
auto is_run = CheckRunningCondition(context);
|
||||
MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
|
||||
<< ") receive the input op control and check running condition:" << is_run
|
||||
<< ", loop body execution:" << is_loop_body_execution_;
|
||||
if (is_run) {
|
||||
Run(context);
|
||||
}
|
||||
}
|
||||
|
||||
void EntranceActor::RunOpRealParameterWithBranchID(OpRealParameterWithBranchID real_parameter_with_branch_id,
|
||||
OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto &sequential_num = context->sequential_num_;
|
||||
real_parameters_with_branch_id_[sequential_num].emplace(real_parameter_with_branch_id);
|
||||
|
||||
if (CheckRunningCondition(context)) {
|
||||
auto is_run = CheckRunningCondition(context);
|
||||
MS_LOG(DEBUG) << "Actor(" << GetAID().Name()
|
||||
<< ") receive the input op data with branch id and check running condition:" << is_run
|
||||
<< ", loop body execution:" << is_loop_body_execution_;
|
||||
if (is_run) {
|
||||
Run(context);
|
||||
}
|
||||
}
|
||||
|
||||
void EntranceActor::ClearDataOnStepEnd(AID *const input_control, OpContext<DeviceTensor> *const context) {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
is_loop_body_execution_ = false;
|
||||
|
||||
if (loop_body_input_controls_nums_ != 0) {
|
||||
loop_body_input_op_controls_.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void EntranceActor::Run(OpContext<DeviceTensor> *const context) {
|
||||
FetchInput(context);
|
||||
EraseInput(context);
|
||||
SendOutput(context);
|
||||
// The actor needs to be disabled after the actor is running, until no actor is running in the entire funcgraph.
|
||||
is_actor_ready_ = false;
|
||||
// The begin execution of step is false and the others execution of step is true.
|
||||
is_loop_body_execution_ = true;
|
||||
}
|
||||
|
||||
void EntranceActor::FetchInput(OpContext<DeviceTensor> *const context) {
|
||||
|
@ -104,33 +135,34 @@ void EntranceActor::FetchInput(OpContext<DeviceTensor> *const context) {
|
|||
}
|
||||
}
|
||||
|
||||
bool EntranceActor::CheckActorStatus(const OpContext<DeviceTensor> *const context) const {
|
||||
if (is_actor_ready_) {
|
||||
return true;
|
||||
}
|
||||
// During operation, entrance actor can be enabled only when receives all control arrows.
|
||||
if (input_controls_num_ != 0) {
|
||||
const auto &control_iter = input_op_controls_.find(context->sequential_num_);
|
||||
if (control_iter != input_op_controls_.end() && control_iter->second.size() == input_controls_num_) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool EntranceActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
|
||||
// When the entrance actor is in the disabled state, it cannot be run.
|
||||
if (!CheckActorStatus(context)) {
|
||||
return false;
|
||||
// Check the running condition in the begin execution of step.
|
||||
// The input controls and input data exist the begin execution of root graph, and there will only be one of the two.
|
||||
if (!is_loop_body_execution_) {
|
||||
if (input_controls_num_ != 0) {
|
||||
const auto &control_iter = input_op_controls_.find(context->sequential_num_);
|
||||
if ((control_iter != input_op_controls_.end()) && (control_iter->second.size() == input_controls_num_)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Data comes from the data source actor.
|
||||
if (input_datas_num_ != 0) {
|
||||
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
if (data_iter != input_op_datas_.end() && data_iter->second.size() == input_datas_num_) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Data comes from the data source actor.
|
||||
if (input_datas_num_ != 0) {
|
||||
const auto &data_iter = input_op_datas_.find(context->sequential_num_);
|
||||
if (data_iter != input_op_datas_.end() && data_iter->second.size() == input_datas_num_) {
|
||||
return true;
|
||||
// Check the controls in the loop body execution of step.
|
||||
if (is_loop_body_execution_ && (loop_body_input_controls_nums_ != 0)) {
|
||||
const auto &control_iter = loop_body_input_op_controls_.find(context->sequential_num_);
|
||||
if ((control_iter == loop_body_input_op_controls_.end()) ||
|
||||
(control_iter->second.size() != loop_body_input_controls_nums_)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -149,7 +181,6 @@ void EntranceActor::EraseInput(const OpContext<DeviceTensor> *const context) {
|
|||
const auto &data_iter = input_op_datas_.find(sequential_num);
|
||||
if (data_iter != input_op_datas_.end()) {
|
||||
input_op_datas_.erase(data_iter);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto &control_iter = input_op_controls_.find(sequential_num);
|
||||
|
@ -157,14 +188,17 @@ void EntranceActor::EraseInput(const OpContext<DeviceTensor> *const context) {
|
|||
input_op_controls_.erase(control_iter);
|
||||
}
|
||||
|
||||
const auto &iter = real_parameters_with_branch_id_.find(sequential_num);
|
||||
if (iter == real_parameters_with_branch_id_.end() || iter->second.empty()) {
|
||||
MS_LOG(ERROR) << "Cannot find input in batch op result for actor:" << GetAID();
|
||||
const auto &loop_body_control_iter = loop_body_input_op_controls_.find(sequential_num);
|
||||
if (loop_body_control_iter != loop_body_input_op_controls_.end()) {
|
||||
loop_body_input_op_controls_.erase(loop_body_control_iter);
|
||||
}
|
||||
|
||||
iter->second.pop();
|
||||
if (iter->second.empty()) {
|
||||
real_parameters_with_branch_id_.erase(sequential_num);
|
||||
const auto &iter = real_parameters_with_branch_id_.find(sequential_num);
|
||||
if (iter != real_parameters_with_branch_id_.end()) {
|
||||
iter->second.pop();
|
||||
if (iter->second.empty()) {
|
||||
real_parameters_with_branch_id_.erase(sequential_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace runtime
|
||||
|
|
|
@ -40,9 +40,17 @@ class EntranceActor : public ControlActor {
|
|||
input_device_tensors_.resize(parameters.size());
|
||||
}
|
||||
~EntranceActor() override = default;
|
||||
|
||||
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
void RunOpRealParameterWithBranchID(OpRealParameterWithBranchID real_parameter_with_branch_id,
|
||||
OpContext<DeviceTensor> *const context);
|
||||
|
||||
// Clear the data which are generated in the loop body execution.
|
||||
void ClearDataOnStepEnd(AID *const input_control, OpContext<DeviceTensor> *const context);
|
||||
|
||||
const std::vector<AID> &loop_body_input_control_arrow_aids() const { return loop_body_input_control_arrow_aids_; }
|
||||
|
||||
protected:
|
||||
void Run(OpContext<DeviceTensor> *const context) override;
|
||||
void FetchInput(OpContext<DeviceTensor> *const context) override;
|
||||
|
@ -52,13 +60,14 @@ class EntranceActor : public ControlActor {
|
|||
private:
|
||||
friend class ControlNodeScheduler;
|
||||
|
||||
// Check if actor is enable. During operation, entrance actor can be enabled only when receives all control arrows.
|
||||
bool CheckActorStatus(const OpContext<DeviceTensor> *const context) const;
|
||||
|
||||
// Is actor ready indicates whether the entrance actor can be executed. In the control flow, the subgraph is an
|
||||
// atomic operation, and execution can only continue after the output of the corresponding exit actor is completed.
|
||||
// At this time, the exit actor will notify the entrance actor to change the ready to true.
|
||||
bool is_actor_ready_{true};
|
||||
// Indicate whether the entrance actor is the execution of loop body. In the control flow, the subgraph can be
|
||||
// triggered to execute in two ways: one is the begin execution of step, another is the execution of loop body.
|
||||
// The input controls are different in the two ways.
|
||||
bool is_loop_body_execution_{false};
|
||||
// The dependent of loop body input actors.
|
||||
mindspore::HashMap<int, std::vector<AID *>> loop_body_input_op_controls_;
|
||||
std::vector<AID> loop_body_input_control_arrow_aids_;
|
||||
size_t loop_body_input_controls_nums_{0};
|
||||
|
||||
// Input data with branch id.
|
||||
mindspore::HashMap<int, std::queue<OpRealParameterWithBranchID>> real_parameters_with_branch_id_;
|
||||
|
|
|
@ -84,8 +84,8 @@ void ExitActor::SendOutput(OpContext<DeviceTensor> *const context) {
|
|||
}
|
||||
auto output_partial = input_partials_[partial_arrow->from_output_index_];
|
||||
MS_EXCEPTION_IF_NULL(output_partial->func_graph_);
|
||||
Async(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial,
|
||||
IntToSize(partial_arrow->to_input_index_), context);
|
||||
ActorDispatcher::Send(partial_arrow->to_op_id_, &ControlActor::RunOpPartial, output_partial,
|
||||
IntToSize(partial_arrow->to_input_index_), context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "runtime/framework/actor/memory_manager_actor.h"
|
||||
#include "runtime/framework/actor/recorder_actor.h"
|
||||
#include "runtime/framework/actor/debug_actor.h"
|
||||
#include "runtime/framework/actor/control_flow/entrance_actor.h"
|
||||
#include "mindrt/include/async/async.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
|
@ -70,6 +71,11 @@ void LoopCountActor::SendOutput(OpContext<DeviceTensor> *const context) {
|
|||
ActorDispatcher::Send(output_control, &OpActor::RunOpControl, from_aid, context);
|
||||
}
|
||||
|
||||
// Send to EntranceActor to clear the data which are generated in the loop body execution.
|
||||
for (auto &entrance_aid : entrance_aids_) {
|
||||
ActorDispatcher::Send(entrance_aid, &EntranceActor::ClearDataOnStepEnd, from_aid, context);
|
||||
}
|
||||
|
||||
// The LoopCountActor exits.
|
||||
if (current_count_ == loop_count_) {
|
||||
current_count_ = 0;
|
||||
|
|
|
@ -52,6 +52,7 @@ class LoopCountActor : public DebugAwareActor {
|
|||
// Get the member.
|
||||
size_t loop_count() const { return loop_count_; }
|
||||
const AID &data_prepare_aid() const { return data_prepare_aid_; }
|
||||
const std::vector<AID> &entrance_aids() const { return entrance_aids_; }
|
||||
|
||||
protected:
|
||||
void Run(OpContext<DeviceTensor> *const context) override;
|
||||
|
@ -59,6 +60,7 @@ class LoopCountActor : public DebugAwareActor {
|
|||
|
||||
private:
|
||||
friend class GraphScheduler;
|
||||
friend class ControlNodeScheduler;
|
||||
|
||||
void IncreaseLoopCount(OpContext<DeviceTensor> *const context);
|
||||
|
||||
|
@ -68,7 +70,9 @@ class LoopCountActor : public DebugAwareActor {
|
|||
// The total running count represents the toal step running count.
|
||||
size_t total_running_count_;
|
||||
|
||||
// The actors which need be handled separately by loop count actor.
|
||||
AID data_prepare_aid_;
|
||||
std::vector<AID> entrance_aids_;
|
||||
};
|
||||
|
||||
using LoopCountActorPtr = std::shared_ptr<LoopCountActor>;
|
||||
|
|
|
@ -380,12 +380,12 @@ void ControlNodeScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &gr
|
|||
// Link data arrows and partial arrows between control actors.
|
||||
LinkArrowForControlActor(actor_set->control_actors_.get(), graph_compiler_info);
|
||||
|
||||
// Link arrows from host data source actor or data prepare actor to entrance actor of root graph.
|
||||
LinkArrowForRootGraphEntranceActor(graph_compiler_info);
|
||||
|
||||
// Link output data arrows from control actors to output actor.
|
||||
LinkDataArrowForOutputActor(actor_set, graph_compiler_info);
|
||||
|
||||
// Link data arrows from host data source actor to control actors.
|
||||
LinkDataArrowForHostDSActor(graph_compiler_info);
|
||||
|
||||
// Link data arrows from entrance actors to kernel actors.
|
||||
LinkDataArrowForKernelActor(graph_compiler_info);
|
||||
|
||||
|
@ -397,6 +397,19 @@ void ControlNodeScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &gr
|
|||
|
||||
// Link control arrows for no input and no output kernel actor.
|
||||
LinkControlArrowForKernelActor(actor_set, graph_compiler_info);
|
||||
|
||||
LinkControlArrowForLoopCountActor(actor_set, graph_compiler_info);
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::ClearActorData(const ControlActorSet *control_actor_set) {
|
||||
if (control_actor_set == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto &exit_actor : control_actor_set->exit_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(exit_actor);
|
||||
exit_actor->created_device_tensors_.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkArrowForControlActor(ControlActorSet *const control_actor_set,
|
||||
|
@ -657,21 +670,11 @@ void ControlNodeScheduler::LinkArrowByKernel(const AnfNodePtr &kernel, ControlAc
|
|||
|
||||
void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor_set,
|
||||
const GraphCompilerInfo &graph_compiler_info) {
|
||||
// Get the exit actor of root graph, In control flow, the final output is always sent by the exit of the root graph.
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
auto control_actor_set = actor_set->control_actors_.get();
|
||||
MS_EXCEPTION_IF_NULL(control_actor_set);
|
||||
const auto &parser = graph_compiler_info.control_node_parser_;
|
||||
MS_EXCEPTION_IF_NULL(parser);
|
||||
const auto &root_graph = parser->root_func_graph_;
|
||||
MS_EXCEPTION_IF_NULL(root_graph);
|
||||
const auto &exit_actor_name = root_graph->ToString() + kExitActorNameSuffix;
|
||||
auto actor = FetchActor(exit_actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
MS_EXCEPTION_IF_NULL(actor_set->loop_count_actor_);
|
||||
auto root_exit_actor = dynamic_cast<ExitActor *>(actor);
|
||||
// link control arrow from root exit actor to loop count actor.
|
||||
LinkControlArrowForExitActor(root_exit_actor, actor_set->loop_count_actor_.get(), kMainBranchID);
|
||||
|
||||
// Since only one set of real parameters are allowed to be executed in funcgraph at the same time, when the funcgraph
|
||||
// stops running, it is necessary to send the control arrow to the corresponding entrance actor at the exit of the
|
||||
|
@ -682,9 +685,7 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
|
|||
const auto &func_graph = graph_to_nodes.first;
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
|
||||
actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto entrance_actor = dynamic_cast<EntranceActor *>(actor);
|
||||
auto entrance_actor = dynamic_cast<EntranceActor *>(FetchActor(actor_name));
|
||||
MS_EXCEPTION_IF_NULL(entrance_actor);
|
||||
|
||||
const auto &nodes = graph_to_nodes.second;
|
||||
|
@ -696,11 +697,9 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
|
|||
} else {
|
||||
actor_name = GetActorName(node);
|
||||
}
|
||||
actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto from_actor = dynamic_cast<ControlActor *>(actor);
|
||||
auto from_actor = dynamic_cast<ControlActor *>(FetchActor(actor_name));
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
LinkControlArrow(from_actor, entrance_actor);
|
||||
LinkLoopBodyControlArrow(from_actor, entrance_actor);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -720,9 +719,7 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
|
|||
const FuncGraphPtr &func_graph = control_actor->node_->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
const auto &actor_name = func_graph->ToString() + kEntranceActorNameSuffix;
|
||||
actor = FetchActor(actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
const auto &entrance_actor = dynamic_cast<EntranceActor *>(actor);
|
||||
const auto &entrance_actor = dynamic_cast<EntranceActor *>(FetchActor(actor_name));
|
||||
MS_EXCEPTION_IF_NULL(entrance_actor);
|
||||
LinkControlArrow(entrance_actor, control_actor);
|
||||
}
|
||||
|
@ -757,6 +754,31 @@ void ControlNodeScheduler::LinkControlArrowForControlActor(ActorSet *const actor
|
|||
}
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkControlArrowForLoopCountActor(const ActorSet *actor_set,
|
||||
const GraphCompilerInfo &graph_compiler_info) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
auto loop_count_actor = actor_set->loop_count_actor_;
|
||||
MS_EXCEPTION_IF_NULL(loop_count_actor);
|
||||
|
||||
// The final output is always sent by the exit of the root graph in control flow.
|
||||
const auto &parser = graph_compiler_info.control_node_parser_;
|
||||
MS_EXCEPTION_IF_NULL(parser);
|
||||
const auto &root_graph = parser->root_func_graph_;
|
||||
MS_EXCEPTION_IF_NULL(root_graph);
|
||||
auto exit_actor_name = root_graph->ToString() + kExitActorNameSuffix;
|
||||
auto root_exit_actor = dynamic_cast<ExitActor *>(FetchActor(exit_actor_name));
|
||||
MS_EXCEPTION_IF_NULL(root_exit_actor);
|
||||
// link control arrow from root exit actor to loop count actor.
|
||||
LinkControlArrowForExitActor(root_exit_actor, loop_count_actor.get(), kMainBranchID);
|
||||
|
||||
// The entrance actor will generate some data in the loop body execution, so need clear on the end of step.
|
||||
MS_EXCEPTION_IF_NULL(actor_set->control_actors_);
|
||||
for (auto &entrance_actor : actor_set->control_actors_->entrance_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(entrance_actor);
|
||||
(void)loop_count_actor->entrance_aids_.emplace_back(entrance_actor->GetAID());
|
||||
}
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkControlArrowForKernelActor(ActorSet *const actor_set,
|
||||
const GraphCompilerInfo &graph_compiler_info) {
|
||||
const auto &parser = graph_compiler_info.control_node_parser_;
|
||||
|
@ -1024,21 +1046,26 @@ void ControlNodeScheduler::LinkDataArrowForOutputActor(ActorSet *const actor_set
|
|||
actor_set->output_actor_->device_contexts_ = iter->second;
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkDataArrowForHostDSActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
// In control flow, the host data source actor sends all the input to the entrance actor of the root graph.
|
||||
const auto &host_ds_actor_name = graph_compiler_info.name_ + "_HostDSActor";
|
||||
auto actor = FetchActor(host_ds_actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
const auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(actor);
|
||||
MS_EXCEPTION_IF_NULL(host_ds_actor);
|
||||
|
||||
void ControlNodeScheduler::LinkArrowForRootGraphEntranceActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
|
||||
const auto &root_graph = graph_compiler_info.control_node_parser_->root_func_graph_;
|
||||
MS_EXCEPTION_IF_NULL(root_graph);
|
||||
const auto &entrance_actor_name = root_graph->ToString() + kEntranceActorNameSuffix;
|
||||
actor = FetchActor(entrance_actor_name);
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
auto to_actor = dynamic_cast<EntranceActor *>(actor);
|
||||
auto to_actor = dynamic_cast<EntranceActor *>(FetchActor(entrance_actor_name));
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
|
||||
const auto &host_ds_actor_name = graph_compiler_info.name_ + "_HostDSActor";
|
||||
auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(host_ds_actor_name));
|
||||
// No host data source actor scenario.
|
||||
if (host_ds_actor == nullptr) {
|
||||
const auto &data_prepare_actor_name = graph_compiler_info.name_ + "_DataPrepareActor";
|
||||
auto data_prepare_actor = FetchActor(data_prepare_actor_name);
|
||||
MS_EXCEPTION_IF_NULL(data_prepare_actor);
|
||||
LinkControlArrow(data_prepare_actor, to_actor);
|
||||
return;
|
||||
}
|
||||
|
||||
// The host data source actor sends all the input to the entrance actor of the root graph.
|
||||
for (size_t i = 0; i < to_actor->formal_parameters_.size(); ++i) {
|
||||
const auto &formal_parameter = to_actor->formal_parameters_[i];
|
||||
MS_EXCEPTION_IF_NULL(formal_parameter.first);
|
||||
|
@ -1076,6 +1103,14 @@ void ControlNodeScheduler::LinkControlArrow(AbstractActor *from_actor, AbstractA
|
|||
(void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID());
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkLoopBodyControlArrow(AbstractActor *from_actor, EntranceActor *to_actor) {
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
(void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
|
||||
to_actor->loop_body_input_controls_nums_++;
|
||||
(void)to_actor->loop_body_input_control_arrow_aids_.emplace_back(from_actor->GetAID());
|
||||
}
|
||||
|
||||
void ControlNodeScheduler::LinkDataArrowForExitActor(ExitActor *const exit_actor, AbstractActor *const to_actor,
|
||||
size_t from_index, size_t to_index, int branch_id) {
|
||||
MS_EXCEPTION_IF_NULL(exit_actor);
|
||||
|
|
|
@ -43,6 +43,9 @@ class ControlNodeScheduler {
|
|||
|
||||
bool CheckActorValid(const ControlActorSetPtr &control_actor_set);
|
||||
|
||||
// The control flow actor will generate some data in the loop body execution, so need clear on the end of execution.
|
||||
void ClearActorData(const ControlActorSet *control_actor_set);
|
||||
|
||||
private:
|
||||
// Interface to create control actors.
|
||||
std::vector<SwitchActorPtr> BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
|
@ -77,14 +80,16 @@ class ControlNodeScheduler {
|
|||
void LinkDataArrowForKernelActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkDataArrowByKernelGraph(const KernelGraphPtr &graph, bool is_call_input_graph,
|
||||
ControlActor *const entrance_actor);
|
||||
void LinkArrowForRootGraphEntranceActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkControlArrowForLoopCountActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkDataArrowForOutputActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkDataArrowForHostDSActor(const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkControlArrowForKernelActor(ActorSet *const actor_set, const GraphCompilerInfo &graph_compiler_info);
|
||||
void LinkControlArrowByAutoMonad(ControlActor *to_actor, const AnfNodePtr &from_node,
|
||||
const ControlNodeParserPtr &parser);
|
||||
|
||||
// Interface tool to link arrows between actors.
|
||||
void LinkControlArrow(AbstractActor *from_actor, AbstractActor *to_actor);
|
||||
void LinkLoopBodyControlArrow(AbstractActor *from_actor, EntranceActor *to_actor);
|
||||
// Data arrow with branch id is only exists from gather actor to entrance actor.
|
||||
void LinkDataWithBranchIDArrow(GatherActor *const gather_actor, EntranceActor *const entrance_actor,
|
||||
const FuncGraphPtr &func_graph);
|
||||
|
|
|
@ -215,6 +215,11 @@ void GraphScheduler::Clear() {
|
|||
ClearAllActors();
|
||||
}
|
||||
|
||||
void GraphScheduler::ClearActorData(const ActorSet *actor_set) {
|
||||
MS_EXCEPTION_IF_NULL(actor_set);
|
||||
control_node_scheduler_.ClearActorData(actor_set->control_actors_.get());
|
||||
}
|
||||
|
||||
using DataArrowLinkFunc = void (GraphScheduler::*)(AbstractActor *const, AbstractActor *const, const KernelWithIndex &,
|
||||
const KernelWithIndex &, const KernelGraphPtr &);
|
||||
static std::map<KernelTransformType, DataArrowLinkFunc> kKernelTypeToLinkFunc;
|
||||
|
|
|
@ -57,6 +57,8 @@ class GraphScheduler {
|
|||
// Clear the members.
|
||||
void Clear();
|
||||
void Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs) noexcept;
|
||||
// The control flow actors will generate some data in the loop body execution, so need clear on the end of execution.
|
||||
void ClearActorData(const ActorSet *actor_set);
|
||||
|
||||
// Transform graph to actor DAG, contains build and link.
|
||||
ActorSet *Transform(const GraphCompilerInfo &graph_compiler_info);
|
||||
|
|
|
@ -882,6 +882,7 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args,
|
|||
size_t output_position = 0;
|
||||
ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs);
|
||||
}
|
||||
runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
|
||||
MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue