forked from mindspore-Ecosystem/mindspore
!47804 [AutoParallel]Parallel support while
Merge pull request !47804 from lichen/parallel_support_protein_predict
This commit is contained in:
commit
057da5a11d
|
@ -174,9 +174,13 @@ ParameterUsersInfo FindParameterUsers(const AnfNodePtr &node, bool (*IsCareNode)
|
|||
// the node is a ref key node
|
||||
return FindRefKeyNodeUsers(cnode_with_refkeys, IsCareNode);
|
||||
} else if (node->isa<Parameter>()) {
|
||||
auto param_ptr = node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
// the node is a parameter node
|
||||
if (param_ptr->has_default()) {
|
||||
return FindParameterNodeUsers(node);
|
||||
}
|
||||
}
|
||||
|
||||
return parameter_users_info;
|
||||
}
|
||||
|
@ -745,7 +749,7 @@ static std::pair<AnfNodePtr, bool> FindParameterByFuncGraph(const AnfNodePtr &no
|
|||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto fg_parameters = fg->parameters();
|
||||
|
||||
auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr);
|
||||
auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr).first;
|
||||
auto pre_cnode = pre_node->cast<CNodePtr>();
|
||||
for (size_t index = 1; index < pre_cnode->inputs().size(); ++index) {
|
||||
auto res = FindParameter(pre_cnode->input(index), pre_cnode->func_graph());
|
||||
|
|
|
@ -590,7 +590,7 @@ static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, con
|
|||
|
||||
AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto real_node = GetRealKernelNode(node, -1);
|
||||
auto real_node = GetRealKernelNode(node, -1).first;
|
||||
if (!real_node->isa<CNode>()) {
|
||||
return real_node;
|
||||
}
|
||||
|
@ -795,7 +795,7 @@ bool PipelineTransformer::IsParameterGraph(const AnfNodePtr &node) const {
|
|||
// ParameterGraph: graph which return a parameter
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
CNodePtr call_node = nullptr;
|
||||
auto real_kernel = GetRealKernelNode(node, -1, &call_node);
|
||||
auto real_kernel = GetRealKernelNode(node, -1, &call_node).first;
|
||||
if (call_node != nullptr && real_kernel->isa<Parameter>()) {
|
||||
return true;
|
||||
}
|
||||
|
@ -806,7 +806,7 @@ AnfNodePtr PipelineTransformer::HandleParameterGraph(const AnfNodePtr &node, con
|
|||
int64_t user_stage, const ValuePtr µ, size_t pos,
|
||||
const std::vector<AnfNodePtr> &ops) {
|
||||
CNodePtr call_node = nullptr;
|
||||
auto argument = GetRealKernelNode(node, -1, &call_node);
|
||||
auto argument = GetRealKernelNode(node, -1, &call_node).first;
|
||||
|
||||
auto use_cnode = use_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(use_cnode);
|
||||
|
|
|
@ -482,6 +482,13 @@ static void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution
|
|||
IsPrimitiveCNode(cnode, prim::kPrimMakeList)) {
|
||||
return;
|
||||
}
|
||||
// Find Redistribution next_nodes
|
||||
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> next_nodes;
|
||||
RedistributionNextNode(cnode, manager, node_users_map, -1, -1, &next_nodes);
|
||||
if (next_nodes.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find Redistribution pre_nodes
|
||||
std::vector<AnfNodePtr> pre_nodes;
|
||||
RedistributionPreNode(cnode, manager, &pre_nodes);
|
||||
|
@ -489,10 +496,6 @@ static void StepRedistribution(const CNodePtr &cnode, const TensorRedistribution
|
|||
MS_LOG(EXCEPTION) << " Don't support Redistribution has multiple pre_node.";
|
||||
}
|
||||
|
||||
// Find Redistribution next_nodes
|
||||
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> next_nodes;
|
||||
RedistributionNextNode(cnode, manager, node_users_map, -1, &next_nodes);
|
||||
|
||||
// Insert Redistribution nodes between pre_nodes and next_nodes
|
||||
for (auto &pre_node : pre_nodes) {
|
||||
for (auto &next_node : next_nodes) {
|
||||
|
@ -881,65 +884,46 @@ static bool FindPreNodes(const AnfNodePtr &node, std::vector<std::string> *uniqu
|
|||
return find;
|
||||
}
|
||||
|
||||
static void FindLastNodesUniqueId(const FuncGraphPtr &root, std::vector<std::string> *unique_ids,
|
||||
std::vector<size_t> *indexes) {
|
||||
MS_EXCEPTION_IF_NULL(unique_ids);
|
||||
CNodePtr cnode = root->get_return();
|
||||
if (!FindPreNodes(cnode, unique_ids, indexes, 0)) {
|
||||
MS_LOG(WARNING) << "cannot find the last parallel care node in eval graph";
|
||||
}
|
||||
}
|
||||
|
||||
void InsertVirtualOutput(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes) {
|
||||
std::vector<std::string> last_forward_node_ids;
|
||||
std::vector<size_t> last_indexs;
|
||||
auto real_graph = PynativeParallelGraph(root, all_nodes);
|
||||
FindLastNodesUniqueId(real_graph, &last_forward_node_ids, &last_indexs);
|
||||
MS_LOG(INFO) << "there are " << last_forward_node_ids.size() << " output nodes in eval/predict";
|
||||
for (auto &node : all_nodes) {
|
||||
// here insert virtualoutput node
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto last_node_iter = std::find(last_forward_node_ids.begin(), last_forward_node_ids.end(), cnode->UniqueId());
|
||||
if (last_node_iter == last_forward_node_ids.end()) {
|
||||
continue;
|
||||
}
|
||||
for (size_t last_node_index = 0; last_node_index < last_forward_node_ids.size(); ++last_node_index) {
|
||||
if (last_forward_node_ids[last_node_index] != cnode->UniqueId()) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(INFO) << "find last node: " << cnode->fullname_with_scope() << ", the parallel care node is: "
|
||||
<< cnode->input(last_indexs[last_node_index])->fullname_with_scope();
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
|
||||
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto node_pair = manager->node_users()[cnode].front();
|
||||
if (!node_pair.first->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "the output of tuple_get_item is not a cnode";
|
||||
}
|
||||
cnode = node_pair.first->cast<CNodePtr>();
|
||||
last_indexs[last_node_index] = IntToSize(node_pair.second);
|
||||
}
|
||||
auto pre_node = cnode->input(last_indexs[last_node_index]);
|
||||
Shapes shape_outputs = GetNodeShape(pre_node);
|
||||
if (shape_outputs[0].empty()) {
|
||||
continue;
|
||||
}
|
||||
FuncGraphPtr func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto out_pair = GetRealKernelNode(real_graph->output(), -1, nullptr, false);
|
||||
auto out_node = out_pair.first;
|
||||
MS_EXCEPTION_IF_NULL(out_node);
|
||||
OperatorParams params;
|
||||
OperatorAttrs attrs;
|
||||
OperatorArgs args = std::make_pair(attrs, params);
|
||||
Operator op = std::make_pair(VIRTUAL_OUTPUT, args);
|
||||
InsertNode(op, cnode, last_indexs[last_node_index], pre_node, func_graph, VIRTUAL_OUTPUT);
|
||||
auto virtual_output_node = cnode->input(last_indexs[last_node_index]);
|
||||
AbstractBasePtr virtual_output_abstract = pre_node->abstract()->Clone();
|
||||
if (IsPrimitiveCNode(out_node, prim::kPrimMakeTuple)) {
|
||||
auto tuple = out_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple);
|
||||
for (size_t i = 1; i < tuple->inputs().size(); ++i) {
|
||||
auto cur_input = tuple->input(i);
|
||||
Shapes shape_outputs = GetNodeShape(cur_input);
|
||||
if (shape_outputs[0].empty()) {
|
||||
continue;
|
||||
}
|
||||
InsertNode(op, tuple, i, cur_input, tuple->func_graph(), VIRTUAL_OUTPUT);
|
||||
auto virtual_output_abstract = cur_input->abstract()->Clone();
|
||||
std::shared_ptr<abstract::BaseShape> virtual_output_shape = std::make_shared<abstract::Shape>(shape_outputs[0]);
|
||||
virtual_output_abstract->set_shape(virtual_output_shape);
|
||||
auto virtual_output_node = tuple->input(i);
|
||||
virtual_output_node->set_abstract(virtual_output_abstract);
|
||||
}
|
||||
} else {
|
||||
Shapes shape_outputs = GetNodeShape(out_node);
|
||||
if (shape_outputs[0].empty()) {
|
||||
return;
|
||||
}
|
||||
auto node_input = CreateInput(op, out_node, VIRTUAL_OUTPUT);
|
||||
auto cur_graph = out_node->cast<CNodePtr>()->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(cur_graph);
|
||||
auto new_node = cur_graph->NewCNode(node_input);
|
||||
auto manager = cur_graph->manager();
|
||||
(void)manager->Replace(out_node, new_node);
|
||||
auto virtual_output_abstract = out_node->abstract()->Clone();
|
||||
std::shared_ptr<abstract::BaseShape> virtual_output_shape = std::make_shared<abstract::Shape>(shape_outputs[0]);
|
||||
virtual_output_abstract->set_shape(virtual_output_shape);
|
||||
new_node->set_abstract(virtual_output_abstract);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1612,14 +1596,15 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
}
|
||||
|
||||
// if reshape's output connect to several primitive, return the first layout found
|
||||
static std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode, bool *next_is_reshape) {
|
||||
static std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode, bool *next_is_reshape,
|
||||
int make_tuple_index) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(cnode->func_graph());
|
||||
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
AnfNodeIndexSet node_set = manager->node_users()[cnode];
|
||||
for (auto &node_pair : node_set) {
|
||||
CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
|
||||
auto use_apply = node_pair.first->cast<CNodePtr>();
|
||||
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
|
||||
continue;
|
||||
}
|
||||
|
@ -1627,24 +1612,26 @@ static std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode, bool
|
|||
*next_is_reshape = true;
|
||||
continue;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_anf_node);
|
||||
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node_prim);
|
||||
MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
|
||||
if (node_prim->name() == DEPEND && node_pair.second != 1) {
|
||||
if (IsPrimitiveCNode(use_apply, prim::kPrimDepend) && node_pair.second != 1) {
|
||||
continue;
|
||||
}
|
||||
if (IsPrimitiveCNode(use_apply, prim::kPrimMakeTuple)) {
|
||||
make_tuple_index = node_pair.second;
|
||||
return FindNextLayout(use_apply, next_is_reshape, make_tuple_index);
|
||||
}
|
||||
if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
|
||||
MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name();
|
||||
if (make_tuple_index != -1) {
|
||||
node_pair.second = make_tuple_index;
|
||||
}
|
||||
MS_LOG(INFO) << "FindNextLayout success node " << use_apply->DebugString();
|
||||
*next_is_reshape = false;
|
||||
auto layout = GetInputLayoutFromCNode(node_pair);
|
||||
return std::make_shared<TensorLayout>(layout);
|
||||
}
|
||||
MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
|
||||
MS_LOG(DEBUG) << "FindNextLayout failed node " << use_apply->DebugString() << " " << IsParallelCareNode(use_apply)
|
||||
<< " " << use_apply->has_user_data<OperatorInfo>();
|
||||
|
||||
auto layout_ptr = FindNextLayout(use_apply, next_is_reshape);
|
||||
auto layout_ptr = FindNextLayout(use_apply, next_is_reshape, -1);
|
||||
if (layout_ptr) {
|
||||
return layout_ptr;
|
||||
}
|
||||
|
@ -1797,7 +1784,7 @@ static void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
reshape_info_ptr->SetInputLayout(*prev_layout_ptr);
|
||||
}
|
||||
bool is_next_reshape = false;
|
||||
auto next_layout_ptr = FindNextLayout(cnode, &is_next_reshape);
|
||||
auto next_layout_ptr = FindNextLayout(cnode, &is_next_reshape, -1);
|
||||
if (next_layout_ptr) {
|
||||
auto reshape_info_ptr = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
|
||||
reshape_info_ptr->SetOutputLayout(*next_layout_ptr);
|
||||
|
@ -1827,10 +1814,7 @@ static CNodePtr HandleDependLoss(const CNodePtr &cnode, size_t curr_depth) {
|
|||
return cnode;
|
||||
}
|
||||
|
||||
static LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph, size_t max_depth) {
|
||||
if (max_depth > MAX_RECURSIVE_DEPTH) {
|
||||
MS_LOG(EXCEPTION) << "Recursive call is larger than 100000.";
|
||||
}
|
||||
static LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
|
||||
LossNodeInfo loss_node_info;
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
CNodePtr return_node = func_graph->get_return();
|
||||
|
@ -1838,18 +1822,11 @@ static LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph, size_t max_dep
|
|||
if (return_node->size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Failure: " << return_node->DebugString() << " size is smaller than 2";
|
||||
}
|
||||
AnfNodePtr pre_node = return_node->input(1);
|
||||
auto pre_node_pair = GetRealKernelNode(return_node->input(1), -1, nullptr);
|
||||
auto pre_node = pre_node_pair.first;
|
||||
MS_EXCEPTION_IF_NULL(pre_node);
|
||||
auto pre_cnode = pre_node->cast<CNodePtr>();
|
||||
pre_cnode = HandleDependLoss(pre_cnode, 0);
|
||||
if (pre_cnode->input(0)->isa<CNode>()) {
|
||||
auto switch_cnode = pre_cnode->input(0)->cast<CNodePtr>();
|
||||
if (IsPrimitiveCNode(switch_cnode, prim::kPrimSwitch)) {
|
||||
MS_EXCEPTION_IF_NULL(switch_cnode);
|
||||
auto switch_graph = GetValueNode<FuncGraphPtr>(switch_cnode->input(2));
|
||||
return FindLossCNode(switch_graph, max_depth + 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (pre_cnode == nullptr || !IsValueNode<Primitive>(pre_cnode->input(0))) {
|
||||
return loss_node_info;
|
||||
}
|
||||
|
@ -1865,21 +1842,11 @@ static LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph, size_t max_dep
|
|||
return loss_node_info;
|
||||
}
|
||||
|
||||
// size of common cnode is larger than 1
|
||||
if (pre_cnode->size() < 2) {
|
||||
MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2";
|
||||
}
|
||||
|
||||
// return -> tuple_getitem -> loss
|
||||
if (current_prim->name() == prim::kTupleGetItem) {
|
||||
auto tuple_index = GetTupleGetItemIndex(pre_cnode);
|
||||
AnfNodePtr pre_pre_node = pre_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(pre_pre_node);
|
||||
|
||||
auto pre_pre_cnode = pre_pre_node->cast<CNodePtr>();
|
||||
if (pre_node_pair.second != -1) {
|
||||
loss_node_info.has_tuple_getitem = true;
|
||||
loss_node_info.dout_index = tuple_index;
|
||||
loss_node_info.loss_node = pre_pre_cnode;
|
||||
loss_node_info.dout_index = pre_node_pair.second;
|
||||
loss_node_info.loss_node = pre_cnode;
|
||||
return loss_node_info;
|
||||
}
|
||||
|
||||
|
@ -2127,7 +2094,7 @@ static std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const Fun
|
|||
MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
|
||||
}
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
|
||||
auto loss_node_info = FindLossCNode(func_graph, 0);
|
||||
auto loss_node_info = FindLossCNode(func_graph);
|
||||
if (loss_node_info.loss_node == nullptr) {
|
||||
MS_LOG(WARNING) << "Can not find the loss cnode";
|
||||
continue;
|
||||
|
@ -2315,7 +2282,7 @@ std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root) {
|
|||
static std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> root_forward_nodes;
|
||||
auto loss_cnode = FindLossCNode(graph, 0).loss_node;
|
||||
auto loss_cnode = FindLossCNode(graph).loss_node;
|
||||
if (loss_cnode == nullptr) {
|
||||
return root_forward_nodes;
|
||||
}
|
||||
|
|
|
@ -112,22 +112,27 @@ TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> ¶m_info)
|
|||
return tensor_info;
|
||||
}
|
||||
|
||||
AnfNodePtr GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node) {
|
||||
std::pair<AnfNodePtr, int64_t> GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node,
|
||||
bool ignore_get_item) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
|
||||
IsPrimitiveCNode(node, prim::kPrimCast)) {
|
||||
return GetRealKernelNode(node->cast<CNodePtr>()->input(1), get_item_index, call_node);
|
||||
IsPrimitiveCNode(node, prim::kPrimCast) || IsPrimitiveCNode(node, prim::kPrimVirtualDiv)) {
|
||||
return GetRealKernelNode(node->cast<CNodePtr>()->input(1), get_item_index, call_node, ignore_get_item);
|
||||
}
|
||||
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem) && ignore_get_item) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto cur_get_item_index = LongToInt(GetTupleGetItemIndex(cnode));
|
||||
auto tuple_getitem_input = cnode->input(1);
|
||||
auto pass_through_node = GetRealKernelNode(tuple_getitem_input, cur_get_item_index, call_node);
|
||||
return GetRealKernelNode(pass_through_node, get_item_index, call_node);
|
||||
return GetRealKernelNode(tuple_getitem_input, cur_get_item_index, call_node, ignore_get_item);
|
||||
}
|
||||
if (get_item_index != -1 && IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
auto make_tuple_cnode = node->cast<CNodePtr>();
|
||||
auto make_tuple_input = make_tuple_cnode->input(LongToSize(get_item_index + 1));
|
||||
return GetRealKernelNode(make_tuple_input, -1, call_node);
|
||||
return GetRealKernelNode(make_tuple_input, -1, call_node, ignore_get_item);
|
||||
}
|
||||
if (IsControlFlowNode(node)) {
|
||||
auto switch_cnode = node->cast<CNodePtr>()->input(0)->cast<CNodePtr>();
|
||||
auto fg = GetValueNode<FuncGraphPtr>(switch_cnode->input(3));
|
||||
return GetRealKernelNode(fg->output(), get_item_index, call_node, ignore_get_item);
|
||||
}
|
||||
if (node->isa<CNode>() && IsValueNode<FuncGraph>(node->cast<CNodePtr>()->input(0))) {
|
||||
if (call_node != nullptr && *call_node == nullptr) {
|
||||
|
@ -135,21 +140,33 @@ AnfNodePtr GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNo
|
|||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
||||
auto output = GetRealKernelNode(graph->output(), get_item_index, call_node);
|
||||
auto output = GetRealKernelNode(graph->output(), get_item_index, call_node, ignore_get_item).first;
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (output->isa<Parameter>()) {
|
||||
auto parameters = graph->parameters();
|
||||
auto pos_iter = std::find(parameters.begin(), parameters.end(), output);
|
||||
// If can't find in parameters, the parameter is a fv.
|
||||
if (pos_iter == parameters.end()) {
|
||||
return output;
|
||||
return std::make_pair(output, get_item_index);
|
||||
}
|
||||
auto pos = std::distance(parameters.begin(), pos_iter);
|
||||
return GetRealKernelNode(cnode->input(LongToSize(pos + 1)), -1, call_node);
|
||||
return GetRealKernelNode(cnode->input(LongToSize(pos + 1)), -1, call_node, ignore_get_item);
|
||||
}
|
||||
return output;
|
||||
return std::make_pair(output, get_item_index);
|
||||
}
|
||||
return node;
|
||||
return std::make_pair(node, get_item_index);
|
||||
}
|
||||
|
||||
static bool IsWhileGraph(const FuncGraphPtr &cur_fg, const FuncGraphPtr &fg) {
|
||||
auto cur_fg_map = cur_fg->func_graph_cnodes_index();
|
||||
for (auto &cur_fg_use : cur_fg_map) {
|
||||
auto temp_node = cur_fg_use.first->first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(temp_node);
|
||||
if (temp_node->func_graph() == fg) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) {
|
||||
|
@ -280,8 +297,13 @@ int64_t GetTupleGetItemIndex(const CNodePtr &cnode) {
|
|||
return tuple_index_value->cast<Int64ImmPtr>()->value();
|
||||
}
|
||||
|
||||
static bool IsNoNeedRedistribution(const CNodePtr &use_cnode, int use_index) {
|
||||
return (IsPrimitiveCNode(use_cnode, prim::kPrimDepend) && use_index != 1) || use_cnode->input(0)->isa<CNode>() ||
|
||||
IsPrimitiveCNode(use_cnode, prim::kPrimUpdateState) || IsPrimitiveCNode(use_cnode, prim::kPrimSwitch);
|
||||
}
|
||||
|
||||
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager,
|
||||
const NodeUsersMap &node_users_map, int64_t get_item_index,
|
||||
const NodeUsersMap &node_users_map, int64_t get_item_index, int64_t make_tuple_index,
|
||||
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node_users_map.count(node) == 0) {
|
||||
|
@ -292,38 +314,68 @@ void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &m
|
|||
auto use_cnode = node_pair.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(use_cnode);
|
||||
if (IsValueNode<FuncGraph>(use_cnode->input(0))) {
|
||||
auto cur_fg = use_cnode->func_graph();
|
||||
auto fg = GetValueNode<FuncGraphPtr>(use_cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
if (IsWhileGraph(cur_fg, fg)) {
|
||||
continue;
|
||||
}
|
||||
auto fg_parameters = fg->parameters();
|
||||
auto param = fg_parameters[IntToSize(node_pair.second - 1)];
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
RedistributionNextNode(param, manager, node_users_map, get_item_index, next_nodes);
|
||||
RedistributionNextNode(param, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
|
||||
continue;
|
||||
}
|
||||
if (IsPrimitiveCNode(use_cnode, prim::kPrimMakeTuple)) {
|
||||
make_tuple_index = node_pair.second - 1;
|
||||
RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
|
||||
continue;
|
||||
}
|
||||
if (IsPrimitiveCNode(use_cnode, prim::kPrimTupleGetItem)) {
|
||||
get_item_index = LongToInt(GetTupleGetItemIndex(use_cnode));
|
||||
auto temp = LongToInt(GetTupleGetItemIndex(use_cnode));
|
||||
if (temp != make_tuple_index && make_tuple_index != -1) {
|
||||
continue;
|
||||
}
|
||||
RedistributionNextNode(use_cnode, manager, node_users_map, temp, -1, next_nodes);
|
||||
continue;
|
||||
}
|
||||
if (IsPrimitiveCNode(use_cnode, prim::kPrimReturn)) {
|
||||
auto fg = use_cnode->func_graph();
|
||||
auto fg_map = fg->func_graph_cnodes_index();
|
||||
for (auto &fg_use : fg_map) {
|
||||
auto fg_node = fg_use.first->first->cast<CNodePtr>();
|
||||
constexpr int SWITCH_LAST_INPUT_INDEX = 3;
|
||||
if (IsWhileGraph(fg, fg_node->func_graph()) && fg_use.first->second == SWITCH_LAST_INPUT_INDEX) {
|
||||
RedistributionNextNode(fg_node, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
|
||||
}
|
||||
}
|
||||
}
|
||||
// depend, auto monad and control flow op don't need to jump over
|
||||
if ((IsPrimitiveCNode(use_cnode, prim::kPrimDepend) && node_pair.second != 1) ||
|
||||
IsPrimitiveCNode(use_cnode, prim::kPrimUpdateState) || IsPrimitiveCNode(use_cnode, prim::kPrimSwitch)) {
|
||||
if (IsNoNeedRedistribution(use_cnode, node_pair.second)) {
|
||||
continue;
|
||||
}
|
||||
if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data<OperatorInfo>()) {
|
||||
next_nodes->push_back(std::make_pair(node_pair, get_item_index));
|
||||
} else if (use_cnode->input(0)->isa<CNode>()) {
|
||||
if (make_tuple_index != -1) {
|
||||
auto real_node = GetRealKernelNode(use_cnode->input(1), -1, nullptr);
|
||||
if (IsPrimitiveCNode(real_node.first, prim::kPrimMakeTuple)) {
|
||||
next_nodes->push_back(std::make_pair(std::make_pair(real_node.first, make_tuple_index + 1), get_item_index));
|
||||
make_tuple_index = -1;
|
||||
continue;
|
||||
} else {
|
||||
// search recursively
|
||||
RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, next_nodes);
|
||||
}
|
||||
}
|
||||
next_nodes->push_back(std::make_pair(node_pair, get_item_index));
|
||||
continue;
|
||||
}
|
||||
// search recursively
|
||||
RedistributionNextNode(use_cnode, manager, node_users_map, get_item_index, make_tuple_index, next_nodes);
|
||||
}
|
||||
}
|
||||
|
||||
void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
|
||||
std::vector<AnfNodePtr> *pre_nodes) {
|
||||
if (IsValueNode<FuncGraph>(cnode->input(0))) {
|
||||
auto fg = GetValueNode<FuncGraphPtr>(cnode->input(0));
|
||||
auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr);
|
||||
auto pre_node = GetRealKernelNode(fg->output(), -1, nullptr).first;
|
||||
if (!pre_node) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -66,11 +66,12 @@ TensorInfo GetInputsTensorInfo(const std::pair<AnfNodePtr, int64_t> ¶m_info)
|
|||
AnfNodePtr CheckMakeTupleSplit(const AnfNodePtr &node, const FuncGraphManagerPtr &manager);
|
||||
bool IsControlFlowNode(const AnfNodePtr &node);
|
||||
int64_t GetTupleGetItemIndex(const CNodePtr &cnode);
|
||||
AnfNodePtr GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index, CNodePtr *call_node = nullptr);
|
||||
std::pair<AnfNodePtr, int64_t> GetRealKernelNode(const AnfNodePtr &node, int64_t get_item_index,
|
||||
CNodePtr *call_node = nullptr, bool ignore_get_item = true);
|
||||
void RedistributionPreNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager,
|
||||
std::vector<AnfNodePtr> *pre_nodes);
|
||||
void RedistributionNextNode(const AnfNodePtr &node, const FuncGraphManagerPtr &manager,
|
||||
const NodeUsersMap &node_users_map, int64_t get_item_index,
|
||||
const NodeUsersMap &node_users_map, int64_t get_item_index, int64_t make_tuple_index,
|
||||
std::vector<std::pair<std::pair<AnfNodePtr, int>, int>> *next_nodes);
|
||||
|
||||
// for specific scenarios
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.train import Model
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
class DatasetLenet():
|
||||
def __init__(self, data, label, length=3):
|
||||
self.data = data
|
||||
self.label = label
|
||||
self.index = 1
|
||||
self.length = length
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= self.length:
|
||||
raise StopIteration
|
||||
self.index += 1
|
||||
return self.data, self.label
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_size():
|
||||
return 32
|
||||
|
||||
@staticmethod
|
||||
def get_repeat_count():
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def get_batch_size():
|
||||
return 32
|
||||
|
||||
def create_tuple_iterator(self, num_epochs=1, do_copy=True):
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.index = 0
|
||||
|
||||
|
||||
class MatMulCell(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul()
|
||||
self.relu = P.ReLU().shard(((2, 1),))
|
||||
self.weight = Parameter(initializer("ones", [64, 64]), name="param1")
|
||||
|
||||
def construct(self, x):
|
||||
out = self.matmul(x, self.weight)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ConcatCell(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.concat = P.Concat().shard(((1, 8), (1, 8)))
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.concat((y, x))
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.matmul = P.MatMul().shard(((2, 4), (4, 1)))
|
||||
self.weight = Parameter(initializer("ones", [64, 64]), name="param")
|
||||
self.index = Parameter(Tensor(0, mstype.int32), requires_grad=False)
|
||||
self.cell1 = MatMulCell()
|
||||
self.cell2 = ConcatCell()
|
||||
self.relu = P.ReLU().shard(((8, 1),))
|
||||
|
||||
def construct(self, x, y):
|
||||
out = self.matmul(x, self.weight)
|
||||
while self.index < 3:
|
||||
out = self.cell1(out)
|
||||
self.index += 1
|
||||
out = self.cell2(out, x)
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
def test_parallel_while():
|
||||
"""
|
||||
Feature: test parallel while.
|
||||
Description: while + concat.
|
||||
Expectation: Successful graph compilation.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
||||
net = Net()
|
||||
data = Tensor(np.ones([128, 64]), dtype=ms.float32)
|
||||
label = Tensor(np.ones([8, 8]), dtype=ms.float32)
|
||||
dataset = DatasetLenet(data, label, 3)
|
||||
opt = nn.Lamb(net.trainable_params(), learning_rate=0.01)
|
||||
model = Model(net, optimizer=opt)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
Loading…
Reference in New Issue