!47804 [AutoParallel]Parallel support while

Merge pull request !47804 from lichen/parallel_support_protein_predict
This commit is contained in:
i-robot 2023-01-18 01:56:51 +00:00 committed by Gitee
commit 057da5a11d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 269 additions and 124 deletions

View File

@ -174,8 +174,12 @@ ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)
// the node is a ref key node
return FindRefKeyNodeUsers(cnode_with_refkeys, IsCareNode);
} else if (node->isa<Parameter>()) {
auto param_ptr = node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
// the node is a parameter node
return FindParameterNodeUsers(node);
if (param_ptr->has_default()) {
return FindParameterNodeUsers(node);
}
}
return parameter_users_info;
@ -745,7 +749,7 @@ static std::pair<AnfNodePtr, bool> FindParameterByFuncGraph(const AnfNodePtr &no
MS_EXCEPTION_IF_NULL(fg);
auto fg_parameters = fg->parameters();
auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr);
auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr).first;
auto pre_cnode = pre_node->cast<CNodePtr>();
for (size_t index = 1; index < pre_cnode->inputs().size(); ++index) {
auto res = FindParameter(pre_cnode->input(index), pre_cnode->func_graph());

View File

@ -590,7 +590,7 @@ static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, con
AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
auto real_node = GetRealKernelNode(node, -1);
auto real_node = GetRealKernelNode(node, -1).first;
if (!real_node->isa<CNode>()) {
return real_node;
}
@ -795,7 +795,7 @@ bool PipelineTransformer::IsParameterGraph(const AnfNodePtr &node) const {
// ParameterGraph: graph which return a parameter
MS_EXCEPTION_IF_NULL(node);
CNodePtr call_node = nullptr;
auto real_kernel = GetRealKernelNode(node, -1, &call_node);
auto real_kernel = GetRealKernelNode(node, -1, &call_node).first;
if (call_node != nullptr && real_kernel->isa<Parameter>()) {
return true;
}
@ -806,7 +806,7 @@ AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, con
int64_t user_stage, const ValuePtr &micro, size_t pos,
const std::vector<AnfNodePtr> &ops) {
CNodePtr call_node = nullptr;
auto argument = GetRealKernelNode(node, -1, &call_node);
auto argument = GetRealKernelNode(node, -1, &call_node).first;
auto use_cnode = use_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(use_cnode);

View File

@ -482,6 +482,13 @@ static void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution
IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
return;
}
// Find Redistribution next_nodes
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> next_nodes;
RedistributionNextNode(cnode, manager, node_users_map, -1, -1, &next_nodes);
if (next_nodes.empty()) {
return;
}
// Find Redistribution pre_nodes
std::vector<AnfNodePtr> pre_nodes;
RedistributionPreNode(cnode, manager, &pre_nodes);
@ -489,10 +496,6 @@ static void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution
MS_LOG(EXCEPTION) << " Don't support Redistribution has multiple pre_node.";
}
// Find Redistribution next_nodes
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> next_nodes;
RedistributionNextNode(cnode, manager, node_users_map, -1, &next_nodes);
// Insert Redistribution nodes between pre_nodes and next_nodes
for (auto &pre_node : pre_nodes) {
for (auto &next_node : next_nodes) {
@ -881,65 +884,46 @@ static bool FindPreNodes(const AnfNodePtr &node, std::vector<std::string> *uniqu
return find;
}
static void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *unique_ids,
std::vector<size_t> *indexes) {
MS_EXCEPTION_IF_NULL(unique_ids);
CNodePtr cnode = root->get_return();
if (!FindPreNodes(cnode, unique_ids, indexes, 0)) {
MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph";
}
}
void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
std::vector<std::string> last_forward_node_ids;
std::vector<size_t> last_indexs;
auto real_graph = PynativeParallelGraph(root, all_nodes);
FindLastNodesUniqueId(real_graph, &last_forward_node_ids, &last_indexs);
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
for (auto &node : all_nodes) {
// here insert virtualoutput node
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
}
auto last_node_iter = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId());
if (last_node_iter == last_forward_node_ids.end()) {
continue;
}
for (size_t last_node_index = 0; last_node_index < last_forward_node_ids.size(); ++last_node_index) {
if (last_forward_node_ids[last_node_index] != cnode->UniqueId()) {
continue;
}
MS_LOG(INFO) << "find last node: " << cnode->fullname_with_scope() << ", the parallel care node is: "
<< cnode->input(last_indexs[last_node_index])->fullname_with_scope();
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_pair = manager->node_users()[cnode].front();
if (!node_pair.first->isa<CNode>()) {
MS_LOG(EXCEPTION) << "the output of tuple_get_item is not a cnode";
}
cnode = node_pair.first->cast<CNodePtr>();
last_indexs[last_node_index] = IntToSize(node_pair.second);
}
auto pre_node = cnode->input(last_indexs[last_node_index]);
Shapes shape_outputs = GetNodeShape(pre_node);
auto out_pair = GetRealKernelNode(real_graph->output(), -1, nullptr, false);
auto out_node = out_pair.first;
MS_EXCEPTION_IF_NULL(out_node);
OperatorParams params;
OperatorAttrs attrs;
OperatorArgs args = std::make_pair(attrs, params);
Operator op = std::make_pair(VIRTUAL_OUTPUT, args);
if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
auto tuple = out_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple);
for (size_t i = 1; i < tuple->inputs().size(); ++i) {
auto cur_input = tuple->input(i);
Shapes shape_outputs = GetNodeShape(cur_input);
if (shape_outputs[0].empty()) {
continue;
}
FuncGraphPtr func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
OperatorParams params;
OperatorAttrs attrs;
OperatorArgs args = std::make_pair(attrs, params);
Operator op = std::make_pair(VIRTUAL_OUTPUT, args);
InsertNode(op, cnode, last_indexs[last_node_index], pre_node, func_graph, VIRTUAL_OUTPUT);
auto virtual_output_node = cnode->input(last_indexs[last_node_index]);
AbstractBasePtr virtual_output_abstract = pre_node->abstract()->Clone();
InsertNode(op, tuple, i, cur_input, tuple->func_graph(), VIRTUAL_OUTPUT);
auto virtual_output_abstract = cur_input->abstract()->Clone();
std::shared_ptr<abstract::BaseShape> virtual_output_shape = std::make_shared<abstract::Shape>(shape_outputs[0]);
virtual_output_abstract->set_shape(virtual_output_shape);
auto virtual_output_node = tuple->input(i);
virtual_output_node->set_abstract(virtual_output_abstract);
}
} else {
Shapes shape_outputs = GetNodeShape(out_node);
if (shape_outputs[0].empty()) {
return;
}
auto node_input = CreateInput(op, out_node, VIRTUAL_OUTPUT);
auto cur_graph = out_node->cast<CNodePtr>()->func_graph();
MS_EXCEPTION_IF_NULL(cur_graph);
auto new_node = cur_graph->NewCNode(node_input);
auto manager = cur_graph->manager();
(void)manager->Replace(out_node, new_node);
auto virtual_output_abstract = out_node->abstract()->Clone();
std::shared_ptr<abstract::BaseShape> virtual_output_shape = std::make_shared<abstract::Shape>(shape_outputs[0]);
virtual_output_abstract->set_shape(virtual_output_shape);
new_node->set_abstract(virtual_output_abstract);
}
}
@ -1612,14 +1596,15 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
}
// if reshape's output connect to several primitive, return the first layout found
static std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode, bool *next_is_reshape) {
static std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode, bool *next_is_reshape,
int make_tuple_index) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(cnode->func_graph());
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
AnfNodeIndexSet node_set = manager->node_users()[cnode];
for (auto &node_pair : node_set) {
CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
auto use_apply = node_pair.first->cast<CNodePtr>();
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
continue;
}
@ -1627,24 +1612,26 @@ static std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode, bool
*next_is_reshape = true;
continue;
}
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node);
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(node_prim);
MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
if (node_prim->name() == DEPEND && node_pair.second != 1) {
if (IsPrimitiveCNode(use_apply, prim::kPrimDepend) && node_pair.second != 1) {
continue;
}
if (IsPrimitiveCNode(use_apply, prim::kPrimMakeTuple)) {
make_tuple_index = node_pair.second;
return FindNextLayout(use_apply, next_is_reshape, make_tuple_index);
}
if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name();
if (make_tuple_index != -1) {
node_pair.second = make_tuple_index;
}
MS_LOG(INFO) << "FindNextLayout success node " << use_apply->DebugString();
*next_is_reshape = false;
auto layout = GetInputLayoutFromCNode(node_pair);
return std::make_shared<TensorLayout>(layout);
}
MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
MS_LOG(DEBUG) << "FindNextLayout failed node " << use_apply->DebugString() << " " << IsParallelCareNode(use_apply)
<< " " << use_apply->has_user_data<OperatorInfo>();
auto layout_ptr = FindNextLayout(use_apply, next_is_reshape);
auto layout_ptr = FindNextLayout(use_apply, next_is_reshape, -1);
if (layout_ptr) {
return layout_ptr;
}
@ -1797,7 +1784,7 @@ static void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
reshape_info_ptr->SetInputLayout(*prev_layout_ptr);
}
bool is_next_reshape = false;
auto next_layout_ptr = FindNextLayout(cnode, &is_next_reshape);
auto next_layout_ptr = FindNextLayout(cnode, &is_next_reshape, -1);
if (next_layout_ptr) {
auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
reshape_info_ptr->SetOutputLayout(*next_layout_ptr);
@ -1827,10 +1814,7 @@ static CNodePtr HandleDependLoss(const CNodePtr &cnode, size_t curr_depth) {
return cnode;
}
static LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph, size_t max_depth) {
if (max_depth > MAX_RECURSIVE_DEPTH) {
MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
}
static LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
LossNodeInfo loss_node_info;
MS_EXCEPTION_IF_NULL(func_graph);
CNodePtr return_node = func_graph->get_return();
@ -1838,18 +1822,11 @@ static LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph, size_t max_dep
if (return_node->size() < 2) {
MS_LOG(EXCEPTION) << "Failure: " << return_node->DebugString() << " size is smaller than 2";
}
AnfNodePtr pre_node = return_node->input(1);
auto pre_node_pair = GetRealKernelNode(return_node->input(1), -1, nullptr);
auto pre_node = pre_node_pair.first;
MS_EXCEPTION_IF_NULL(pre_node);
auto pre_cnode = pre_node->cast<CNodePtr>();
pre_cnode = HandleDependLoss(pre_cnode, 0);
if (pre_cnode->input(0)->isa<CNode>()) {
auto switch_cnode = pre_cnode->input(0)->cast<CNodePtr>();
if (IsPrimitiveCNode(switch_cnode, prim::kPrimSwitch)) {
MS_EXCEPTION_IF_NULL(switch_cnode);
auto switch_graph = GetValueNode<FuncGraphPtr>(switch_cnode->input(2));
return FindLossCNode(switch_graph, max_depth + 1);
}
}
if (pre_cnode == nullptr || !IsValueNode<Primitive>(pre_cnode->input(0))) {
return loss_node_info;
}
@ -1865,21 +1842,11 @@ static LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph, size_t max_dep
return loss_node_info;
}
// size of common cnode is larger than 1
if (pre_cnode->size() < 2) {
MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2";
}
// return -> tuple_getitem -> loss
if (current_prim->name() == prim::kTupleGetItem) {
auto tuple_index = GetTupleGetItemIndex(pre_cnode);
AnfNodePtr pre_pre_node = pre_cnode->input(1);
MS_EXCEPTION_IF_NULL(pre_pre_node);
auto pre_pre_cnode = pre_pre_node->cast<CNodePtr>();
if (pre_node_pair.second != -1) {
loss_node_info.has_tuple_getitem = true;
loss_node_info.dout_index = tuple_index;
loss_node_info.loss_node = pre_pre_cnode;
loss_node_info.dout_index = pre_node_pair.second;
loss_node_info.loss_node = pre_cnode;
return loss_node_info;
}
@ -2127,7 +2094,7 @@ static std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const Fun
MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
}
auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
auto loss_node_info = FindLossCNode(func_graph, 0);
auto loss_node_info = FindLossCNode(func_graph);
if (loss_node_info.loss_node == nullptr) {
MS_LOG(WARNING) << "Can not find the loss cnode";
continue;
@ -2315,7 +2282,7 @@ std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root) {
static std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> root_forward_nodes;
auto loss_cnode = FindLossCNode(graph, 0).loss_node;
auto loss_cnode = FindLossCNode(graph).loss_node;
if (loss_cnode == nullptr) {
return root_forward_nodes;
}

View File

@ -112,22 +112,27 @@ TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> &param_info)
return tensor_info;
}
AnfNodePtr GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node) {
std::pair<AnfNodePtr, int64_t> GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node,
bool ignore_get_item) {
if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
IsPrimitiveCNode(node, prim::kPrimCast)) {
return GetRealKernelNode(node->cast<CNodePtr>()->input(1), get_item_index, call_node);
IsPrimitiveCNode(node, prim::kPrimCast) || IsPrimitiveCNode(node, prim::kPrimVirtualDiv)) {
return GetRealKernelNode(node->cast<CNodePtr>()->input(1), get_item_index, call_node, ignore_get_item);
}
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem) && ignore_get_item) {
auto cnode = node->cast<CNodePtr>();
auto cur_get_item_index = LongToInt(GetTupleGetItemIndex(cnode));
auto tuple_getitem_input = cnode->input(1);
auto pass_through_node = GetRealKernelNode(tuple_getitem_input, cur_get_item_index, call_node);
return GetRealKernelNode(pass_through_node, get_item_index, call_node);
return GetRealKernelNode(tuple_getitem_input, cur_get_item_index, call_node, ignore_get_item);
}
if (get_item_index != -1 && IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
auto make_tuple_cnode = node->cast<CNodePtr>();
auto make_tuple_input = make_tuple_cnode->input(LongToSize(get_item_index + 1));
return GetRealKernelNode(make_tuple_input, -1, call_node);
return GetRealKernelNode(make_tuple_input, -1, call_node, ignore_get_item);
}
if (IsControlFlowNode(node)) {
auto switch_cnode = node->cast<CNodePtr>()->input(0)->cast<CNodePtr>();
auto fg = GetValueNode<FuncGraphPtr>(switch_cnode->input(3));
return GetRealKernelNode(fg->output(), get_item_index, call_node, ignore_get_item);
}
if (node->isa<CNode>() && IsValueNode<FuncGraph>(node->cast<CNodePtr>()->input(0))) {
if (call_node != nullptr && *call_node == nullptr) {
@ -135,21 +140,33 @@ AnfNodePtr GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNo
}
auto cnode = node->cast<CNodePtr>();
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
auto output = GetRealKernelNode(graph->output(), get_item_index, call_node);
auto output = GetRealKernelNode(graph->output(), get_item_index, call_node, ignore_get_item).first;
MS_EXCEPTION_IF_NULL(output);
if (output->isa<Parameter>()) {
auto parameters = graph->parameters();
auto pos_iter = std::find(parameters.begin(), parameters.end(), output);
// If can't find in parameters, the parameter is a fv.
if (pos_iter == parameters.end()) {
return output;
return std::make_pair(output, get_item_index);
}
auto pos = std::distance(parameters.begin(), pos_iter);
return GetRealKernelNode(cnode->input(LongToSize(pos + 1)), -1, call_node);
return GetRealKernelNode(cnode->input(LongToSize(pos + 1)), -1, call_node, ignore_get_item);
}
return output;
return std::make_pair(output, get_item_index);
}
return node;
return std::make_pair(node, get_item_index);
}
static bool IsWhileGraph(const FuncGraphPtr &cur_fg, const FuncGraphPtr &fg) {
auto cur_fg_map = cur_fg->func_graph_cnodes_index();
for (auto &cur_fg_use : cur_fg_map) {
auto temp_node = cur_fg_use.first->first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(temp_node);
if (temp_node->func_graph() == fg) {
return true;
}
}
return false;
}
AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) {
@ -280,8 +297,13 @@ int64_t GetTupleGetItemIndex(const CNodePtr &cnode) {
return tuple_index_value->cast<Int64ImmPtr>()->value();
}
static bool IsNoNeedRedistribution(const CNodePtr &use_cnode, int use_index) {
return (IsPrimitiveCNode(use_cnode, prim::kPrimDepend) && use_index != 1) || use_cnode->input(0)->isa<CNode>() ||
IsPrimitiveCNode(use_cnode, prim::kPrimUpdateState) || IsPrimitiveCNode(use_cnode, prim::kPrimSwitch);
}
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager,
const NodeUsersMap &node_users_map, int64_t get_item_index,
const NodeUsersMap &node_users_map, int64_t get_item_index, int64_t make_tuple_index,
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes) {
MS_EXCEPTION_IF_NULL(node);
if (node_users_map.count(node) == 0) {
@ -292,30 +314,60 @@ void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &m
auto use_cnode = node_pair.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(use_cnode);
if (IsValueNode<FuncGraph>(use_cnode->input(0))) {
auto cur_fg = use_cnode->func_graph();
auto fg = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
MS_EXCEPTION_IF_NULL(fg);
if (IsWhileGraph(cur_fg, fg)) {
continue;
}
auto fg_parameters = fg->parameters();
auto param = fg_parameters[IntToSize(node_pair.second - 1)];
MS_EXCEPTION_IF_NULL(param);
RedistributionNextNode(param, manager, node_users_map, get_item_index, next_nodes);
RedistributionNextNode(param, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
continue;
}
if (IsPrimitiveCNode(use_cnode, prim::kPrimMakeTuple)) {
make_tuple_index = node_pair.second - 1;
RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
continue;
}
if (IsPrimitiveCNode(use_cnode, prim::kPrimTupleGetItem)) {
get_item_index = LongToInt(GetTupleGetItemIndex(use_cnode));
auto temp = LongToInt(GetTupleGetItemIndex(use_cnode));
if (temp != make_tuple_index && make_tuple_index != -1) {
continue;
}
RedistributionNextNode(use_cnode, manager, node_users_map, temp, -1, next_nodes);
continue;
}
if (IsPrimitiveCNode(use_cnode, prim::kPrimReturn)) {
auto fg = use_cnode->func_graph();
auto fg_map = fg->func_graph_cnodes_index();
for (auto &fg_use : fg_map) {
auto fg_node = fg_use.first->first->cast<CNodePtr>();
constexpr int SWITCH_LAST_INPUT_INDEX = 3;
if (IsWhileGraph(fg, fg_node->func_graph()) && fg_use.first->second == SWITCH_LAST_INPUT_INDEX) {
RedistributionNextNode(fg_node, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
}
}
}
// depend, auto monad and control flow op don't need to jump over
if ((IsPrimitiveCNode(use_cnode, prim::kPrimDepend) && node_pair.second != 1) ||
IsPrimitiveCNode(use_cnode, prim::kPrimUpdateState) || IsPrimitiveCNode(use_cnode, prim::kPrimSwitch)) {
if (IsNoNeedRedistribution(use_cnode, node_pair.second)) {
continue;
}
if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data<OperatorInfo>()) {
if (make_tuple_index != -1) {
auto real_node = GetRealKernelNode(use_cnode->input(1), -1, nullptr);
if (IsPrimitiveCNode(real_node.first, prim::kPrimMakeTuple)) {
next_nodes->push_back(std::make_pair(std::make_pair(real_node.first, make_tuple_index + 1), get_item_index));
make_tuple_index = -1;
continue;
}
}
next_nodes->push_back(std::make_pair(node_pair, get_item_index));
} else if (use_cnode->input(0)->isa<CNode>()) {
continue;
} else {
// search recursively
RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, next_nodes);
}
// search recursively
RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
}
}
@ -323,7 +375,7 @@ void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &man
std::vector<AnfNodePtr> *pre_nodes) {
if (IsValueNode<FuncGraph>(cnode->input(0))) {
auto fg = GetValueNode<FuncGraphPtr>(cnode->input(0));
auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr);
auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr).first;
if (!pre_node) {
return;
}

View File

@ -66,11 +66,12 @@ TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> &param_info)
AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager);
bool IsControlFlowNode(const AnfNodePtr &node);
int64_t GetTupleGetItemIndex(const CNodePtr &cnode);
AnfNodePtr GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node = nullptr);
std::pair<AnfNodePtr, int64_t> GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index,
CNodePtr *call_node = nullptr, bool ignore_get_item = true);
void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
std::vector<AnfNodePtr> *pre_nodes);
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager,
const NodeUsersMap &node_users_map, int64_t get_item_index,
const NodeUsersMap &node_users_map, int64_t get_item_index, int64_t make_tuple_index,
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes);
// for specific scenarios

View File

@ -0,0 +1,121 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.train import Model
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
import mindspore.common.dtype as mstype
class DatasetLenet():
def __init__(self, data, label, length=3):
self.data = data
self.label = label
self.index = 1
self.length = length
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.data, self.label
@staticmethod
def get_dataset_size():
return 32
@staticmethod
def get_repeat_count():
return 1
@staticmethod
def get_batch_size():
return 32
def create_tuple_iterator(self, num_epochs=1, do_copy=True):
return self
def reset(self):
self.index = 0
class MatMulCell(nn.Cell):
def __init__(self):
super().__init__()
self.matmul = P.MatMul()
self.relu = P.ReLU().shard(((2, 1),))
self.weight = Parameter(initializer("ones", [64, 64]), name="param1")
def construct(self, x):
out = self.matmul(x, self.weight)
out = self.relu(out)
return out
class ConcatCell(nn.Cell):
def __init__(self):
super().__init__()
self.concat = P.Concat().shard(((1, 8), (1, 8)))
self.relu = P.ReLU()
def construct(self, x, y):
out = self.concat((y, x))
out = self.relu(out)
return out
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.matmul = P.MatMul().shard(((2, 4), (4, 1)))
self.weight = Parameter(initializer("ones", [64, 64]), name="param")
self.index = Parameter(Tensor(0, mstype.int32), requires_grad=False)
self.cell1 = MatMulCell()
self.cell2 = ConcatCell()
self.relu = P.ReLU().shard(((8, 1),))
def construct(self, x, y):
out = self.matmul(x, self.weight)
while self.index < 3:
out = self.cell1(out)
self.index += 1
out = self.cell2(out, x)
out = self.relu(out)
return out
def test_parallel_while():
"""
Feature: test parallel while.
Description: while + concat.
Expectation: Successful graph compilation.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
context.set_auto_parallel_context(device_num=8, global_rank=0)
net = Net()
data = Tensor(np.ones([128, 64]), dtype=ms.float32)
label = Tensor(np.ones([8, 8]), dtype=ms.float32)
dataset = DatasetLenet(data, label, 3)
opt = nn.Lamb(net.trainable_params(), learning_rate=0.01)
model = Model(net, optimizer=opt)
model.train(2, dataset, dataset_sink_mode=False)