diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index 4bea59edb9d..01b9f230d14 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -375,7 +375,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, // maintain a copy of output vector task->outputs_ = *outputs; // sync run graph without output tensor(int dataset graph) - if (!TensorInVector(outputs)) { + if (!TensorInVector(outputs) && !graph->HasPostGraph()) { task->sync_run_ = true; RunTask(task, true, true); return; diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 20bcb8b2c80..507420a854f 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -279,13 +279,16 @@ class KernelGraph : public FuncGraph { } } - bool IsPreGraphFinished() { return pre_graphs_.size() == pre_graph_finished_count_; } - bool IsPostGraphFinished() { + bool IsPreGraphFinished() const { return pre_graphs_.size() == pre_graph_finished_count_; } + bool IsPostGraphFinished() const { if (first_step_) { return true; } return post_graphs_.size() == post_graph_finished_count_; } + + bool HasPostGraph() const { return !post_graphs_.empty(); } + void IncPreGraphFinishedCount() { pre_graph_finished_count_++; } void IncPostGraphFinishedCount() { post_graph_finished_count_++; } void ResetGraphRunningStatus() { diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index 32230204b89..bbeaed863e2 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -406,24 +406,28 @@ std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) { } return GetOriginNodeTarget(inputs[use_index]); } - } else if (IsPrimitiveCNode(node, prim::kPrimUpdateState)) { - const size_t node_inputs_num = 3; - if (inputs.size() >= node_inputs_num) { - return GetOriginNodeTarget(inputs[2]); - } - } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { + } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) || IsPrimitiveCNode(node, prim::kPrimUpdateState)) { std::vector real_inputs; - std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(real_inputs)); + const size_t update_state_valid_input_index = 2; + const size_t make_tuple_valid_input_index = 1; + if (IsPrimitiveCNode(node, prim::kPrimUpdateState) && inputs.size() > update_state_valid_input_index) { + std::copy(inputs.begin() + update_state_valid_input_index, inputs.end(), std::back_inserter(real_inputs)); + } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) && inputs.size() > make_tuple_valid_input_index) { + std::copy(inputs.begin() + make_tuple_valid_input_index, inputs.end(), std::back_inserter(real_inputs)); + } std::string first_input_target = kTargetUnDefined; - bool has_same_target = + bool has_diff_target = std::any_of(std::begin(real_inputs), std::end(real_inputs), [&first_input_target](const AnfNodePtr &n) { auto target = GetOriginNodeTarget(n); - if (target != kTargetUnDefined) { + if (target == kTargetUnDefined) { + return false; + } + if (first_input_target == kTargetUnDefined) { first_input_target = target; } - return target == first_input_target; + return target != first_input_target; }); - if (has_same_target) { + if (!has_diff_target) { return first_input_target; } } else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { @@ -446,15 +450,18 @@ std::string GetVirtualNodeTargetFromUsers(const AnfNodePtr &node) { } auto users = manager->node_users()[cnode]; std::string first_user_target = kTargetUnDefined; - bool has_same_target = + bool has_diff_target = std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair &u) { auto target = GetOriginNodeTarget(u.first); - if (target != kTargetUnDefined) { + if (target == kTargetUnDefined) { + return false; + } + if (first_user_target == kTargetUnDefined) { first_user_target = target; } - return target == first_user_target; + return target != first_user_target; }); - if (has_same_target) { + if (!has_diff_target) { return first_user_target; } return kTargetUnDefined;