!19175 fix ps hunge up bug

Merge pull request !19175 from kisnwang/fix-ps-hange-bug
This commit is contained in:
i-robot 2021-07-02 01:29:53 +00:00 committed by Gitee
commit f29fe5c51c
3 changed files with 28 additions and 18 deletions

View File

@ -375,7 +375,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
// maintain a copy of output vector // maintain a copy of output vector
task->outputs_ = *outputs; task->outputs_ = *outputs;
// sync run graph without output tensor(int dataset graph) // sync run graph without output tensor(int dataset graph)
if (!TensorInVector(outputs)) { if (!TensorInVector(outputs) && !graph->HasPostGraph()) {
task->sync_run_ = true; task->sync_run_ = true;
RunTask(task, true, true); RunTask(task, true, true);
return; return;

View File

@ -279,13 +279,16 @@ class KernelGraph : public FuncGraph {
} }
} }
bool IsPreGraphFinished() { return pre_graphs_.size() == pre_graph_finished_count_; } bool IsPreGraphFinished() const { return pre_graphs_.size() == pre_graph_finished_count_; }
bool IsPostGraphFinished() { bool IsPostGraphFinished() const {
if (first_step_) { if (first_step_) {
return true; return true;
} }
return post_graphs_.size() == post_graph_finished_count_; return post_graphs_.size() == post_graph_finished_count_;
} }
bool HasPostGraph() const { return !post_graphs_.empty(); }
void IncPreGraphFinishedCount() { pre_graph_finished_count_++; } void IncPreGraphFinishedCount() { pre_graph_finished_count_++; }
void IncPostGraphFinishedCount() { post_graph_finished_count_++; } void IncPostGraphFinishedCount() { post_graph_finished_count_++; }
void ResetGraphRunningStatus() { void ResetGraphRunningStatus() {

View File

@ -406,24 +406,28 @@ std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) {
} }
return GetOriginNodeTarget(inputs[use_index]); return GetOriginNodeTarget(inputs[use_index]);
} }
} else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) { } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) || IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
const size_t node_inputs_num = 3;
if (inputs.size() >= node_inputs_num) {
return GetOriginNodeTarget(inputs[2]);
}
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
std::vector<AnfNodePtr> real_inputs; std::vector<AnfNodePtr> real_inputs;
std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(real_inputs)); const size_t update_state_valid_input_index = 2;
const size_t make_tuple_valid_input_index = 1;
if (IsPrimitiveCNode(node, prim::kPrimUpdateState) && inputs.size() > update_state_valid_input_index) {
std::copy(inputs.begin() + update_state_valid_input_index, inputs.end(), std::back_inserter(real_inputs));
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) && inputs.size() > make_tuple_valid_input_index) {
std::copy(inputs.begin() + make_tuple_valid_input_index, inputs.end(), std::back_inserter(real_inputs));
}
std::string first_input_target = kTargetUnDefined; std::string first_input_target = kTargetUnDefined;
bool has_same_target = bool has_diff_target =
std::any_of(std::begin(real_inputs), std::end(real_inputs), [&first_input_target](const AnfNodePtr &n) { std::any_of(std::begin(real_inputs), std::end(real_inputs), [&first_input_target](const AnfNodePtr &n) {
auto target = GetOriginNodeTarget(n); auto target = GetOriginNodeTarget(n);
if (target != kTargetUnDefined) { if (target == kTargetUnDefined) {
return false;
}
if (first_input_target == kTargetUnDefined) {
first_input_target = target; first_input_target = target;
} }
return target == first_input_target; return target != first_input_target;
}); });
if (has_same_target) { if (!has_diff_target) {
return first_input_target; return first_input_target;
} }
} else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { } else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
@ -446,15 +450,18 @@ std::string GetVirtualNodeTargetFromUsers(const AnfNodePtr &node) {
} }
auto users = manager->node_users()[cnode]; auto users = manager->node_users()[cnode];
std::string first_user_target = kTargetUnDefined; std::string first_user_target = kTargetUnDefined;
bool has_same_target = bool has_diff_target =
std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair<AnfNodePtr, int> &u) { std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair<AnfNodePtr, int> &u) {
auto target = GetOriginNodeTarget(u.first); auto target = GetOriginNodeTarget(u.first);
if (target != kTargetUnDefined) { if (target == kTargetUnDefined) {
return false;
}
if (first_user_target == kTargetUnDefined) {
first_user_target = target; first_user_target = target;
} }
return target == first_user_target; return target != first_user_target;
}); });
if (has_same_target) { if (!has_diff_target) {
return first_user_target; return first_user_target;
} }
return kTargetUnDefined; return kTargetUnDefined;