From 5ff1124e71db68ff51fdc1e1ec9365729ce2e86e Mon Sep 17 00:00:00 2001 From: lichenever Date: Mon, 11 Oct 2021 11:35:54 +0800 Subject: [PATCH] fix_pipeline_with_no_loss_bug --- .../graph_util/pipeline_split_utils.cc | 43 +++++++++++++------ .../graph_util/pipeline_split_utils.h | 2 +- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc index 7d2aa80e946..bc947db4429 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc @@ -27,15 +27,25 @@ #include "frontend/parallel/device_manager.h" #include "frontend/parallel/context.h" #include "frontend/parallel/step_parallel.h" +#include "utils/parallel_node_check.h" namespace mindspore { namespace parallel { -const std::set END_NODE_BLACK_LIST = {prim::kPrimDepend, prim::kPrimTupleGetItem, prim::kPrimAdd, - prim::kPrimSoftmaxCrossEntropyWithLogits}; +const std::set END_NODE_BLACK_LIST = { + prim::kPrimDepend, prim::kPrimTupleGetItem, prim::kPrimAdd, prim::kPrimSoftmaxCrossEntropyWithLogits, + prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimReshape}; static bool IsInEndNodeBlackList(const CNodePtr &cnode) { - for (auto &prim : END_NODE_BLACK_LIST) { - if (IsPrimitiveCNode(cnode, prim)) { + MS_EXCEPTION_IF_NULL(cnode); + if (!IsValueNode(cnode->input(0))) { + return true; + } + auto prim = GetValueNode(cnode->input(0)); + if (IsInParallelBlackList(prim)) { + return true; + } + for (auto &prim_node : END_NODE_BLACK_LIST) { + if (IsPrimitiveCNode(cnode, prim_node)) { return true; } } @@ -414,16 +424,24 @@ void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, con } } -AnfNodePtr GetPreNode(const AnfNodePtr &node, size_t max_depth) { +AnfNodePtr GetPreNode(const AnfNodePtr &node) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - if (max_depth > MAX_RECURSIVE_DEPTH) { - MS_LOG(EXCEPTION) << "Recursive call is larger than 100000."; + std::vector node_queue = {node}; + while (!node_queue.empty()) { + auto cur_node = (*node_queue.begin())->cast(); + if (!cur_node) { + node_queue.erase(node_queue.begin()); + continue; + } + node_queue.erase(node_queue.begin()); + if (!IsInEndNodeBlackList(cur_node)) { + MS_LOG(INFO) << "Pipeline End node: " << cur_node->DebugString(); + return cur_node; + } + node_queue.insert(node_queue.end(), cur_node->inputs().begin() + 1, cur_node->inputs().end()); } - if (IsInEndNodeBlackList(cnode)) { - return GetPreNode(cnode->input(1), max_depth + 1); - } - return cnode; + MS_LOG(EXCEPTION) << "Get Pipeline End node failed."; } void LastStageEndNode(const std::vector &all_nodes, const FuncGraphManagerPtr &manager) { @@ -448,7 +466,8 @@ void LastStageEndNode(const std::vector &all_nodes, const FuncGraphM if (!temp_prim || temp_prim->HasAttr(PIPELINE_END)) { continue; } - auto end_node = GetPreNode(temp_node, 0); + auto end_node = GetPreNode(temp_node); + MS_EXCEPTION_IF_NULL(end_node); auto end_cnode = end_node->cast(); MS_EXCEPTION_IF_NULL(end_cnode); auto end_prim = GetCNodePrimitive(end_node); diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.h b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.h index 90a845c4ebb..f67df4a79ff 100755 --- a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.h +++ b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.h @@ -57,7 +57,7 @@ void Reorder(const FuncGraphPtr &root); void ReorderForPredict(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); void HandleMicroBatch(const std::vector &all_nodes, const FuncGraphManagerPtr &manager); void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value, size_t max_depth); -AnfNodePtr GetPreNode(const AnfNodePtr &node, size_t max_depth); +AnfNodePtr GetPreNode(const AnfNodePtr &node); void LastStageEndNode(const std::vector &all_nodes, const FuncGraphManagerPtr &manager); void SetStridedSliceStrategy(const AnfNodePtr &node); void ParameterStartNode(const std::vector &all_nodes, const FuncGraphManagerPtr &manager);