!39769 [AutoParallel]optimizer parallel compile performance to master

Merge pull request !39769 from lichen/opt_parallel_compile_master
This commit is contained in:
i-robot 2022-08-08 03:21:23 +00:00 committed by Gitee
commit bbf85a975b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 18 additions and 12 deletions

View File

@ -542,7 +542,8 @@ void ExceptionIfHasCommunicationOp(const std::vector<AnfNodePtr> &all_nodes) {
} }
} }
void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution) { void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution,
const NodeUsersMap &node_users_map) {
MS_EXCEPTION_IF_NULL(cnode->func_graph()); MS_EXCEPTION_IF_NULL(cnode->func_graph());
FuncGraphManagerPtr manager = cnode->func_graph()->manager(); FuncGraphManagerPtr manager = cnode->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
@ -559,9 +560,8 @@ void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tenso
} }
// Find Redistribution next_nodes // Find Redistribution next_nodes
auto node_users_map = manager->node_users();
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> next_nodes; std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> next_nodes;
RedistributionNextNode(cnode, manager, &node_users_map, -1, &next_nodes); RedistributionNextNode(cnode, manager, node_users_map, -1, &next_nodes);
// Insert Redistribution nodes between pre_nodes and next_nodes // Insert Redistribution nodes between pre_nodes and next_nodes
for (auto &pre_node : pre_nodes) { for (auto &pre_node : pre_nodes) {
@ -2476,12 +2476,13 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
} }
} }
auto node_users_map = manager->node_users();
for (auto &node : all_nodes) { for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
if (IsValueNode<FuncGraph>(cnode->input(0))) { if (IsValueNode<FuncGraph>(cnode->input(0))) {
StepRedistribution(cnode, tensor_redistribution); StepRedistribution(cnode, tensor_redistribution, node_users_map);
continue; continue;
} }
// the make_tuple is parallel care node, but it may have not operator info // the make_tuple is parallel care node, but it may have not operator info
@ -2502,7 +2503,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
} }
// insert redistribution ops // insert redistribution ops
StepRedistribution(cnode, tensor_redistribution); StepRedistribution(cnode, tensor_redistribution, node_users_map);
} }
// insert backward ops // insert backward ops
if (!IsControlFlowNode(cnode) && (has_backward || IsPynativeParallel())) { if (!IsControlFlowNode(cnode) && (has_backward || IsPynativeParallel())) {

View File

@ -77,7 +77,8 @@ void MarkForwardCNode(const FuncGraphPtr &root);
void ExceptionIfHasCommunicationOp(const std::vector<AnfNodePtr> &all_nodes); void ExceptionIfHasCommunicationOp(const std::vector<AnfNodePtr> &all_nodes);
void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution); void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution,
const NodeUsersMap &node_users_map);
void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node); void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node);

View File

@ -247,12 +247,14 @@ int64_t GetTupleGetItemIndex(const CNodePtr &cnode) {
return tuple_index_value->cast<Int64ImmPtr>()->value(); return tuple_index_value->cast<Int64ImmPtr>()->value();
} }
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map, void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager,
int64_t get_item_index, const NodeUsersMap &node_users_map, int64_t get_item_index,
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes) { std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node_users_map); if (node_users_map.count(node) == 0) {
auto node_set = (*node_users_map)[node]; return;
}
auto node_set = node_users_map.at(node);
for (auto &node_pair : node_set) { for (auto &node_pair : node_set) {
auto use_cnode = node_pair.first->cast<CNodePtr>(); auto use_cnode = node_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(use_cnode); MS_EXCEPTION_IF_NULL(use_cnode);
@ -275,6 +277,8 @@ void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &m
} }
if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data<OperatorInfo>()) { if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data<OperatorInfo>()) {
next_nodes->push_back(std::make_pair(node_pair, get_item_index)); next_nodes->push_back(std::make_pair(node_pair, get_item_index));
} else if (use_cnode->input(0)->isa<CNode>()) {
continue;
} else { } else {
// search recursively // search recursively
RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, next_nodes); RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, next_nodes);

View File

@ -57,8 +57,8 @@ int64_t GetTupleGetItemIndex(const CNodePtr &cnode);
AnfNodePtr GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node = nullptr); AnfNodePtr GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node = nullptr);
void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager, void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
std::vector<AnfNodePtr> *pre_nodes); std::vector<AnfNodePtr> *pre_nodes);
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map, void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager,
int64_t get_item_index, const NodeUsersMap &node_users_map, int64_t get_item_index,
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes); std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes);
// for specific scenarios // for specific scenarios