!19175 fix ps hunge up bug
Merge pull request !19175 from kisnwang/fix-ps-hange-bug
This commit is contained in:
commit
f29fe5c51c
|
@ -375,7 +375,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
|
||||||
// maintain a copy of output vector
|
// maintain a copy of output vector
|
||||||
task->outputs_ = *outputs;
|
task->outputs_ = *outputs;
|
||||||
// sync run graph without output tensor(int dataset graph)
|
// sync run graph without output tensor(int dataset graph)
|
||||||
if (!TensorInVector(outputs)) {
|
if (!TensorInVector(outputs) && !graph->HasPostGraph()) {
|
||||||
task->sync_run_ = true;
|
task->sync_run_ = true;
|
||||||
RunTask(task, true, true);
|
RunTask(task, true, true);
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -279,13 +279,16 @@ class KernelGraph : public FuncGraph {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsPreGraphFinished() { return pre_graphs_.size() == pre_graph_finished_count_; }
|
bool IsPreGraphFinished() const { return pre_graphs_.size() == pre_graph_finished_count_; }
|
||||||
bool IsPostGraphFinished() {
|
bool IsPostGraphFinished() const {
|
||||||
if (first_step_) {
|
if (first_step_) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return post_graphs_.size() == post_graph_finished_count_;
|
return post_graphs_.size() == post_graph_finished_count_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HasPostGraph() const { return !post_graphs_.empty(); }
|
||||||
|
|
||||||
void IncPreGraphFinishedCount() { pre_graph_finished_count_++; }
|
void IncPreGraphFinishedCount() { pre_graph_finished_count_++; }
|
||||||
void IncPostGraphFinishedCount() { post_graph_finished_count_++; }
|
void IncPostGraphFinishedCount() { post_graph_finished_count_++; }
|
||||||
void ResetGraphRunningStatus() {
|
void ResetGraphRunningStatus() {
|
||||||
|
|
|
@ -406,24 +406,28 @@ std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
return GetOriginNodeTarget(inputs[use_index]);
|
return GetOriginNodeTarget(inputs[use_index]);
|
||||||
}
|
}
|
||||||
} else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) || IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
|
||||||
const size_t node_inputs_num = 3;
|
|
||||||
if (inputs.size() >= node_inputs_num) {
|
|
||||||
return GetOriginNodeTarget(inputs[2]);
|
|
||||||
}
|
|
||||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
|
||||||
std::vector<AnfNodePtr> real_inputs;
|
std::vector<AnfNodePtr> real_inputs;
|
||||||
std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(real_inputs));
|
const size_t update_state_valid_input_index = 2;
|
||||||
|
const size_t make_tuple_valid_input_index = 1;
|
||||||
|
if (IsPrimitiveCNode(node, prim::kPrimUpdateState) && inputs.size() > update_state_valid_input_index) {
|
||||||
|
std::copy(inputs.begin() + update_state_valid_input_index, inputs.end(), std::back_inserter(real_inputs));
|
||||||
|
} else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) && inputs.size() > make_tuple_valid_input_index) {
|
||||||
|
std::copy(inputs.begin() + make_tuple_valid_input_index, inputs.end(), std::back_inserter(real_inputs));
|
||||||
|
}
|
||||||
std::string first_input_target = kTargetUnDefined;
|
std::string first_input_target = kTargetUnDefined;
|
||||||
bool has_same_target =
|
bool has_diff_target =
|
||||||
std::any_of(std::begin(real_inputs), std::end(real_inputs), [&first_input_target](const AnfNodePtr &n) {
|
std::any_of(std::begin(real_inputs), std::end(real_inputs), [&first_input_target](const AnfNodePtr &n) {
|
||||||
auto target = GetOriginNodeTarget(n);
|
auto target = GetOriginNodeTarget(n);
|
||||||
if (target != kTargetUnDefined) {
|
if (target == kTargetUnDefined) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (first_input_target == kTargetUnDefined) {
|
||||||
first_input_target = target;
|
first_input_target = target;
|
||||||
}
|
}
|
||||||
return target == first_input_target;
|
return target != first_input_target;
|
||||||
});
|
});
|
||||||
if (has_same_target) {
|
if (!has_diff_target) {
|
||||||
return first_input_target;
|
return first_input_target;
|
||||||
}
|
}
|
||||||
} else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
} else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||||
|
@ -446,15 +450,18 @@ std::string GetVirtualNodeTargetFromUsers(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
auto users = manager->node_users()[cnode];
|
auto users = manager->node_users()[cnode];
|
||||||
std::string first_user_target = kTargetUnDefined;
|
std::string first_user_target = kTargetUnDefined;
|
||||||
bool has_same_target =
|
bool has_diff_target =
|
||||||
std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair<AnfNodePtr, int> &u) {
|
std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair<AnfNodePtr, int> &u) {
|
||||||
auto target = GetOriginNodeTarget(u.first);
|
auto target = GetOriginNodeTarget(u.first);
|
||||||
if (target != kTargetUnDefined) {
|
if (target == kTargetUnDefined) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (first_user_target == kTargetUnDefined) {
|
||||||
first_user_target = target;
|
first_user_target = target;
|
||||||
}
|
}
|
||||||
return target == first_user_target;
|
return target != first_user_target;
|
||||||
});
|
});
|
||||||
if (has_same_target) {
|
if (!has_diff_target) {
|
||||||
return first_user_target;
|
return first_user_target;
|
||||||
}
|
}
|
||||||
return kTargetUnDefined;
|
return kTargetUnDefined;
|
||||||
|
|
Loading…
Reference in New Issue