forked from mindspore-Ecosystem/mindspore
opt_parallel_compile
This commit is contained in:
parent
dad41b6809
commit
5070982760
|
@ -542,7 +542,8 @@ void ExceptionIfHasCommunicationOp(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
}
|
||||
}
|
||||
|
||||
void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution) {
|
||||
void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution,
|
||||
const NodeUsersMap &node_users_map) {
|
||||
MS_EXCEPTION_IF_NULL(cnode->func_graph());
|
||||
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -559,9 +560,8 @@ void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tenso
|
|||
}
|
||||
|
||||
// Find Redistribution next_nodes
|
||||
auto node_users_map = manager->node_users();
|
||||
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> next_nodes;
|
||||
RedistributionNextNode(cnode, manager, &node_users_map, -1, &next_nodes);
|
||||
RedistributionNextNode(cnode, manager, node_users_map, -1, &next_nodes);
|
||||
|
||||
// Insert Redistribution nodes between pre_nodes and next_nodes
|
||||
for (auto &pre_node : pre_nodes) {
|
||||
|
@ -2476,12 +2476,13 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
}
|
||||
}
|
||||
|
||||
auto node_users_map = manager->node_users();
|
||||
for (auto &node : all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (IsValueNode<FuncGraph>(cnode->input(0))) {
|
||||
StepRedistribution(cnode, tensor_redistribution);
|
||||
StepRedistribution(cnode, tensor_redistribution, node_users_map);
|
||||
continue;
|
||||
}
|
||||
// the make_tuple is parallel care node, but it may have not operator info
|
||||
|
@ -2502,7 +2503,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
}
|
||||
|
||||
// insert redistribution ops
|
||||
StepRedistribution(cnode, tensor_redistribution);
|
||||
StepRedistribution(cnode, tensor_redistribution, node_users_map);
|
||||
}
|
||||
// insert backward ops
|
||||
if (!IsControlFlowNode(cnode) && (has_backward || IsPynativeParallel())) {
|
||||
|
|
|
@ -77,7 +77,8 @@ void MarkForwardCNode(const FuncGraphPtr &root);
|
|||
|
||||
void ExceptionIfHasCommunicationOp(const std::vector<AnfNodePtr> &all_nodes);
|
||||
|
||||
void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution);
|
||||
void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution &tensor_redistribution,
|
||||
const NodeUsersMap &node_users_map);
|
||||
|
||||
void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node);
|
||||
|
||||
|
|
|
@ -247,12 +247,14 @@ int64_t GetTupleGetItemIndex(const CNodePtr &cnode) {
|
|||
return tuple_index_value->cast<Int64ImmPtr>()->value();
|
||||
}
|
||||
|
||||
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map,
|
||||
int64_t get_item_index,
|
||||
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager,
|
||||
const NodeUsersMap &node_users_map, int64_t get_item_index,
|
||||
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node_users_map);
|
||||
auto node_set = (*node_users_map)[node];
|
||||
if (node_users_map.count(node) == 0) {
|
||||
return;
|
||||
}
|
||||
auto node_set = node_users_map.at(node);
|
||||
for (auto &node_pair : node_set) {
|
||||
auto use_cnode = node_pair.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(use_cnode);
|
||||
|
@ -275,6 +277,8 @@ void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &m
|
|||
}
|
||||
if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data<OperatorInfo>()) {
|
||||
next_nodes->push_back(std::make_pair(node_pair, get_item_index));
|
||||
} else if (use_cnode->input(0)->isa<CNode>()) {
|
||||
continue;
|
||||
} else {
|
||||
// search recursively
|
||||
RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, next_nodes);
|
||||
|
|
|
@ -57,8 +57,8 @@ int64_t GetTupleGetItemIndex(const CNodePtr &cnode);
|
|||
AnfNodePtr GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node = nullptr);
|
||||
void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
|
||||
std::vector<AnfNodePtr> *pre_nodes);
|
||||
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager, NodeUsersMap *node_users_map,
|
||||
int64_t get_item_index,
|
||||
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager,
|
||||
const NodeUsersMap &node_users_map, int64_t get_item_index,
|
||||
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes);
|
||||
|
||||
// for specific scenarios
|
||||
|
|
Loading…
Reference in New Issue