forked from mindspore-Ecosystem/mindspore
!7818 add recursion limit for FindParallelCareNode
Merge pull request !7818 from yangzhenzhang/add-recursion-limit
This commit is contained in:
commit
78f795971b
|
@ -429,7 +429,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) {
|
|||
return false;
|
||||
}
|
||||
if (IsInBlackList(prim)) {
|
||||
MS_LOG(INFO) << "Parallel don't care node: " << prim->name();
|
||||
MS_LOG(DEBUG) << "Parallel don't care node: " << prim->name();
|
||||
return false;
|
||||
}
|
||||
// get_next is not in the forward graph, we need mark the get_next as the forward node
|
||||
|
@ -1286,7 +1286,11 @@ std::vector<Shapes> ExtractShape(const CNodePtr &node) {
|
|||
return shape_all;
|
||||
}
|
||||
|
||||
std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) {
|
||||
std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node, int32_t recursion_num) {
|
||||
if (recursion_num >= RECURSION_LIMIT) {
|
||||
return std::make_pair(nullptr, 0);
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
FuncGraphPtr func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
@ -1308,8 +1312,11 @@ std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) {
|
|||
}
|
||||
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
|
||||
return node_pair;
|
||||
} else if (FindParallelCareNode(node_pair.first).first != nullptr) {
|
||||
return FindParallelCareNode(node_pair.first);
|
||||
} else {
|
||||
auto tmp_pair = FindParallelCareNode(node_pair.first, recursion_num + 1);
|
||||
if (tmp_pair.first != nullptr) {
|
||||
return tmp_pair;
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_pair(nullptr, 0);
|
||||
|
@ -1320,7 +1327,7 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode
|
|||
MS_EXCEPTION_IF_NULL(parameter);
|
||||
FuncGraphManagerPtr manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::pair<AnfNodePtr, int> prim_anf_node_pair = FindParallelCareNode(parameter);
|
||||
std::pair<AnfNodePtr, int> prim_anf_node_pair = FindParallelCareNode(parameter, 0);
|
||||
if (prim_anf_node_pair.first != nullptr) {
|
||||
return prim_anf_node_pair;
|
||||
} else {
|
||||
|
|
|
@ -36,6 +36,7 @@ using OperatorInfoPtr = std::shared_ptr<mindspore::parallel::OperatorInfo>;
|
|||
namespace mindspore {
|
||||
namespace parallel {
|
||||
const uint64_t kUSecondInSecond = 1000000;
|
||||
const int32_t RECURSION_LIMIT = 3;
|
||||
|
||||
struct LossNodeInfo {
|
||||
bool has_tuple_getitem = false;
|
||||
|
@ -104,8 +105,6 @@ std::vector<AnfNodePtr> FindParameterByRefKeyNode(const AnfNodePtr &node, const
|
|||
// Extract shape from anfnode
|
||||
std::vector<Shapes> ExtractShape(const CNodePtr &node);
|
||||
|
||||
std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node);
|
||||
|
||||
// Find finally sub graph
|
||||
std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr ¶meter);
|
||||
|
||||
|
|
Loading…
Reference in New Issue