forked from mindspore-Ecosystem/mindspore
Add Check valid for actor input.
This commit is contained in:
parent
c3b451f3c9
commit
38b01471a2
|
@ -55,7 +55,11 @@ bool AbstractActor::CheckRunningCondition(const OpContext<DeviceTensor> *context
|
|||
if (data_iter == input_op_datas_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (data_iter->second.size() != input_datas_num_) {
|
||||
if (data_iter->second.size() < input_datas_num_) {
|
||||
return false;
|
||||
} else if (data_iter->second.size() > input_datas_num_) {
|
||||
MS_LOG(ERROR) << "Invalid input data num:" << data_iter->second.size() << " need:" << input_datas_num_
|
||||
<< " for actor:" << GetAID();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -65,7 +69,11 @@ bool AbstractActor::CheckRunningCondition(const OpContext<DeviceTensor> *context
|
|||
if (control_iter == input_op_controls_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (control_iter->second.size() != input_controls_num_) {
|
||||
if (control_iter->second.size() < input_controls_num_) {
|
||||
return false;
|
||||
} else if (control_iter->second.size() > input_controls_num_) {
|
||||
MS_LOG(ERROR) << "Invalid input control num:" << control_iter->second.size() << " need:" << input_controls_num_
|
||||
<< " for actor:" << GetAID();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -377,6 +377,9 @@ void DumpExitActor(const ExitActor *actor, std::ofstream &ofs) {
|
|||
void DumpStackActor(const StackActor *actor, std::ofstream &ofs) {
|
||||
MS_EXCEPTION_IF_NULL(actor);
|
||||
ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
|
||||
ofs << "\t\tinput stack data num:" << actor->input_stack_data_num() << '\n';
|
||||
ofs << "\t\tinput stack partial num:" << actor->input_stack_partials_num() << '\n';
|
||||
ofs << "\t\tinput stack control num:" << actor->input_stack_controls_num() << '\n';
|
||||
DumpControlActor(actor, ofs);
|
||||
}
|
||||
|
||||
|
|
|
@ -161,7 +161,11 @@ bool ControlActor::CheckRunningCondition(const OpContext<DeviceTensor> *context)
|
|||
if (partial_iter == input_op_partials_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (partial_iter->second.size() != input_partials_num_) {
|
||||
if (partial_iter->second.size() < input_partials_num_) {
|
||||
return false;
|
||||
} else if (partial_iter->second.size() > input_partials_num_) {
|
||||
MS_LOG(ERROR) << "Invalid input partial num:" << partial_iter->second.size() << " need:" << input_partials_num_
|
||||
<< " for actor:" << GetAID();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -138,6 +138,15 @@ bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) c
|
|||
return false;
|
||||
}
|
||||
|
||||
if (CheckStackDataRunningCondition(context) && CheckStackPartialRunningCondition(context) &&
|
||||
CheckStackControlRunningCondition(context)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool StackActor::CheckStackDataRunningCondition(const OpContext<DeviceTensor> *context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto iter = input_branch_ids_.find(context->sequential_num_);
|
||||
bool is_branch_id_available = (iter == input_branch_ids_.end() || iter->second.empty());
|
||||
|
||||
|
@ -146,7 +155,11 @@ bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) c
|
|||
if (data_iter == input_stack_data_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (data_iter->second.size() != input_stack_data_num_) {
|
||||
if (data_iter->second.size() < input_stack_data_num_) {
|
||||
return false;
|
||||
} else if (data_iter->second.size() > input_stack_data_num_) {
|
||||
MS_LOG(ERROR) << "Invalid input stack data num:" << data_iter->second.size() << " need:" << input_stack_data_num_
|
||||
<< " for actor:" << GetAID();
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -155,18 +168,35 @@ bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) c
|
|||
return false;
|
||||
}
|
||||
size_t branch_id_size = iter->second.size();
|
||||
if (std::any_of(data_iter->second.begin(), data_iter->second.end(),
|
||||
[branch_id_size](const auto &one_stack) { return one_stack.second.size() != branch_id_size; })) {
|
||||
return false;
|
||||
for (const auto &one_stack : data_iter->second) {
|
||||
if (one_stack.second.size() < branch_id_size) {
|
||||
return false;
|
||||
} else if (one_stack.second.size() > branch_id_size) {
|
||||
MS_LOG(ERROR) << "Invalid input stack data num:" << one_stack.second.size()
|
||||
<< " for input index:" << one_stack.first << " need:" << branch_id_size
|
||||
<< " for actor:" << GetAID();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StackActor::CheckStackPartialRunningCondition(const OpContext<DeviceTensor> *context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto iter = input_branch_ids_.find(context->sequential_num_);
|
||||
bool is_branch_id_available = (iter == input_branch_ids_.end() || iter->second.empty());
|
||||
|
||||
if (input_stack_partials_num_ != 0) {
|
||||
const auto &partial_iter = input_stack_partials_.find(context->sequential_num_);
|
||||
if (partial_iter == input_stack_partials_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (partial_iter->second.size() != input_stack_partials_num_) {
|
||||
if (partial_iter->second.size() < input_stack_partials_num_) {
|
||||
return false;
|
||||
} else if (partial_iter->second.size() > input_stack_partials_num_) {
|
||||
MS_LOG(ERROR) << "Invalid input stack partial num:" << partial_iter->second.size()
|
||||
<< " need:" << input_stack_partials_num_ << " for actor:" << GetAID();
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -175,18 +205,35 @@ bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) c
|
|||
return false;
|
||||
}
|
||||
size_t branch_id_size = iter->second.size();
|
||||
if (std::any_of(partial_iter->second.begin(), partial_iter->second.end(),
|
||||
[branch_id_size](const auto &one_stack) { return one_stack.second.size() != branch_id_size; })) {
|
||||
return false;
|
||||
for (const auto &one_stack : partial_iter->second) {
|
||||
if (one_stack.second.size() < branch_id_size) {
|
||||
return false;
|
||||
} else if (one_stack.second.size() > branch_id_size) {
|
||||
MS_LOG(ERROR) << "Invalid input stack partial num:" << one_stack.second.size()
|
||||
<< " for input index:" << one_stack.first << " need:" << branch_id_size
|
||||
<< " for actor:" << GetAID();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StackActor::CheckStackControlRunningCondition(const OpContext<DeviceTensor> *context) const {
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto iter = input_branch_ids_.find(context->sequential_num_);
|
||||
bool is_branch_id_available = (iter == input_branch_ids_.end() || iter->second.empty());
|
||||
|
||||
if (input_stack_controls_num_ != 0) {
|
||||
const auto &control_iter = input_stack_controls_.find(context->sequential_num_);
|
||||
if (control_iter == input_stack_controls_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (control_iter->second.size() != input_stack_controls_num_) {
|
||||
if (control_iter->second.size() < input_stack_controls_num_) {
|
||||
return false;
|
||||
} else if (control_iter->second.size() > input_stack_controls_num_) {
|
||||
MS_LOG(ERROR) << "Invalid input stack control num:" << control_iter->second.size()
|
||||
<< " need:" << input_stack_controls_num_ << " for actor:" << GetAID();
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -195,9 +242,15 @@ bool StackActor::CheckRunningCondition(const OpContext<DeviceTensor> *context) c
|
|||
return false;
|
||||
}
|
||||
size_t branch_id_size = iter->second.size();
|
||||
if (std::any_of(control_iter->second.begin(), control_iter->second.end(),
|
||||
[branch_id_size](const auto &one_stack) { return one_stack.second != branch_id_size; })) {
|
||||
return false;
|
||||
for (const auto &one_stack : control_iter->second) {
|
||||
if (one_stack.second < branch_id_size) {
|
||||
return false;
|
||||
} else if (one_stack.second > branch_id_size) {
|
||||
MS_LOG(ERROR) << "Invalid input stack control num:" << one_stack.second
|
||||
<< " for input actor:" << one_stack.first->Name() << " need:" << branch_id_size
|
||||
<< " for actor:" << GetAID();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
|
|
@ -45,6 +45,9 @@ class StackActor : public ControlActor {
|
|||
void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
|
||||
void RunOpPartial(const OpPartialPtr &partial, size_t position, OpContext<DeviceTensor> *const context) override;
|
||||
void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
|
||||
size_t input_stack_data_num() const { return input_stack_data_num_; }
|
||||
size_t input_stack_partials_num() const { return input_stack_partials_num_; }
|
||||
size_t input_stack_controls_num() const { return input_stack_controls_num_; }
|
||||
|
||||
protected:
|
||||
void Init() override;
|
||||
|
@ -56,6 +59,11 @@ class StackActor : public ControlActor {
|
|||
private:
|
||||
friend class ControlNodeScheduler;
|
||||
|
||||
// Check running condition functions.
|
||||
bool CheckStackDataRunningCondition(const OpContext<DeviceTensor> *context) const;
|
||||
bool CheckStackPartialRunningCondition(const OpContext<DeviceTensor> *context) const;
|
||||
bool CheckStackControlRunningCondition(const OpContext<DeviceTensor> *context) const;
|
||||
|
||||
// The input data and partials records that the stack actor is copied from the input nodes and needs to be
|
||||
// stored in the device tensor in the stack.
|
||||
mindspore::HashMap<int, mindspore::HashMap<size_t, std::stack<DeviceTensor *>>> input_stack_data_;
|
||||
|
|
|
@ -211,6 +211,14 @@ bool IsValidPartialCNode(const AnfNodePtr &node) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CheckExitActorInvalid(const ExitActorPtr &exit_actor) {
|
||||
MS_EXCEPTION_IF_NULL(exit_actor);
|
||||
|
||||
return exit_actor->output_data_arrows().empty() && exit_actor->output_partial_arrows().empty() &&
|
||||
exit_actor->output_control_arrows().empty() && exit_actor->output_branch_control_arrows().empty() &&
|
||||
exit_actor->output_branch_data_arrows().empty() && exit_actor->output_branch_partial_arrows().empty();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<GatherActorPtr> ControlNodeScheduler::BuildGatherActor(const GraphCompilerInfo &graph_compiler_info) {
|
||||
|
@ -1564,6 +1572,13 @@ bool ControlNodeScheduler::CheckActorValid(const ActorSet *actor_set) const {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &exit_actor : actor_set->control_actors_->exit_actors_) {
|
||||
MS_EXCEPTION_IF_NULL(exit_actor);
|
||||
if (CheckExitActorInvalid(exit_actor)) {
|
||||
MS_LOG(EXCEPTION) << "Invalid exit actor:" << exit_actor->GetAID();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue