!19175 fix ps hunge up bug

Merge pull request !19175 from kisnwang/fix-ps-hange-bug
This commit is contained in:
i-robot 2021-07-02 01:29:53 +00:00 committed by Gitee
commit f29fe5c51c
3 changed files with 28 additions and 18 deletions

View File

@ -375,7 +375,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
// maintain a copy of output vector
task->outputs_ = *outputs;
// sync run graph without output tensor(int dataset graph)
if (!TensorInVector(outputs)) {
if (!TensorInVector(outputs) && !graph->HasPostGraph()) {
task->sync_run_ = true;
RunTask(task, true, true);
return;

View File

@ -279,13 +279,16 @@ class KernelGraph : public FuncGraph {
}
}
bool IsPreGraphFinished() { return pre_graphs_.size() == pre_graph_finished_count_; }
bool IsPostGraphFinished() {
bool IsPreGraphFinished() const { return pre_graphs_.size() == pre_graph_finished_count_; }
bool IsPostGraphFinished() const {
if (first_step_) {
return true;
}
return post_graphs_.size() == post_graph_finished_count_;
}
bool HasPostGraph() const { return !post_graphs_.empty(); }
void IncPreGraphFinishedCount() { pre_graph_finished_count_++; }
void IncPostGraphFinishedCount() { post_graph_finished_count_++; }
void ResetGraphRunningStatus() {

View File

@ -406,24 +406,28 @@ std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) {
}
return GetOriginNodeTarget(inputs[use_index]);
}
} else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
const size_t node_inputs_num = 3;
if (inputs.size() >= node_inputs_num) {
return GetOriginNodeTarget(inputs[2]);
}
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) || IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
std::vector<AnfNodePtr> real_inputs;
std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(real_inputs));
const size_t update_state_valid_input_index = 2;
const size_t make_tuple_valid_input_index = 1;
if (IsPrimitiveCNode(node, prim::kPrimUpdateState) && inputs.size() > update_state_valid_input_index) {
std::copy(inputs.begin() + update_state_valid_input_index, inputs.end(), std::back_inserter(real_inputs));
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) && inputs.size() > make_tuple_valid_input_index) {
std::copy(inputs.begin() + make_tuple_valid_input_index, inputs.end(), std::back_inserter(real_inputs));
}
std::string first_input_target = kTargetUnDefined;
bool has_same_target =
bool has_diff_target =
std::any_of(std::begin(real_inputs), std::end(real_inputs), [&first_input_target](const AnfNodePtr &n) {
auto target = GetOriginNodeTarget(n);
if (target != kTargetUnDefined) {
if (target == kTargetUnDefined) {
return false;
}
if (first_input_target == kTargetUnDefined) {
first_input_target = target;
}
return target == first_input_target;
return target != first_input_target;
});
if (has_same_target) {
if (!has_diff_target) {
return first_input_target;
}
} else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
@ -446,15 +450,18 @@ std::string GetVirtualNodeTargetFromUsers(const AnfNodePtr &node) {
}
auto users = manager->node_users()[cnode];
std::string first_user_target = kTargetUnDefined;
bool has_same_target =
bool has_diff_target =
std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair<AnfNodePtr, int> &u) {
auto target = GetOriginNodeTarget(u.first);
if (target != kTargetUnDefined) {
if (target == kTargetUnDefined) {
return false;
}
if (first_user_target == kTargetUnDefined) {
first_user_target = target;
}
return target == first_user_target;
return target != first_user_target;
});
if (has_same_target) {
if (!has_diff_target) {
return first_user_target;
}
return kTargetUnDefined;