!35219 [AutoParallel]Parallel support subgraph redistribution

Merge pull request !35219 from lichen/parallel_support_subgraph_redistribution
This commit is contained in:
i-robot 2022-05-31 08:57:39 +00:00 committed by Gitee
commit 588c8fd928
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 13 additions and 5 deletions

View File

@ -330,7 +330,7 @@ bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2) {
auto micro1_value = GetValue<int64_t>(micro1);
auto micro2_value = GetValue<int64_t>(micro2);
if (micro1_value == micro2_value) {
if (IsPrimitiveCNode(node1, prim::kPrimStridedSlice)) {
if (IsPrimitiveCNode(node1, prim::kPrimStridedSlice) || IsPrimitiveCNode(node2, prim::kPrimStridedSlice)) {
return true;
}
auto prim1 = GetCNodePrimitive(cnode1);

View File

@ -195,15 +195,23 @@ int64_t GetTupleGetItemIndex(const CNodePtr &cnode) {
return tuple_index_value->cast<Int64ImmPtr>()->value();
}
void RedistributionNextNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map,
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map,
int64_t get_item_index,
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node_users_map);
auto node_set = (*node_users_map)[cnode];
auto node_set = (*node_users_map)[node];
for (auto &node_pair : node_set) {
auto use_cnode = node_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(use_cnode);
if (IsValueNode<FuncGraph>(use_cnode->input(0))) {
auto fg = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
MS_EXCEPTION_IF_NULL(fg);
auto fg_parameters = fg->parameters();
auto param = fg_parameters[node_pair.second - 1];
MS_EXCEPTION_IF_NULL(param);
RedistributionNextNode(param, manager, node_users_map, get_item_index, next_nodes);
}
if (IsPrimitiveCNode(use_cnode, prim::kPrimTupleGetItem)) {
get_item_index = LongToInt(GetTupleGetItemIndex(use_cnode));
}

View File

@ -46,7 +46,7 @@ bool IsControlFlowNode(const AnfNodePtr &node);
int64_t GetTupleGetItemIndex(const CNodePtr &cnode);
void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
std::vector<AnfNodePtr> *pre_nodes);
void RedistributionNextNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map,
void RedistributionNextNode(const AnfNodePtr &cnode, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map,
int64_t get_item_index,
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes);