opt_parallel_compile

This commit is contained in:
lichen 2022-08-04 20:20:14 +08:00
parent dad41b6809
commit 5070982760
4 changed files with 18 additions and 12 deletions

View File

@ -542,7 +542,8 @@ void ExceptionIfHasCommunicationOp(const std::vector<AnfNodePtr> &all_nodes) {
}
}
void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution) {
void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution,
const NodeUsersMap &node_users_map) {
MS_EXCEPTION_IF_NULL(cnode->func_graph());
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
@ -559,9 +560,8 @@ void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tenso
}
// Find Redistribution next_nodes
auto node_users_map = manager->node_users();
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> next_nodes;
RedistributionNextNode(cnode, manager, &node_users_map, -1, &next_nodes);
RedistributionNextNode(cnode, manager, node_users_map, -1, &next_nodes);
// Insert Redistribution nodes between pre_nodes and next_nodes
for (auto &pre_node : pre_nodes) {
@ -2476,12 +2476,13 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
}
}
auto node_users_map = manager->node_users();
for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
if (IsValueNode<FuncGraph>(cnode->input(0))) {
StepRedistribution(cnode, tensor_redistribution);
StepRedistribution(cnode, tensor_redistribution, node_users_map);
continue;
}
// the make_tuple is parallel care node, but it may have not operator info
@ -2502,7 +2503,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
}
// insert redistribution ops
StepRedistribution(cnode, tensor_redistribution);
StepRedistribution(cnode, tensor_redistribution, node_users_map);
}
// insert backward ops
if (!IsControlFlowNode(cnode) && (has_backward || IsPynativeParallel())) {

View File

@ -77,7 +77,8 @@ void MarkForwardCNode(const FuncGraphPtr &root);
void ExceptionIfHasCommunicationOp(const std::vector<AnfNodePtr> &all_nodes);
void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution);
void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution,
const NodeUsersMap &node_users_map);
void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node);

View File

@ -247,12 +247,14 @@ int64_t GetTupleGetItemIndex(const CNodePtr &cnode) {
return tuple_index_value->cast<Int64ImmPtr>()->value();
}
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map,
int64_t get_item_index,
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager,
const NodeUsersMap &node_users_map, int64_t get_item_index,
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node_users_map);
auto node_set = (*node_users_map)[node];
if (node_users_map.count(node) == 0) {
return;
}
auto node_set = node_users_map.at(node);
for (auto &node_pair : node_set) {
auto use_cnode = node_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(use_cnode);
@ -275,6 +277,8 @@ void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &m
}
if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data<OperatorInfo>()) {
next_nodes->push_back(std::make_pair(node_pair, get_item_index));
} else if (use_cnode->input(0)->isa<CNode>()) {
continue;
} else {
// search recursively
RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, next_nodes);

View File

@ -57,8 +57,8 @@ int64_t GetTupleGetItemIndex(const CNodePtr &cnode);
AnfNodePtr GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node = nullptr);
void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
std::vector<AnfNodePtr> *pre_nodes);
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map,
int64_t get_item_index,
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager,
const NodeUsersMap &node_users_map, int64_t get_item_index,
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes);
// for specific scenarios