Add Check valid for actor input.

This commit is contained in:
gaoyong10 2022-02-25 11:41:03 +08:00
parent c3b451f3c9
commit 38b01471a2
6 changed files with 106 additions and 15 deletions

View File

@ -55,7 +55,11 @@ bool AbstractActor::CheckRunningCondition(const OpContext<DeviceTensor> *context
if (data_iter == input_op_datas_.end()) {
return false;
}
if (data_iter->second.size() != input_datas_num_) {
if (data_iter->second.size() < input_datas_num_) {
return false;
} else if (data_iter->second.size() > input_datas_num_) {
MS_LOG(ERROR) << "Invalid input data num:" << data_iter->second.size() << " need:" << input_datas_num_
<< " for actor:" << GetAID();
return false;
}
}
@ -65,7 +69,11 @@ bool AbstractActor::CheckRunningCondition(const OpContext<DeviceTensor> *context
if (control_iter == input_op_controls_.end()) {
return false;
}
if (control_iter->second.size() != input_controls_num_) {
if (control_iter->second.size() < input_controls_num_) {
return false;
} else if (control_iter->second.size() > input_controls_num_) {
MS_LOG(ERROR) << "Invalid input control num:" << control_iter->second.size() << " need:" << input_controls_num_
<< " for actor:" << GetAID();
return false;
}
}

View File

@ -377,6 +377,9 @@ void DumpExitActor(const ExitActor *actor, std::ofstream &ofs) {
void DumpStackActor(const StackActor *actor, std::ofstream &ofs) {
MS_EXCEPTION_IF_NULL(actor);
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
ofs << "\t\tinput stack data num:" << actor->input_stack_data_num() << '\n';
ofs << "\t\tinput stack partial num:" << actor->input_stack_partials_num() << '\n';
ofs << "\t\tinput stack control num:" << actor->input_stack_controls_num() << '\n';
DumpControlActor(actor, ofs);
}

View File

@ -161,7 +161,11 @@ bool ControlActor::CheckRunningCondition(const OpContext<DeviceTensor> *context)
if (partial_iter == input_op_partials_.end()) {
return false;
}
if (partial_iter->second.size() != input_partials_num_) {
if (partial_iter->second.size() < input_partials_num_) {
return false;
} else if (partial_iter->second.size() > input_partials_num_) {
MS_LOG(ERROR) << "Invalid input partial num:" << partial_iter->second.size() << " need:" << input_partials_num_
<< " for actor:" << GetAID();
return false;
}
}

View File

@ -138,6 +138,15 @@ bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) c
return false;
}
if (CheckStackDataRunningCondition(context) && CheckStackPartialRunningCondition(context) &&
CheckStackControlRunningCondition(context)) {
return true;
}
return false;
}
bool StackActor::CheckStackDataRunningCondition(const OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);
auto iter = input_branch_ids_.find(context->sequential_num_);
bool is_branch_id_available = (iter == input_branch_ids_.end() || iter->second.empty());
@ -146,7 +155,11 @@ bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) c
if (data_iter == input_stack_data_.end()) {
return false;
}
if (data_iter->second.size() != input_stack_data_num_) {
if (data_iter->second.size() < input_stack_data_num_) {
return false;
} else if (data_iter->second.size() > input_stack_data_num_) {
MS_LOG(ERROR) << "Invalid input stack data num:" << data_iter->second.size() << " need:" << input_stack_data_num_
<< " for actor:" << GetAID();
return false;
}
@ -155,18 +168,35 @@ bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) c
return false;
}
size_t branch_id_size = iter->second.size();
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
[branch_id_size](const auto &one_stack) { return one_stack.second.size() != branch_id_size; })) {
return false;
for (const auto &one_stack : data_iter->second) {
if (one_stack.second.size() < branch_id_size) {
return false;
} else if (one_stack.second.size() > branch_id_size) {
MS_LOG(ERROR) << "Invalid input stack data num:" << one_stack.second.size()
<< " for input index:" << one_stack.first << " need:" << branch_id_size
<< " for actor:" << GetAID();
return false;
}
}
}
return true;
}
bool StackActor::CheckStackPartialRunningCondition(const OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);
auto iter = input_branch_ids_.find(context->sequential_num_);
bool is_branch_id_available = (iter == input_branch_ids_.end() || iter->second.empty());
if (input_stack_partials_num_ != 0) {
const auto &partial_iter = input_stack_partials_.find(context->sequential_num_);
if (partial_iter == input_stack_partials_.end()) {
return false;
}
if (partial_iter->second.size() != input_stack_partials_num_) {
if (partial_iter->second.size() < input_stack_partials_num_) {
return false;
} else if (partial_iter->second.size() > input_stack_partials_num_) {
MS_LOG(ERROR) << "Invalid input stack partial num:" << partial_iter->second.size()
<< " need:" << input_stack_partials_num_ << " for actor:" << GetAID();
return false;
}
@ -175,18 +205,35 @@ bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) c
return false;
}
size_t branch_id_size = iter->second.size();
if (std::any_of(partial_iter->second.begin(), partial_iter->second.end(),
[branch_id_size](const auto &one_stack) { return one_stack.second.size() != branch_id_size; })) {
return false;
for (const auto &one_stack : partial_iter->second) {
if (one_stack.second.size() < branch_id_size) {
return false;
} else if (one_stack.second.size() > branch_id_size) {
MS_LOG(ERROR) << "Invalid input stack partial num:" << one_stack.second.size()
<< " for input index:" << one_stack.first << " need:" << branch_id_size
<< " for actor:" << GetAID();
return false;
}
}
}
return true;
}
bool StackActor::CheckStackControlRunningCondition(const OpContext<DeviceTensor> *context) const {
MS_EXCEPTION_IF_NULL(context);
auto iter = input_branch_ids_.find(context->sequential_num_);
bool is_branch_id_available = (iter == input_branch_ids_.end() || iter->second.empty());
if (input_stack_controls_num_ != 0) {
const auto &control_iter = input_stack_controls_.find(context->sequential_num_);
if (control_iter == input_stack_controls_.end()) {
return false;
}
if (control_iter->second.size() != input_stack_controls_num_) {
if (control_iter->second.size() < input_stack_controls_num_) {
return false;
} else if (control_iter->second.size() > input_stack_controls_num_) {
MS_LOG(ERROR) << "Invalid input stack control num:" << control_iter->second.size()
<< " need:" << input_stack_controls_num_ << " for actor:" << GetAID();
return false;
}
@ -195,9 +242,15 @@ bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) c
return false;
}
size_t branch_id_size = iter->second.size();
if (std::any_of(control_iter->second.begin(), control_iter->second.end(),
[branch_id_size](const auto &one_stack) { return one_stack.second != branch_id_size; })) {
return false;
for (const auto &one_stack : control_iter->second) {
if (one_stack.second < branch_id_size) {
return false;
} else if (one_stack.second > branch_id_size) {
MS_LOG(ERROR) << "Invalid input stack control num:" << one_stack.second
<< " for input actor:" << one_stack.first->Name() << " need:" << branch_id_size
<< " for actor:" << GetAID();
return false;
}
}
}
return true;

View File

@ -45,6 +45,9 @@ class StackActor : public ControlActor {
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
void RunOpPartial(const OpPartialPtr &partial, size_t position, OpContext<DeviceTensor> *const context) override;
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
size_t input_stack_data_num() const { return input_stack_data_num_; }
size_t input_stack_partials_num() const { return input_stack_partials_num_; }
size_t input_stack_controls_num() const { return input_stack_controls_num_; }
protected:
void Init() override;
@ -56,6 +59,11 @@ class StackActor : public ControlActor {
private:
friend class ControlNodeScheduler;
// Check running condition functions.
bool CheckStackDataRunningCondition(const OpContext<DeviceTensor> *context) const;
bool CheckStackPartialRunningCondition(const OpContext<DeviceTensor> *context) const;
bool CheckStackControlRunningCondition(const OpContext<DeviceTensor> *context) const;
// The input data and partials records that the stack actor is copied from the input nodes and needs to be
// stored in the device tensor in the stack.
mindspore::HashMap<int, mindspore::HashMap<size_t, std::stack<DeviceTensor *>>> input_stack_data_;

View File

@ -211,6 +211,14 @@ bool IsValidPartialCNode(const AnfNodePtr &node) {
}
return true;
}
bool CheckExitActorInvalid(const ExitActorPtr &exit_actor) {
MS_EXCEPTION_IF_NULL(exit_actor);
return exit_actor->output_data_arrows().empty() && exit_actor->output_partial_arrows().empty() &&
exit_actor->output_control_arrows().empty() && exit_actor->output_branch_control_arrows().empty() &&
exit_actor->output_branch_data_arrows().empty() && exit_actor->output_branch_partial_arrows().empty();
}
} // namespace
std::vector<GatherActorPtr> ControlNodeScheduler::BuildGatherActor(const GraphCompilerInfo &graph_compiler_info) {
@ -1564,6 +1572,13 @@ bool ControlNodeScheduler::CheckActorValid(const ActorSet *actor_set) const {
}
}
}
for (const auto &exit_actor : actor_set->control_actors_->exit_actors_) {
MS_EXCEPTION_IF_NULL(exit_actor);
if (CheckExitActorInvalid(exit_actor)) {
MS_LOG(EXCEPTION) << "Invalid exit actor:" << exit_actor->GetAID();
}
}
return true;
}