!35219 [AutoParallel]Parallel support subgraph redistribution
Merge pull request !35219 from lichen/parallel_support_subgraph_redistribution
This commit is contained in:
commit
588c8fd928
|
@ -330,7 +330,7 @@ bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
||||||
auto micro1_value = GetValue<int64_t>(micro1);
|
auto micro1_value = GetValue<int64_t>(micro1);
|
||||||
auto micro2_value = GetValue<int64_t>(micro2);
|
auto micro2_value = GetValue<int64_t>(micro2);
|
||||||
if (micro1_value == micro2_value) {
|
if (micro1_value == micro2_value) {
|
||||||
if (IsPrimitiveCNode(node1, prim::kPrimStridedSlice)) {
|
if (IsPrimitiveCNode(node1, prim::kPrimStridedSlice) || IsPrimitiveCNode(node2, prim::kPrimStridedSlice)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
auto prim1 = GetCNodePrimitive(cnode1);
|
auto prim1 = GetCNodePrimitive(cnode1);
|
||||||
|
|
|
@ -195,15 +195,23 @@ int64_t GetTupleGetItemIndex(const CNodePtr &cnode) {
|
||||||
return tuple_index_value->cast<Int64ImmPtr>()->value();
|
return tuple_index_value->cast<Int64ImmPtr>()->value();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RedistributionNextNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map,
|
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map,
|
||||||
int64_t get_item_index,
|
int64_t get_item_index,
|
||||||
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes) {
|
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes) {
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
MS_EXCEPTION_IF_NULL(node_users_map);
|
MS_EXCEPTION_IF_NULL(node_users_map);
|
||||||
auto node_set = (*node_users_map)[cnode];
|
auto node_set = (*node_users_map)[node];
|
||||||
for (auto &node_pair : node_set) {
|
for (auto &node_pair : node_set) {
|
||||||
auto use_cnode = node_pair.first->cast<CNodePtr>();
|
auto use_cnode = node_pair.first->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(use_cnode);
|
MS_EXCEPTION_IF_NULL(use_cnode);
|
||||||
|
if (IsValueNode<FuncGraph>(use_cnode->input(0))) {
|
||||||
|
auto fg = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
|
||||||
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
|
auto fg_parameters = fg->parameters();
|
||||||
|
auto param = fg_parameters[node_pair.second - 1];
|
||||||
|
MS_EXCEPTION_IF_NULL(param);
|
||||||
|
RedistributionNextNode(param, manager, node_users_map, get_item_index, next_nodes);
|
||||||
|
}
|
||||||
if (IsPrimitiveCNode(use_cnode, prim::kPrimTupleGetItem)) {
|
if (IsPrimitiveCNode(use_cnode, prim::kPrimTupleGetItem)) {
|
||||||
get_item_index = LongToInt(GetTupleGetItemIndex(use_cnode));
|
get_item_index = LongToInt(GetTupleGetItemIndex(use_cnode));
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,7 +46,7 @@ bool IsControlFlowNode(const AnfNodePtr &node);
|
||||||
int64_t GetTupleGetItemIndex(const CNodePtr &cnode);
|
int64_t GetTupleGetItemIndex(const CNodePtr &cnode);
|
||||||
void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
|
void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
|
||||||
std::vector<AnfNodePtr> *pre_nodes);
|
std::vector<AnfNodePtr> *pre_nodes);
|
||||||
void RedistributionNextNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map,
|
void RedistributionNextNode(const AnfNodePtr &cnode, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map,
|
||||||
int64_t get_item_index,
|
int64_t get_item_index,
|
||||||
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes);
|
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue