forked from mindspore-Ecosystem/mindspore
!24651 [AutoParallel]fix_pipeline_with_no_loss_bug_master
Merge pull request !24651 from lichen/fix_pipeline_with_no_loss_bug_master
This commit is contained in:
commit
72fc71fc07
|
@ -27,15 +27,25 @@
|
|||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "frontend/parallel/step_parallel.h"
|
||||
#include "utils/parallel_node_check.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
const std::set<PrimitivePtr> END_NODE_BLACK_LIST = {prim::kPrimDepend, prim::kPrimTupleGetItem, prim::kPrimAdd,
|
||||
prim::kPrimSoftmaxCrossEntropyWithLogits};
|
||||
const std::set<PrimitivePtr> END_NODE_BLACK_LIST = {
|
||||
prim::kPrimDepend, prim::kPrimTupleGetItem, prim::kPrimAdd, prim::kPrimSoftmaxCrossEntropyWithLogits,
|
||||
prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimReshape};
|
||||
|
||||
static bool IsInEndNodeBlackList(const CNodePtr &cnode) {
|
||||
for (auto &prim : END_NODE_BLACK_LIST) {
|
||||
if (IsPrimitiveCNode(cnode, prim)) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return true;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (IsInParallelBlackList(prim)) {
|
||||
return true;
|
||||
}
|
||||
for (auto &prim_node : END_NODE_BLACK_LIST) {
|
||||
if (IsPrimitiveCNode(cnode, prim_node)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -414,16 +424,24 @@ void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, con
|
|||
}
|
||||
}
|
||||
|
||||
AnfNodePtr GetPreNode(const AnfNodePtr &node, size_t max_depth) {
|
||||
AnfNodePtr GetPreNode(const AnfNodePtr &node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (max_depth > MAX_RECURSIVE_DEPTH) {
|
||||
MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
|
||||
std::vector<AnfNodePtr> node_queue = {node};
|
||||
while (!node_queue.empty()) {
|
||||
auto cur_node = (*node_queue.begin())->cast<CNodePtr>();
|
||||
if (!cur_node) {
|
||||
node_queue.erase(node_queue.begin());
|
||||
continue;
|
||||
}
|
||||
node_queue.erase(node_queue.begin());
|
||||
if (!IsInEndNodeBlackList(cur_node)) {
|
||||
MS_LOG(INFO) << "Pipeline End node: " << cur_node->DebugString();
|
||||
return cur_node;
|
||||
}
|
||||
node_queue.insert(node_queue.end(), cur_node->inputs().begin() + 1, cur_node->inputs().end());
|
||||
}
|
||||
if (IsInEndNodeBlackList(cnode)) {
|
||||
return GetPreNode(cnode->input(1), max_depth + 1);
|
||||
}
|
||||
return cnode;
|
||||
MS_LOG(EXCEPTION) << "Get Pipeline End node failed.";
|
||||
}
|
||||
|
||||
void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
|
||||
|
@ -448,7 +466,8 @@ void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphM
|
|||
if (!temp_prim || temp_prim->HasAttr(PIPELINE_END)) {
|
||||
continue;
|
||||
}
|
||||
auto end_node = GetPreNode(temp_node, 0);
|
||||
auto end_node = GetPreNode(temp_node);
|
||||
MS_EXCEPTION_IF_NULL(end_node);
|
||||
auto end_cnode = end_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(end_cnode);
|
||||
auto end_prim = GetCNodePrimitive(end_node);
|
||||
|
|
|
@ -57,7 +57,7 @@ void Reorder(const FuncGraphPtr &root);
|
|||
void ReorderForPredict(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||
void HandleMicroBatch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
|
||||
void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value, size_t max_depth);
|
||||
AnfNodePtr GetPreNode(const AnfNodePtr &node, size_t max_depth);
|
||||
AnfNodePtr GetPreNode(const AnfNodePtr &node);
|
||||
void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
|
||||
void SetStridedSliceStrategy(const AnfNodePtr &node);
|
||||
void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
|
||||
|
|
Loading…
Reference in New Issue