fix_pipeline_with_no_loss_bug

This commit is contained in:
lichenever 2021-10-11 11:35:54 +08:00
parent 2c738757c3
commit 5ff1124e71
2 changed files with 32 additions and 13 deletions

View File

@ -27,15 +27,25 @@
#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/step_parallel.h"
#include "utils/parallel_node_check.h"
namespace mindspore {
namespace parallel {
const std::set<PrimitivePtr> END_NODE_BLACK_LIST = {prim::kPrimDepend, prim::kPrimTupleGetItem, prim::kPrimAdd,
prim::kPrimSoftmaxCrossEntropyWithLogits};
const std::set<PrimitivePtr> END_NODE_BLACK_LIST = {
prim::kPrimDepend, prim::kPrimTupleGetItem, prim::kPrimAdd, prim::kPrimSoftmaxCrossEntropyWithLogits,
prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimReshape};
static bool IsInEndNodeBlackList(const CNodePtr &cnode) {
for (auto &prim : END_NODE_BLACK_LIST) {
if (IsPrimitiveCNode(cnode, prim)) {
MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) {
return true;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (IsInParallelBlackList(prim)) {
return true;
}
for (auto &prim_node : END_NODE_BLACK_LIST) {
if (IsPrimitiveCNode(cnode, prim_node)) {
return true;
}
}
@ -414,16 +424,24 @@ void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, con
}
}
AnfNodePtr GetPreNode(const AnfNodePtr &node, size_t max_depth) {
AnfNodePtr GetPreNode(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (max_depth > MAX_RECURSIVE_DEPTH) {
MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
std::vector<AnfNodePtr> node_queue = {node};
while (!node_queue.empty()) {
auto cur_node = (*node_queue.begin())->cast<CNodePtr>();
if (!cur_node) {
node_queue.erase(node_queue.begin());
continue;
}
node_queue.erase(node_queue.begin());
if (!IsInEndNodeBlackList(cur_node)) {
MS_LOG(INFO) << "Pipeline End node: " << cur_node->DebugString();
return cur_node;
}
node_queue.insert(node_queue.end(), cur_node->inputs().begin() + 1, cur_node->inputs().end());
}
if (IsInEndNodeBlackList(cnode)) {
return GetPreNode(cnode->input(1), max_depth + 1);
}
return cnode;
MS_LOG(EXCEPTION) << "Get Pipeline End node failed.";
}
void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager) {
@ -448,7 +466,8 @@ void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphM
if (!temp_prim || temp_prim->HasAttr(PIPELINE_END)) {
continue;
}
auto end_node = GetPreNode(temp_node, 0);
auto end_node = GetPreNode(temp_node);
MS_EXCEPTION_IF_NULL(end_node);
auto end_cnode = end_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(end_cnode);
auto end_prim = GetCNodePrimitive(end_node);

View File

@ -57,7 +57,7 @@ void Reorder(const FuncGraphPtr &root);
void ReorderForPredict(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
void HandleMicroBatch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value, size_t max_depth);
AnfNodePtr GetPreNode(const AnfNodePtr &node, size_t max_depth);
AnfNodePtr GetPreNode(const AnfNodePtr &node);
void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);
void SetStridedSliceStrategy(const AnfNodePtr &node);
void ParameterStartNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager);