!14612 remove ControlDepend

From: @huangbingjian
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-04-07 18:36:29 +08:00 committed by Gitee
commit e2260a2f09
37 changed files with 118 additions and 752 deletions

View File

@ -108,7 +108,7 @@ bool InputCheck(const AnfNodePtr &node) {
MS_LOG(INFO) << "Data->TransData->split, can not optimizer."; MS_LOG(INFO) << "Data->TransData->split, can not optimizer.";
return false; return false;
} }
if (in_node_name == prim::kPrimControlDepend->name() || in_node_name == prim::kPrimDepend->name()) { if (in_node_name == prim::kPrimDepend->name()) {
return false; return false;
} }
if ((AnfAlgo::HasNodeAttr("non_task", in_cnode) && AnfAlgo::GetNodeAttr<bool>(in_node, "non_task")) || if ((AnfAlgo::HasNodeAttr("non_task", in_cnode) && AnfAlgo::GetNodeAttr<bool>(in_node, "non_task")) ||
@ -131,7 +131,7 @@ bool OutputCheck(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
return false; return false;
} }
for (const auto &item : outputs) { for (const auto &item : outputs) {
if (IsPrimitiveCNode(item, prim::kPrimControlDepend) || IsPrimitiveCNode(item, prim::kPrimDepend)) { if (IsPrimitiveCNode(item, prim::kPrimDepend)) {
MS_LOG(INFO) << "Split has control edge, can not optimizer."; MS_LOG(INFO) << "Split has control edge, can not optimizer.";
return false; return false;
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -168,7 +168,7 @@ const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, c
(void)manager->Replace(output, bn_training_update_v2_outputs[output_index]); (void)manager->Replace(output, bn_training_update_v2_outputs[output_index]);
output_index++; output_index++;
} }
// Return the new node for control depends. // Return the new node.
return bn_training_update_v2; return bn_training_update_v2;
} }
} // namespace opt } // namespace opt

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -201,20 +201,6 @@ const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &f
} }
} }
} }
if (dropout_do_mask1 != nullptr) {
// Dropout is used by ControlDepend in some situation, need to replace ControlDepend.
auto &users = manager->node_users();
iter = users.find(dropout_node);
if (iter != users.end()) {
for (auto &node_index : iter->second) {
auto used_node = node_index.first;
if (AnfAlgo::CheckPrimitiveType(used_node, prim::kPrimControlDepend)) {
(void)manager->Replace(used_node, dropout_do_mask1);
break;
}
}
}
}
// CreateDropoutDoMask-backward // CreateDropoutDoMask-backward
if (equiv->find(grad_input_) == equiv->end()) { if (equiv->find(grad_input_) == equiv->end()) {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -426,9 +426,6 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOu
} }
auto output_info_list = iter->second; auto output_info_list = iter->second;
for (const auto &output_info : output_info_list) { for (const auto &output_info : output_info_list) {
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) {
continue;
}
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
output_info.second == kDependAttachNodeIndex) { output_info.second == kDependAttachNodeIndex) {
continue; continue;
@ -908,16 +905,12 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager(); auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
// find BatchNorm's output which is a Depend or ControlDepend // find BatchNorm's output which is a Depend
for (const auto &node_index : manager->node_users()[old_node]) { for (const auto &node_index : manager->node_users()[old_node]) {
AnfNodePtr output = node_index.first; AnfNodePtr output = node_index.first;
size_t index = IntToSize(node_index.second); size_t index = IntToSize(node_index.second);
MS_EXCEPTION_IF_NULL(output); MS_EXCEPTION_IF_NULL(output);
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) {
auto control_depend = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(control_depend);
control_depend->set_input(index, new_node);
} else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) {
auto depend = output->cast<CNodePtr>(); auto depend = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend); MS_EXCEPTION_IF_NULL(depend);
depend->set_input(index, new_node); depend->set_input(index, new_node);

View File

@ -210,7 +210,7 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor
// Create a new value node of func graph,not kernel graph // Create a new value node of func graph,not kernel graph
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node); ValueNodePtr MakeValueNode(const ValueNodePtr &value_node);
// Transfer depend or control_depend to the new node // Transfer depend to the new node
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node); void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node);
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);

View File

@ -327,7 +327,7 @@ void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node,
void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node,
const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index) { const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index) {
// Create depend node to hold new control depend node. // Create depend node to hold execution order.
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node}; AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node};
auto depend_cnode = main_graph->NewCNode(d_inputs); auto depend_cnode = main_graph->NewCNode(d_inputs);
depend_cnode->set_abstract(clean_node->abstract()); depend_cnode->set_abstract(clean_node->abstract());
@ -501,12 +501,11 @@ bool AtomicCleanInsertter::IsExistStructuralObstacle(const KernelGraphPtr &main_
const FuncGraphManagerPtr &mng) { const FuncGraphManagerPtr &mng) {
auto reduce_users = FindOriginCNodeUsers(main_graph, node, mng, false); auto reduce_users = FindOriginCNodeUsers(main_graph, node, mng, false);
// If reduce user is MakeTuple and not last node, there is no cheap method to set right running order between reduce // If reduce user is MakeTuple and not last node, there is no cheap method to set right running order between reduce
// node and user node. If reduce is Depend or ControlDepend node, the origin node may be wrong! // node and user node. If reduce is Depend node, the origin node may be wrong!
return std::all_of(reduce_users.cbegin(), reduce_users.cend(), return std::all_of(
[&main_graph](const std::pair<AnfNodePtr, int> &user_info) -> bool { reduce_users.cbegin(), reduce_users.cend(), [&main_graph](const std::pair<AnfNodePtr, int> &user_info) -> bool {
auto &user = user_info.first; auto &user = user_info.first;
if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend) || if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend)) &&
IsPrimitiveCNode(user, prim::kPrimControlDepend)) &&
!(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) { !(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) {
return false; return false;
} else { } else {

View File

@ -123,9 +123,9 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
bool changed = false; bool changed = false;
auto mng = kernel_graph->manager(); auto mng = kernel_graph->manager();
// depend_prior[depend] = pair(prior, controlDependNode) // depend_prior[depend] = pair(prior, behind)
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior; std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior;
InitDependPrior(todos, &depend_prior); // InitDependPrior(todos, &depend_prior);
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) { for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto node = (*iter)->cast<CNodePtr>(); auto node = (*iter)->cast<CNodePtr>();

View File

@ -657,76 +657,6 @@ void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
#endif #endif
} }
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto cnode = (*iter)->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
}
if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimControlDepend)) {
continue;
}
auto prior_node = cnode->input(kControlDependPriorIndex);
auto depend_node = cnode->input(kControlDependBehindIndex);
MS_EXCEPTION_IF_NULL(prior_node);
MS_EXCEPTION_IF_NULL(depend_node);
std::vector<AnfNodePtr> prior_nodes = {prior_node};
std::vector<AnfNodePtr> depend_nodes = {depend_node};
int depend_mode = 0;
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode);
}
auto GetOutputNodes = [cnode](const AnfNodePtr &param) -> std::vector<AnfNodePtr> {
std::vector<AnfNodePtr> out_nodes;
auto user_set = param->func_graph()->manager()->node_users()[param];
for (auto iter = user_set.cbegin(); iter != user_set.cend(); ++iter) {
if (iter->first != cnode) {
out_nodes.push_back(iter->first);
}
}
return out_nodes;
};
if (prior_node->isa<Parameter>() && depend_mode == 1) {
prior_nodes = GetOutputNodes(prior_node);
}
if (depend_node->isa<Parameter>()) {
depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector<AnfNodePtr>{};
}
std::vector<AnfNodePtr> real_prior_nodes;
std::set<AnfNodePtr> prior_visited;
for (const auto &tmp : prior_nodes) {
AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
}
prior_visited.clear();
std::vector<AnfNodePtr> real_depend_nodes;
std::set<AnfNodePtr> depend_visited;
for (const auto &tmp : depend_nodes) {
AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
}
depend_visited.clear();
for (auto &prior : real_prior_nodes) {
if (AnfAlgo::CheckPrimitiveType(prior, prim::kPrimControlDepend)) {
continue;
}
for (auto &depend : real_depend_nodes) {
if (AnfAlgo::CheckPrimitiveType(depend, prim::kPrimControlDepend)) {
continue;
}
depend_prior->insert({depend, std::make_pair(prior, cnode)});
}
}
real_prior_nodes.clear();
real_depend_nodes.clear();
}
}
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior, void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) { const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) {
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri; std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri;

View File

@ -75,8 +75,6 @@ std::vector<PrimitivePtr> GetFusibleOpList();
bool IsBasicFuseOp(const AnfNodePtr &node); bool IsBasicFuseOp(const AnfNodePtr &node);
bool IsFusibleOp(const AnfNodePtr &node); bool IsFusibleOp(const AnfNodePtr &node);
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE); void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior);
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior, void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs); const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs);

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -55,7 +55,7 @@ CNodePtr GetRealPrevCNode(const AnfNodePtr &node, size_t index, std::vector<Kern
auto item_idx = GetValue<int64_t>(value_node->value()); auto item_idx = GetValue<int64_t>(value_node->value());
pass_vector->push_back(make_pair(cnode, IntToSize(1))); pass_vector->push_back(make_pair(cnode, IntToSize(1)));
return GetRealPrevCNode(cnode->input(1), LongToSize(item_idx), pass_vector); return GetRealPrevCNode(cnode->input(1), LongToSize(item_idx), pass_vector);
} else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { } else if (IsPrimitive(input0, prim::kPrimDepend)) {
pass_vector->push_back(make_pair(cnode, IntToSize(1))); pass_vector->push_back(make_pair(cnode, IntToSize(1)));
return GetRealPrevCNode(cnode->input(1), 0, pass_vector); return GetRealPrevCNode(cnode->input(1), 0, pass_vector);
} else if (IsPrimitive(input0, prim::kPrimUpdateState)) { } else if (IsPrimitive(input0, prim::kPrimUpdateState)) {
@ -92,8 +92,7 @@ const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNode
auto pass_size = pass_vector->size(); auto pass_size = pass_vector->size();
for (size_t idx = 1; idx <= pass_size - 1; ++idx) { for (size_t idx = 1; idx <= pass_size - 1; ++idx) {
auto nd = (*pass_vector)[idx].first; auto nd = (*pass_vector)[idx].first;
if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend) || if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend)) {
AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) {
has_depend_node = true; has_depend_node = true;
} }
if (users[nd].size() >= 2) { if (users[nd].size() >= 2) {

View File

@ -248,7 +248,7 @@ class AnfRuntimeAlgorithm {
static void InferShape(const CNodePtr &node); static void InferShape(const CNodePtr &node);
static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index); static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
// Find control_depend real input nodes. // Find real input nodes.
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
std::set<AnfNodePtr> *visited); std::set<AnfNodePtr> *visited);
}; };

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -534,14 +534,17 @@ void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNul
return_node->set_input(kFirstDataInputIndex, depend_node); return_node->set_input(kFirstDataInputIndex, depend_node);
} }
void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node, void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> prior_node,
NotNull<AnfNodePtr> second_node) { NotNull<AnfNodePtr> behind_node) {
MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString() MS_LOG(INFO) << "Insert control dependence at the end of graph, the prior node is " << prior_node->DebugString()
<< ", the second node is " << second_node->DebugString(); << ", the behind node is " << behind_node->DebugString();
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())), auto manager = kg->manager();
first_node, second_node}; MS_EXCEPTION_IF_NULL(manager);
auto control_depend = kg->NewCNode(inputs); AnfNodePtrList inputs = {NewValueNode(prim::kPrimDepend), behind_node, prior_node};
InsertDependToGraph(kg, NOT_NULL(control_depend)); auto depend_cnode = kg->NewCNode(inputs);
if (!manager->Replace(behind_node, depend_cnode)) {
MS_LOG(EXCEPTION) << behind_node->DebugString() << ", replace node failed.";
}
} }
void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node, void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -422,7 +422,7 @@ void KernelGraph::CheckLoop() {
none_zero_nodes[it.first] = it.second; none_zero_nodes[it.first] = it.second;
} }
} }
// if don't consider control depend and loop exit,a exception will be throw // if don't consider loop exit,a exception will be throw
if (!none_zero_nodes.empty()) { if (!none_zero_nodes.empty()) {
MS_LOG(WARNING) << "Nums of loop:" << GetLoopNum(none_zero_nodes); MS_LOG(WARNING) << "Nums of loop:" << GetLoopNum(none_zero_nodes);
MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size(); MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size();
@ -815,61 +815,10 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
return output_nodes; return output_nodes;
} }
// update the depend relations of control depend
void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
for (const auto &node : depends) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) {
MS_LOG(EXCEPTION) << node->DebugString() << " is not a control depend";
}
auto prior_node = cnode->input(kControlDependPriorIndex);
auto depend_node = cnode->input(kControlDependBehindIndex);
MS_EXCEPTION_IF_NULL(prior_node);
MS_EXCEPTION_IF_NULL(depend_node);
std::vector<AnfNodePtr> prior_nodes = {prior_node};
std::vector<AnfNodePtr> depend_nodes = {depend_node};
int depend_mode = 0;
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode);
}
MS_LOG(DEBUG) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString()
<< "], depend_mode :" << depend_mode << ".";
if (prior_node->isa<Parameter>() && depend_mode == 1) {
prior_nodes = GetOutputNodes(prior_node);
}
if (depend_node->isa<Parameter>()) {
depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector<AnfNodePtr>{};
}
std::vector<AnfNodePtr> real_prior_nodes;
std::set<AnfNodePtr> prior_visited;
for (const auto &tmp : prior_nodes) {
AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
}
std::vector<AnfNodePtr> real_depend_nodes;
std::set<AnfNodePtr> depend_visited;
for (const auto &tmp : depend_nodes) {
AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
}
UpdateNodeInputOutputEdges(real_prior_nodes, real_depend_nodes);
}
}
void KernelGraph::UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real_prior_nodes, void KernelGraph::UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real_prior_nodes,
const std::vector<AnfNodePtr> &real_depend_nodes) { const std::vector<AnfNodePtr> &real_depend_nodes) {
for (auto &first_node : real_prior_nodes) { for (auto &first_node : real_prior_nodes) {
if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) {
continue;
}
for (auto &second_node : real_depend_nodes) { for (auto &second_node : real_depend_nodes) {
if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) {
continue;
}
MS_EXCEPTION_IF_NULL(first_node); MS_EXCEPTION_IF_NULL(first_node);
MS_EXCEPTION_IF_NULL(second_node); MS_EXCEPTION_IF_NULL(second_node);
MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString(); MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString();
@ -878,35 +827,6 @@ void KernelGraph::UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real
} }
} }
bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
std::unordered_set<AnfNodePtr> *visited_nodes) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(que);
MS_EXCEPTION_IF_NULL(visited_nodes);
if (!node->isa<CNode>()) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) {
return false;
}
// set the control depend visited but don't push it into the que
if (visited_nodes->find(node) != visited_nodes->end()) {
return true;
}
(void)visited_nodes->insert(cnode);
// add a 0 depend num to keep the link relations to prepare for finding zero output nodes
auto prior_node = cnode->input(kControlDependPriorIndex);
auto depend_node = cnode->input(kControlDependBehindIndex);
for (const auto &input : cnode->inputs()) {
AddDependEdge(node, input, 0);
}
PushNoVisitedNode(depend_node, que, visited_nodes);
PushNoVisitedNode(prior_node, que, visited_nodes);
return true;
}
void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) { void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
MS_EXCEPTION_IF_NULL(seed_nodes); MS_EXCEPTION_IF_NULL(seed_nodes);
node_output_edges_.clear(); node_output_edges_.clear();

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -286,15 +286,11 @@ class KernelGraph : public FuncGraph {
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true); std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true);
// update node edge list // update node edge list
void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes); void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes);
// add node depend edge by data edge or control depend // add node depend edge by data edge
void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num); void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num);
void UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real_prior_nodes, void UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real_prior_nodes,
const std::vector<AnfNodePtr> &real_depend_nodes); const std::vector<AnfNodePtr> &real_depend_nodes);
// handle control depend
std::vector<AnfNodePtr> GetOutputNodes(const AnfNodePtr &node); std::vector<AnfNodePtr> GetOutputNodes(const AnfNodePtr &node);
bool HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
std::unordered_set<AnfNodePtr> *visited_nodes);
void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends);
AnfNodePtr TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value); AnfNodePtr TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value);
AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract); AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract);
AnfNodePtr TransCNodeTuple(const CNodePtr &node); AnfNodePtr TransCNodeTuple(const CNodePtr &node);

View File

@ -223,11 +223,9 @@ BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &gra
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
VectorRef ret; VectorRef ret;
for (size_t i = 1; i < cnode->inputs().size(); ++i) { for (size_t i = 1; i < cnode->inputs().size(); ++i) {
if (!AnfAlgo::CheckPrimitiveType(cnode->input(i), prim::kPrimControlDepend)) {
auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node); auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node);
ret.push_back(out); ret.push_back(out);
} }
}
return ret; return ret;
} }
// if is graph return nothing ,the function should return a null anylist // if is graph return nothing ,the function should return a null anylist
@ -386,22 +384,6 @@ bool ExistSummaryNode(const KernelGraph *graph) {
return false; return false;
} }
bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &node_inputs = cnode->inputs();
for (size_t i = 1; i < node_inputs.size(); ++i) {
if (!AnfAlgo::CheckPrimitiveType(node_inputs[i], prim::kPrimControlDepend)) {
return false;
}
}
return true;
}
void GetParameterIndex(KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs, void GetParameterIndex(KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
std::map<AnfNodePtr, size_t> *parameter_index) { std::map<AnfNodePtr, size_t> *parameter_index) {
size_t index = 0; size_t index = 0;
@ -692,9 +674,6 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) { AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
if (IgnoreCreateParameterForMakeTuple(node)) {
return nullptr;
}
auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract())); auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
auto parameters = AnfAlgo::GetAllOutput(new_parameter); auto parameters = AnfAlgo::GetAllOutput(new_parameter);
std::vector<AnfNodePtr> pre_graph_out = {node}; std::vector<AnfNodePtr> pre_graph_out = {node};
@ -1872,9 +1851,6 @@ std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_gr
auto &users = front_func_graph_manager->node_users()[front_node]; auto &users = front_func_graph_manager->node_users()[front_node];
std::vector<AnfNodePtr> result; std::vector<AnfNodePtr> result;
for (auto &user : users) { for (auto &user : users) {
if (IsPrimitiveCNode(user.first, prim::kPrimControlDepend)) {
continue;
}
if (IsPrimitiveCNode(user.first, prim::kPrimDepend)) { if (IsPrimitiveCNode(user.first, prim::kPrimDepend)) {
auto depend_cnode = user.first->cast<CNodePtr>(); auto depend_cnode = user.first->cast<CNodePtr>();
if (depend_cnode == nullptr) { if (depend_cnode == nullptr) {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -84,7 +84,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
} }
} }
std::vector<PrimitivePtr> adapter_convert_ops = {prim::kPrimDepend, prim::kPrimControlDepend, prim::kPrimLoad}; std::vector<PrimitivePtr> adapter_convert_ops = {prim::kPrimDepend, prim::kPrimLoad};
for (auto &item : adapter_convert_ops) { for (auto &item : adapter_convert_ops) {
if (IsPrimitiveCNode(node, item)) { if (IsPrimitiveCNode(node, item)) {
return true; return true;
@ -243,8 +243,7 @@ CNodePtr MergeNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, int64_t sw
return merge_op; return merge_op;
} }
// construct a depend node with merge output node, merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data)) // merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data))
// control_depend(output_node, square_op)
AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &output_node, AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &output_node,
int64_t switch_idx) { int64_t switch_idx) {
tensor::TensorPtr const_data = GetConstData(); tensor::TensorPtr const_data = GetConstData();
@ -259,54 +258,21 @@ AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr
SetSquareOp(switch_idx, square_op); SetSquareOp(switch_idx, square_op);
} }
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
AnfNodePtrList inputs = {NewValueNode(prim::kPrimDepend), square_op, output_node};
auto depend_cnode = graph->NewCNode(inputs);
if (!manager->Replace(square_op, depend_cnode)) {
MS_LOG(EXCEPTION) << square_op->DebugString() << ", replace node failed.";
}
CNodePtr merge_op = GetMergeOp(switch_idx); CNodePtr merge_op = GetMergeOp(switch_idx);
if (merge_op == nullptr) { if (merge_op == nullptr) {
merge_op = MergeNode(graph, cond, switch_idx, const_data, square_op); merge_op = MergeNode(graph, cond, switch_idx, const_data, square_op);
SetMergeOp(switch_idx, merge_op); SetMergeOp(switch_idx, merge_op);
} }
std::vector<AnfNodePtr> control_depend_nodes{NewValueNode(prim::kPrimControlDepend), output_node, square_op}; return merge_op;
auto control_depend_op = graph->NewCNode(control_depend_nodes);
std::vector<AnfNodePtr> depend_nodes{NewValueNode(prim::kPrimDepend), merge_op, control_depend_op};
auto depend_op = graph->NewCNode(depend_nodes);
return depend_op;
}
// construct a merge output and add dependency with the netoutput node from control_depend
// we need to reserve the control_depend node, besides the generated merge node and control_depend node
CNodePtr GenerateSwitchControlDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond,
const AnfNodePtr &ctrl_dep_node, const AnfNodePtr &ctrl_depend_dst,
int64_t switch_idx) {
auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast<PrimitivePtr>();
auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast<PrimitivePtr>();
std::vector<int64_t> shp = {1};
tensor::TensorPtr const_data = std::make_shared<tensor::Tensor>(kInt64->type_id(), shp);
auto *val = static_cast<int64_t *>(const_data->data_c());
*val = 0;
// for the control_depend netoutput node , add two const data to merge the flow ,one for depended node with same
// switch the other use the opposite
auto ctrl_data = NewValueNode(const_data);
auto oppsite_ctrl_data = NewValueNode(const_data);
auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx);
auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx);
std::vector<AnfNodePtr> square_nodes{NewValueNode(PrimSquare), ctrl_node};
auto square_op = graph->NewCNode(square_nodes);
std::vector<AnfNodePtr> merge_nodes;
merge_nodes.push_back(NewValueNode(PrimMerge));
std::vector<AnfNodePtr> make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node};
merge_nodes.push_back(graph->NewCNode(make_tuple_nodes));
auto merge_output = graph->NewCNode(merge_nodes);
std::vector<AnfNodePtr> control_depend_nodes{NewValueNode(prim::kPrimControlDepend), ctrl_depend_dst, square_op};
auto cond_dep_output = graph->NewCNode(control_depend_nodes);
std::vector<AnfNodePtr> depended_make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), ctrl_dep_node, merge_output,
cond_dep_output};
return graph->NewCNode(depended_make_tuple_nodes);
} }
// generate switch nodes for true graph node inputs // generate switch nodes for true graph node inputs
@ -321,26 +287,12 @@ AnfNodePtr GenerateSwitchDependFalseNode(const FuncGraphPtr &graph, const AnfNod
return GenerateSwitchDependNode(graph, cond, data, 0); return GenerateSwitchDependNode(graph, cond, data, 0);
} }
// generate switch nodes for true graph node inputs // to judge if the node used in Depend is a net output node
CNodePtr GenerateSwitchControlDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond,
const AnfNodePtr &con_input, const AnfNodePtr &output) {
// for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch
return GenerateSwitchControlDependNode(graph, cond, con_input, output, 1);
}
// generate switch nodes for false graph node inputs
CNodePtr GenerateSwitchControlDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond,
const AnfNodePtr &con_input, const AnfNodePtr &output) {
// for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch
return GenerateSwitchControlDependNode(graph, cond, con_input, output, 0);
}
// to judge if the node used in ControlDepend is a net output node
bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) { bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
auto uses = manager->node_users()[node]; auto uses = manager->node_users()[node];
bool is_output_node = true; bool is_output_node = true;
for (auto &item : uses) { for (auto &item : uses) {
if (IsPrimitiveCNode(item.first, prim::kPrimControlDepend) || IsPrimitiveCNode(item.first, prim::kPrimDepend)) { if (IsPrimitiveCNode(item.first, prim::kPrimDepend)) {
continue; continue;
} }
is_output_node = false; is_output_node = false;
@ -353,8 +305,7 @@ bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node)
void GenerateReplNodeForDependMakeTuple( void GenerateReplNodeForDependMakeTuple(
const AnfNodePtr &depended_node, const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &depended_node, const FuncGraphPtr &graph, const AnfNodePtr &cond,
const std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> &repl_node, const std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> &repl_node,
const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func, const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func) {
const std::function<CNodePtr(FuncGraphPtr, AnfNodePtr, AnfNodePtr, AnfNodePtr)> &gen_ctl_depd_func) {
MS_EXCEPTION_IF_NULL(graph->manager()); MS_EXCEPTION_IF_NULL(graph->manager());
auto make_tuple_inputs = depended_node->cast<CNodePtr>()->inputs(); auto make_tuple_inputs = depended_node->cast<CNodePtr>()->inputs();
@ -368,26 +319,6 @@ void GenerateReplNodeForDependMakeTuple(
new_make_tuple_nodes.push_back(depended_tuple_input_node); new_make_tuple_nodes.push_back(depended_tuple_input_node);
continue; continue;
} }
if (IsPrimitiveCNode(depended_tuple_input_node->cast<CNodePtr>(), prim::kPrimControlDepend)) {
// only when the control depend input is not square op (the op to use as merge output)
auto control_inputs = depended_tuple_input_node->cast<CNodePtr>()->inputs();
if (control_inputs.size() != 3) {
MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size();
}
// control inputs: primitive, src, dst
auto dst_node = control_inputs[2];
if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) {
auto gen_node = gen_ctl_depd_func(graph, cond, make_tuple_inputs[idx], dst_node);
MS_EXCEPTION_IF_NULL(gen_node);
auto tuple_inputs = gen_node->inputs();
// add depended tuple inputs to new_make_tuple directly
for (size_t i = 1; i < tuple_inputs.size(); i++) {
new_make_tuple_nodes.push_back(tuple_inputs[i]);
}
}
replace_make_tuple = true;
continue;
}
if (graph->manager()->node_users()[depended_tuple_input_node].size() == 1) { if (graph->manager()->node_users()[depended_tuple_input_node].size() == 1) {
auto gen_node = generate_func(graph, cond, depended_tuple_input_node); auto gen_node = generate_func(graph, cond, depended_tuple_input_node);
@ -408,8 +339,7 @@ void GenerateReplNodeForDependMakeTuple(
void GenerateRepDepend( void GenerateRepDepend(
const CNodePtr &node, const FuncGraphPtr &graph, const AnfNodePtr &cond, const CNodePtr &node, const FuncGraphPtr &graph, const AnfNodePtr &cond,
const std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> &repl_node, const std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> &repl_node,
const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func, const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func) {
const std::function<CNodePtr(FuncGraphPtr, AnfNodePtr, AnfNodePtr, AnfNodePtr)> &gen_ctl_depd_func) {
auto inputs = node->inputs(); auto inputs = node->inputs();
if (inputs.size() != 3) { if (inputs.size() != 3) {
MS_LOG(EXCEPTION) << "Inputs should be [depend, actual_value, depended_node]."; MS_LOG(EXCEPTION) << "Inputs should be [depend, actual_value, depended_node].";
@ -422,19 +352,7 @@ void GenerateRepDepend(
new_depened_inputs.push_back(inputs[1]); new_depened_inputs.push_back(inputs[1]);
// depended node should be make_tuple or a single depended node // depended node should be make_tuple or a single depended node
if (IsPrimitiveCNode(depended_node, prim::kPrimMakeTuple)) { if (IsPrimitiveCNode(depended_node, prim::kPrimMakeTuple)) {
GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func, gen_ctl_depd_func); GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func);
} else if (IsPrimitiveCNode(depended_node, prim::kPrimControlDepend)) {
// only when the control depend input is not square op (the op to use as merge output)
auto control_inputs = depended_node->cast<CNodePtr>()->inputs();
// control inputs: primitive, src, dst
if (control_inputs.size() != 3) {
MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size();
}
auto dst_node = control_inputs[2];
if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) {
auto gen_node = gen_ctl_depd_func(graph, cond, depended_node, dst_node);
(*repl_node)[depended_node] = gen_node;
}
} else { } else {
// Check if there is only single user for depend_node. // Check if there is only single user for depend_node.
if (graph->manager()->node_users()[depended_node].size() == 1) { if (graph->manager()->node_users()[depended_node].size() == 1) {
@ -448,11 +366,9 @@ void GenerateRepDepend(
// generate depend node for netoutput node, to resolve the stream synchronize problem of ge // generate depend node for netoutput node, to resolve the stream synchronize problem of ge
// traverse all nodes of depend node, find the graph output node , generaete a merge node of (square, const) // traverse all nodes of depend node, find the graph output node , generaete a merge node of (square, const)
// and add control_depend of graph output node and square node.
FuncGraphPtr TransformGraphDependNode( FuncGraphPtr TransformGraphDependNode(
const FuncGraphPtr &graph, const AnfNodePtr &cond, const FuncGraphPtr &graph, const AnfNodePtr &cond,
const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &gen_depend_func, const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &gen_depend_func) {
const std::function<CNodePtr(FuncGraphPtr, AnfNodePtr, AnfNodePtr, AnfNodePtr)> &gen_ctl_depd_func) {
auto manager = graph->manager(); auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
@ -478,7 +394,7 @@ FuncGraphPtr TransformGraphDependNode(
if (IsPrimitiveCNode(depended_node, prim::kPrimDepend)) { if (IsPrimitiveCNode(depended_node, prim::kPrimDepend)) {
continue; continue;
} }
GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func, gen_ctl_depd_func); GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func);
} }
} }
ResetSharedOp(); ResetSharedOp();
@ -494,12 +410,12 @@ FuncGraphPtr TransformGraphDependNode(
FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) {
(void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchTrueNode); (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchTrueNode);
return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode, GenerateSwitchControlDependTrueNode); return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode);
} }
FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) { FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) {
(void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchFalseNode); (void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchFalseNode);
return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode, GenerateSwitchControlDependFalseNode); return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode);
} }
// judge if the true and false graph output is compatible(they shall have same tuple size) // judge if the true and false graph output is compatible(they shall have same tuple size)

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -218,11 +218,11 @@ bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) {
if (output_set_iter == node_users.end()) { if (output_set_iter == node_users.end()) {
return false; return false;
} }
for (const auto &node_index_set : output_set_iter->second) {
if (!IsBpropNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) { if (std::any_of(output_set_iter->second.begin(), output_set_iter->second.end(),
[](const auto &node_index_set) { return !IsBpropNode(node_index_set.first); })) {
return true; return true;
} }
}
return false; return false;
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -367,7 +367,6 @@ constexpr char HISTOGRAMSUMMARY[] = "HistogramSummary";
constexpr char DEBUG[] = "Debug"; constexpr char DEBUG[] = "Debug";
constexpr char BROADCASTGRADIENTARGS[] = "BroadcastGradientArgs"; constexpr char BROADCASTGRADIENTARGS[] = "BroadcastGradientArgs";
constexpr char INVERTPERMUTATION[] = "InvertPermutation"; constexpr char INVERTPERMUTATION[] = "InvertPermutation";
constexpr char CONTROLDEPEND[] = "ControlDepend";
constexpr char DOT[] = "dot"; constexpr char DOT[] = "dot";
constexpr char IM2COL[] = "im2col"; constexpr char IM2COL[] = "im2col";
constexpr char COL2IM[] = "col2im"; constexpr char COL2IM[] = "col2im";

View File

@ -259,11 +259,9 @@ BaseRef CreateOutputTensors(const AnfNodePtr &output_node, const KernelGraphPtr
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
VectorRef ret; VectorRef ret;
for (size_t i = 1; i < cnode->inputs().size(); ++i) { for (size_t i = 1; i < cnode->inputs().size(); ++i) {
if (!AnfAlgo::CheckPrimitiveType(cnode->input(i), prim::kPrimControlDepend)) {
auto out = CreateOutputTensors(cnode->input(i), graph, input_tensors); auto out = CreateOutputTensors(cnode->input(i), graph, input_tensors);
ret.push_back(out); ret.push_back(out);
} }
}
return ret; return ret;
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -1044,7 +1044,7 @@ bool DfGraphConvertor::IsControlEdgeNode(const AnfNodePtr &node) {
OperatorPtr DfGraphConvertor::ToOperatorPtr(const AnfNodePtr &node) { OperatorPtr DfGraphConvertor::ToOperatorPtr(const AnfNodePtr &node) {
auto op = Convert(GetRealOpNode(node)); auto op = Convert(GetRealOpNode(node));
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "Convert control depend node to operator failed, " << node->ToString(); MS_LOG(ERROR) << "Convert real op node to operator failed, " << node->ToString();
error_ = FAILED; error_ = FAILED;
return nullptr; return nullptr;
} }
@ -1170,13 +1170,13 @@ void DfGraphConvertor::AutoMonadSetControlInput(const AnfNodePtr &node) {
void DfGraphConvertor::SetOpControlInput(const AnfNodePtr &node) { void DfGraphConvertor::SetOpControlInput(const AnfNodePtr &node) {
AutoMonadSetControlInput(node); AutoMonadSetControlInput(node);
if (control_depend_cache_.find(node.get()) == control_depend_cache_.end()) { if (control_edge_cache_.find(node.get()) == control_edge_cache_.end()) {
return; return;
} }
std::vector<ControlEdge> control_edges = control_depend_cache_[node.get()]; std::vector<ControlEdge> control_edges = control_edge_cache_[node.get()];
if ((control_edges.empty())) { if ((control_edges.empty())) {
MS_LOG(ERROR) << "Get control depend node's src or dest operator failed"; MS_LOG(ERROR) << "Get control edge node's src or dest operator failed";
return; return;
} }
@ -1600,7 +1600,7 @@ std::vector<OperatorPtr> DfGraphConvertor::ConvertDependNode(const AnfNodePtr no
for (size_t index = 1; index < node_inputs.size(); index++) { for (size_t index = 1; index < node_inputs.size(); index++) {
auto op = Convert(GetRealOpNode(node_inputs[index])); auto op = Convert(GetRealOpNode(node_inputs[index]));
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "Convert control depend node to operator failed"; MS_LOG(ERROR) << "Convert real op node to operator failed";
error_ = FAILED; error_ = FAILED;
return std::vector<OperatorPtr>({}); return std::vector<OperatorPtr>({});
} }
@ -1611,194 +1611,13 @@ std::vector<OperatorPtr> DfGraphConvertor::ConvertDependNode(const AnfNodePtr no
auto op = Convert(GetRealOpNode(node)); auto op = Convert(GetRealOpNode(node));
if (op == nullptr) { if (op == nullptr) {
MS_LOG(ERROR) << "Convert control depend node to operator failed"; MS_LOG(ERROR) << "Convert real op node to operator failed";
error_ = FAILED; error_ = FAILED;
return std::vector<OperatorPtr>({}); return std::vector<OperatorPtr>({});
} }
return std::vector<OperatorPtr>({op}); return std::vector<OperatorPtr>({op});
} }
// get the anf node list for depend
std::vector<AnfNodePtr> DfGraphConvertor::GetDependNodes(const AnfNodePtr &node) {
std::vector<AnfNodePtr> nodes;
// for make tuple, should control depend on the tuple items
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
auto node_inputs = node->cast<CNodePtr>()->inputs();
for (size_t index = 1; index < node_inputs.size(); index++) {
nodes.push_back(GetRealOpNode(node_inputs[index]));
}
return nodes;
}
// for parameter ,find the apply that used the parameter as the control depended node
if (node->isa<Parameter>()) {
auto uses = node->func_graph()->manager()->node_users()[node];
for (auto &use : uses) {
auto use_node = use.first;
if ((use_node->isa<CNode>()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) {
nodes.push_back(GetRealOpNode(use_node));
}
}
return nodes;
}
nodes.push_back(GetRealOpNode(node));
return nodes;
}
void DfGraphConvertor::DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node) {
#ifdef DRAW_GE_GRAPH
auto src_depend_nodes = GetDependNodes(src_node);
auto dst_depend_nodes = GetDependNodes(dest_node);
if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() > 1) {
for (auto &item : dst_depend_nodes) {
compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[item.get()]
<< "[style=\"dotted\"]" << endl;
}
} else if (src_depend_nodes.size() > 1 && dst_depend_nodes.size() == 1) {
for (auto &item : src_depend_nodes) {
compute_sout_ << op_draw_name_[item.get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()]
<< "[style=\"dotted\"]" << endl;
}
} else if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() == 1) {
compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()]
<< "[style=\"dotted\"]" << endl;
}
#endif
}
void DfGraphConvertor::GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node,
const AnfNodePtr &dest_node,
const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list,
const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list) {
if (src_node->isa<Parameter>()) {
auto uses = node->func_graph()->manager()->node_users()[src_node];
for (auto &use : uses) {
auto use_node = use.first;
if ((use_node->isa<CNode>()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) &&
(!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) {
auto converted_list = ConvertDependNode(use_node);
src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end());
}
}
}
if (dest_node->isa<Parameter>()) {
auto uses = node->func_graph()->manager()->node_users()[dest_node];
for (auto &use : uses) {
auto use_node = use.first;
if ((use_node->isa<CNode>()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) &&
(!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) {
auto converted_list = ConvertDependNode(use_node);
dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end());
}
}
}
}
bool DfGraphConvertor::GetControlDependList(const CNodePtr &node,
const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list,
const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list) {
const int CONTROL_DEPEND_INDEX = 0;
const int SRC_NODE_INDEX = 1;
const int DEST_NODE_INDEX = 2;
const int DEPEND_MODE_NORMAL_USE = 0;
const int DEPEND_MODE_ON_PARAMETER_USE = 1;
auto node_inputs = node->inputs();
if (node_inputs.size() <= DEST_NODE_INDEX) {
MS_LOG(WARNING) << "Control depend node input size error";
return false;
}
auto src_node = node_inputs[SRC_NODE_INDEX];
auto dest_node = node_inputs[DEST_NODE_INDEX];
if ((src_node == nullptr) || (dest_node == nullptr)) {
MS_LOG(ERROR) << "Control depend node miss src or dest node";
error_ = FAILED;
return false;
}
AnfNodePtr fn = node_inputs[CONTROL_DEPEND_INDEX];
PrimitivePtr prim_ptr = GetValueNode<PrimitivePtr>(fn);
ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode");
int depend_mode = DEPEND_MODE_NORMAL_USE;
if (mode_ptr != nullptr) {
auto mode_int = mode_ptr->cast<Int64ImmPtr>();
MS_EXCEPTION_IF_NULL(mode_int);
depend_mode = mode_int->value();
MS_LOG(DEBUG) << "depend_mode = " << depend_mode;
}
if (depend_mode == DEPEND_MODE_ON_PARAMETER_USE) {
GetDependOnParameterUse(node, src_node, dest_node, src_ops_list, dst_ops_list);
}
if (src_node->isa<CNode>()) {
auto converted_list = ConvertDependNode(src_node);
src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end());
}
if (dest_node->isa<CNode>()) {
auto converted_list = ConvertDependNode(dest_node);
dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end());
}
if (src_ops_list->empty() || dst_ops_list->empty()) {
MS_LOG(DEBUG) << "Control depend node's src or dest node is not a CNode, ignore it";
error_ = SUCCESS;
}
return true;
}
void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) {
const int SRC_NODE_INDEX = 1;
const int DEST_NODE_INDEX = 2;
if (control_depend_cache_.find(node.get()) != control_depend_cache_.end()) {
return;
}
auto node_inputs = node->inputs();
if (node_inputs.size() <= DEST_NODE_INDEX) {
MS_LOG(WARNING) << "Control depend node input size error";
return;
}
auto src_node = node_inputs[SRC_NODE_INDEX];
auto dest_node = node_inputs[DEST_NODE_INDEX];
if ((src_node == nullptr) || (dest_node == nullptr)) {
MS_LOG(ERROR) << "Control depend node miss src or dest node";
error_ = FAILED;
return;
}
std::shared_ptr<std::vector<OperatorPtr>> src_ops_list = std::make_shared<std::vector<OperatorPtr>>();
std::shared_ptr<std::vector<OperatorPtr>> dst_ops_list = std::make_shared<std::vector<OperatorPtr>>();
if (!GetControlDependList(node, src_ops_list, dst_ops_list)) {
MS_LOG(ERROR) << "Get depend list failed";
error_ = FAILED;
return;
}
std::vector<ControlEdge> control_edges;
if (src_ops_list->size() == 1 && dst_ops_list->size() > 1) {
(void)std::transform(dst_ops_list->begin(), dst_ops_list->end(), std::back_inserter(control_edges),
[src_ops_list](const OperatorPtr &op) -> ControlEdge {
return {(*src_ops_list)[0], op};
});
} else if (src_ops_list->size() > 1 && dst_ops_list->size() == 1) {
(void)std::transform(src_ops_list->begin(), src_ops_list->end(), std::back_inserter(control_edges),
[dst_ops_list](const OperatorPtr &op) -> ControlEdge {
return {op, (*dst_ops_list)[0]};
});
} else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) {
control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]});
} else if (src_ops_list->empty() || dst_ops_list->empty()) {
MS_LOG(DEBUG) << "Depend list of src or dst is empty, ignore it";
} else {
MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size()
<< " -> dst:" << dst_ops_list->size();
error_ = FAILED;
return;
}
control_depend_cache_[node.get()] = control_edges;
#ifdef DRAW_GE_GRAPH
DrawControlDepend(src_node, dest_node);
#endif
}
bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) { bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) {
// ignore apply node of return // ignore apply node of return
if (name == "" || name == prim::kPrimReturn->name() || name == prim::kPrimDepend->name() || if (name == "" || name == prim::kPrimReturn->name() || name == prim::kPrimDepend->name() ||
@ -1818,12 +1637,6 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
return false; return false;
} }
// ControlDepend
if (name == prim::kPrimControlDepend->name()) {
ConvertControlDependNode(node);
return false;
}
return true; return true;
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -145,19 +145,11 @@ class DfGraphConvertor {
OperatorPtr ConvertCNode(CNodePtr node); OperatorPtr ConvertCNode(CNodePtr node);
std::vector<OperatorPtr> ConvertDependNode(AnfNodePtr node); std::vector<OperatorPtr> ConvertDependNode(AnfNodePtr node);
AnfNodePtr GetRealOpNode(AnfNodePtr node); AnfNodePtr GetRealOpNode(AnfNodePtr node);
std::vector<AnfNodePtr> GetDependNodes(const AnfNodePtr &node);
OperatorPtr ConvertParameter(AnfNodePtr node); OperatorPtr ConvertParameter(AnfNodePtr node);
Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); Status TryConvertValueNodeToMultiConst(const ValueNodePtr node);
OperatorPtr ConvertValueNode(ValueNodePtr node); OperatorPtr ConvertValueNode(ValueNodePtr node);
void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node); void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node);
void ConvertTupleGetItem(const CNodePtr node); void ConvertTupleGetItem(const CNodePtr node);
void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node,
const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list,
const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list);
bool GetControlDependList(const CNodePtr &node, const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list,
const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list);
void DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node);
void ConvertControlDependNode(const CNodePtr node);
void ConvertMakeTuple(const CNodePtr node); void ConvertMakeTuple(const CNodePtr node);
bool CheckCNode(const std::string &name, const CNodePtr node); bool CheckCNode(const std::string &name, const CNodePtr node);
void TraceOutput(AnfNodePtr node); void TraceOutput(AnfNodePtr node);
@ -195,7 +187,7 @@ class DfGraphConvertor {
std::shared_ptr<DfGraph> broadcast_graph_{nullptr}; std::shared_ptr<DfGraph> broadcast_graph_{nullptr};
std::unordered_map<AnfNode *, DfGraph> branches_map_; std::unordered_map<AnfNode *, DfGraph> branches_map_;
std::unordered_map<AnfNode *, OperatorPtr> op_cache_; std::unordered_map<AnfNode *, OperatorPtr> op_cache_;
std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_depend_cache_; std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_edge_cache_;
std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>> monad_control_edge_cache_; std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>> monad_control_edge_cache_;
/* record "tuple_getitem"<->"out_handler" mapping */ /* record "tuple_getitem"<->"out_handler" mapping */
std::unordered_map<AnfNode *, OutHandler> out_handle_cache_; std::unordered_map<AnfNode *, OutHandler> out_handle_cache_;

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -51,88 +51,6 @@ std::string GetOtherTarget(const std::vector<AnfNodePtr> &nodes) {
} }
return ""; return "";
} }
bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node,
std::vector<AnfNodePtr> *prior_nodes, std::vector<AnfNodePtr> *depend_nodes) {
MS_EXCEPTION_IF_NULL(prior_node);
MS_EXCEPTION_IF_NULL(behind_node);
MS_EXCEPTION_IF_NULL(graph);
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
if (prior_node->isa<Parameter>()) {
for (auto &user : node_users[prior_node]) {
auto cnode = user.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
prior_nodes->emplace_back(cnode);
}
}
} else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) {
prior_nodes->emplace_back(prior_node);
} else {
return false;
}
if (behind_node->isa<Parameter>()) {
for (auto &user : node_users[behind_node]) {
auto cnode = user.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
depend_nodes->emplace_back(cnode);
}
}
} else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) {
depend_nodes->emplace_back(behind_node);
} else {
return false;
}
return true;
}
void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node,
std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges,
std::map<AnfNodePtr, size_t> *nodes_ref) {
MS_EXCEPTION_IF_NULL(node);
auto input_cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
auto prior_node = input_cnode->input(kControlDependPriorIndex);
auto depend_node = input_cnode->input(kControlDependBehindIndex);
MS_EXCEPTION_IF_NULL(prior_node);
MS_EXCEPTION_IF_NULL(depend_node);
auto prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0));
MS_EXCEPTION_IF_NULL(prim_ptr);
ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode");
int64_t depend_mode = 0;
if (mode_ptr != nullptr) {
depend_mode = GetValue<int64_t>(mode_ptr);
}
if ((prior_node->isa<Parameter>() || depend_node->isa<Parameter>()) && depend_mode == 0) {
return;
}
std::vector<AnfNodePtr> prior_nodes;
std::vector<AnfNodePtr> behind_nodes;
if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) {
return;
}
for (auto &first_node : prior_nodes) {
for (auto &second_node : behind_nodes) {
MS_EXCEPTION_IF_NULL(first_node);
MS_EXCEPTION_IF_NULL(second_node);
auto iter = control_edges->find(second_node);
if (iter == control_edges->end()) {
(void)control_edges->insert(
std::pair<AnfNodePtr, std::vector<AnfNodePtr>>(second_node, std::vector<AnfNodePtr>{first_node}));
} else {
iter->second.emplace_back(first_node);
}
auto ref_iter = nodes_ref->find(first_node);
if (ref_iter != nodes_ref->end()) {
ref_iter->second++;
} else {
(void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(first_node, 1));
}
}
}
}
void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref, void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref,
std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) { std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) {
@ -149,9 +67,6 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
for (auto &input : cnode->inputs()) { for (auto &input : cnode->inputs()) {
if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) {
AddControlEdge(graph, input, control_edges, nodes_ref);
}
auto iter = nodes_ref->find(input); auto iter = nodes_ref->find(input);
if (iter != nodes_ref->end()) { if (iter != nodes_ref->end()) {
iter->second++; iter->second++;
@ -479,12 +394,10 @@ void AddSegmentDependency(const FuncGraphPtr &graph, const std::string &default_
node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end());
} }
GraphSegmentPtr node_segment{nullptr}; GraphSegmentPtr node_segment{nullptr};
if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
auto node_iter = node_to_segment.find(node); auto node_iter = node_to_segment.find(node);
if (node_iter != node_to_segment.end()) { if (node_iter != node_to_segment.end()) {
node_segment = node_iter->second; node_segment = node_iter->second;
} }
}
for (auto &input : node_inputs) { for (auto &input : node_inputs) {
if (node_segment != nullptr && !node_segment->is_cut_ && input->isa<CNode>()) { if (node_segment != nullptr && !node_segment->is_cut_ && input->isa<CNode>()) {
GraphSegmentPtr input_segment{nullptr}; GraphSegmentPtr input_segment{nullptr};
@ -615,18 +528,14 @@ void SplitDynamicNodeSegment(const std::vector<AnfNodePtr> &segment_nodes, std::
std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment, std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment,
const std::set<AnfNodePtr> &dynamic_nodes_set) { const std::set<AnfNodePtr> &dynamic_nodes_set) {
SplitDynamicNodesHelper helper; SplitDynamicNodesHelper helper;
bool is_last_node_dynamic = false;
for (auto &node : segment_nodes) { for (auto &node : segment_nodes) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
helper.AddNode(node, is_last_node_dynamic);
continue;
}
auto &inputs = cnode->inputs(); auto &inputs = cnode->inputs();
bool has_dynamic_shape = dynamic_nodes_set.find(node) != dynamic_nodes_set.end(); bool has_dynamic_shape = dynamic_nodes_set.find(node) != dynamic_nodes_set.end();
bool depend_common_node = false; bool depend_common_node = false;
bool depend_dynamic_node = false; bool depend_dynamic_node = false;
bool is_last_node_dynamic = false;
for (size_t i = 1; i < inputs.size(); ++i) { for (size_t i = 1; i < inputs.size(); ++i) {
if (dynamic_nodes_set.find(inputs[i]) != dynamic_nodes_set.end()) { if (dynamic_nodes_set.find(inputs[i]) != dynamic_nodes_set.end()) {
has_dynamic_shape = true; has_dynamic_shape = true;

View File

@ -1,7 +1,7 @@
/** /**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* *
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -87,26 +87,7 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo
if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) { if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
eqv[node] = node; eqv[node] = node;
} else if (eqv.find(node) == eqv.end()) { } else if (eqv.find(node) == eqv.end()) {
if (IsPrimitiveCNode(node, prim::kPrimControlDepend)) {
eqv[node] = NewValueNode(MakeValue(0));
return eqv[node];
}
bool ignore_make_tuple = false;
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
ignore_make_tuple = true;
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &node_inputs = cnode->inputs();
for (size_t i = 1; i < node_inputs.size(); ++i) {
if (!IsPrimitiveCNode(node_inputs[i], prim::kPrimControlDepend)) {
ignore_make_tuple = false;
break;
}
}
}
if (!ignore_make_tuple) {
inputs.push_back(node); inputs.push_back(node);
}
eqv[node] = fg->add_parameter(); eqv[node] = fg->add_parameter();
eqv[node]->set_abstract(node->abstract()); eqv[node]->set_abstract(node->abstract());
eqv[node]->set_kernel_info(node->kernel_info_ptr()); eqv[node]->set_kernel_info(node->kernel_info_ptr());
@ -148,14 +129,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
for (size_t i = 2; i < inps.size(); ++i) { for (size_t i = 2; i < inps.size(); ++i) {
args.emplace_back(NewValueNode(MakeValue(0))); args.emplace_back(NewValueNode(MakeValue(0)));
} }
} else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) {
for (size_t i = 1; i < inps.size(); ++i) {
if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) {
args.emplace_back(NewValueNode(MakeValue(static_cast<int>(i))));
} else {
args.emplace_back(RefSubGraphNode(fg, inps[i], &inputs, &eqv));
}
}
} else { } else {
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); }); [&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });

View File

@ -182,8 +182,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -188,23 +188,6 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP
return args_spec_list[0]->Broaden(); return args_spec_list[0]->Broaden();
} }
AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// args: Two objects of a subclass of AbstractBase
CheckArgsSize(primitive->name(), args_spec_list, 2);
auto arg_src = args_spec_list[0];
auto arg_dst = args_spec_list[1];
// control depend can not setup tuple of ops to tuple of ops dependency relation
if (arg_src->isa<AbstractTuple>() && arg_dst->isa<AbstractTuple>()) {
auto src_size = arg_src->cast<AbstractTuplePtr>()->size();
auto dst_size = arg_src->cast<AbstractTuplePtr>()->size();
if (src_size > 1 && dst_size > 1) {
MS_LOG(EXCEPTION) << "Control depend can not setup operator dependency relationship from tuple from tuple";
}
}
return std::make_shared<AbstractScalar>(kAnyValue, kBool);
}
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple. // Inputs: two tensors and a tuple.

View File

@ -149,7 +149,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimStateSetItem, {InferImplStateSetItem, nullptr, true}}, {prim::kPrimStateSetItem, {InferImplStateSetItem, nullptr, true}},
{prim::kPrimDepend, {InferImplDepend, nullptr, true}}, {prim::kPrimDepend, {InferImplDepend, nullptr, true}},
{prim::kPrimUpdateState, {InferImplUpdateState, nullptr, true}}, {prim::kPrimUpdateState, {InferImplUpdateState, nullptr, true}},
{prim::kPrimControlDepend, {InferImplControlDepend, nullptr, true}},
// Debug // Debug
{prim::kPrimDebug, {InferImplDebug, nullptr, true}}, {prim::kPrimDebug, {InferImplDebug, nullptr, true}},
// Dynamic shape testing // Dynamic shape testing

View File

@ -453,7 +453,6 @@ inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookB
inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType"); inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType");
inline const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape"); inline const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape");
inline const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print"); inline const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print");
inline const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend");
inline const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_"); inline const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_");
inline const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not"); inline const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
inline const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict"); inline const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");

View File

@ -1,7 +1,7 @@
/** /**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
* *
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -399,8 +399,7 @@ std::string GetAttrTarget(const PrimitivePtr &primitive, const ValuePtr &att_tar
if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) ||
IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) ||
IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) ||
IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimPartial)) {
IsPrimitive(attr_input, prim::kPrimPartial)) {
primitive->EraseAttr(primitive_target); primitive->EraseAttr(primitive_target);
return default_target; return default_target;
} }

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -23,7 +23,6 @@
#include "ir/dtype/tensor_type.h" #include "ir/dtype/tensor_type.h"
#include "utils/check_convert_utils.h" #include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h" #include "abstract/primitive_infer_map.h"
#include "ops/control_depend.h"
namespace mindspore { namespace mindspore {
namespace ops { namespace ops {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -23,7 +23,6 @@
#include "ops/expand_dims.h" #include "ops/expand_dims.h"
#include "utils/check_convert_utils.h" #include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h" #include "abstract/primitive_infer_map.h"
#include "ops/control_depend.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -30,7 +30,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
"identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key", "identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key",
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary", "get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
"InvertPermutation", "ControlDepend", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
"stop_gradient", "Send", "UpdateState", "Load"}; "stop_gradient", "Send", "UpdateState", "Load"};
// clang-format on // clang-format on

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -27,7 +27,6 @@
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
#include "mindspore/core/ir/primitive.h" #include "mindspore/core/ir/primitive.h"
#include "ops/fusion/partial_fusion.h" #include "ops/fusion/partial_fusion.h"
#include "ops/control_depend.h"
#include "ops/depend.h" #include "ops/depend.h"
#include "ops/make_tuple.h" #include "ops/make_tuple.h"
#include "ops/quant_dtype_cast.h" #include "ops/quant_dtype_cast.h"
@ -213,8 +212,7 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) {
MS_LOG(ERROR) << "value node is invalid."; MS_LOG(ERROR) << "value node is invalid.";
return; return;
} }
if (value_node->value() != nullptr && (opt::CheckPrimitiveType(depend_node, prim::kPrimDepend) || if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) {
opt::CheckPrimitiveType(depend_node, prim::kPrimControlDepend))) {
has_depend = true; has_depend = true;
bool mask_out = (depend_node->inputs().size() == 3); bool mask_out = (depend_node->inputs().size() == 3);
for (size_t j = 1; j < depend_node->inputs().size(); ++j) { for (size_t j = 1; j < depend_node->inputs().size(); ++j) {
@ -466,8 +464,8 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
} }
RemoveIfDepend(cnode); RemoveIfDepend(cnode);
if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend || if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameTupleGetItem ||
prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) { prim->name() == mindspore::ops::kNameMakeTuple) {
continue; continue;
} }
if (prim->name() == "make_tuple") { if (prim->name() == "make_tuple") {

View File

@ -57,8 +57,8 @@ bool IsRealKernel(const AnfNodePtr &node) {
IsPrimitive(input, prim::kPrimTensorSummary) || IsPrimitive(input, prim::kPrimTensorSummary) ||
IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimReturn) ||
IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); IsPrimitive(input, prim::kPrimPartial);
return !is_virtual_node; return !is_virtual_node;
} }

View File

@ -43,8 +43,8 @@ tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) {
bool IsSpecialType(const CNodePtr &cnode) { bool IsSpecialType(const CNodePtr &cnode) {
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) || if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) ||
CheckPrimitiveType(cnode, prim::kPrimControlDepend) || CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimReturn) ||
CheckPrimitiveType(cnode, prim::kPrimReturn) || CheckPrimitiveType(cnode, std::make_shared<Primitive>("While")) || CheckPrimitiveType(cnode, std::make_shared<Primitive>("While")) ||
CheckPrimitiveType(cnode, std::make_shared<Primitive>("If"))) { CheckPrimitiveType(cnode, std::make_shared<Primitive>("If"))) {
return true; return true;
} }

View File

@ -58,13 +58,6 @@ int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraph
return lite::RET_NO_CHANGE; return lite::RET_NO_CHANGE;
} }
} }
if (CheckPrimitiveType(anf_node, prim::kPrimControlDepend)) {
if (cnode->size() != InputDoubleNum) {
MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
remove_cnode_.insert(anf_node);
return lite::RET_NO_CHANGE;
}
}
bool replace_succ = manager->Replace(anf_node, cnode->input(1)); bool replace_succ = manager->Replace(anf_node, cnode->input(1));
if (!replace_succ) { if (!replace_succ) {

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -96,7 +96,6 @@ class GPT2FinetuneCell(nn.Cell):
self.get_status = P.NPUGetFloatStatus() self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus() self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False) self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32) self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual() self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
@ -132,8 +131,8 @@ class GPT2FinetuneCell(nn.Cell):
if not self.gpu_target: if not self.gpu_target:
init = self.alloc_status() init = self.alloc_status()
init = F.depend(init, loss)
clear_before_grad = self.clear_before_grad(init) clear_before_grad = self.clear_before_grad(init)
F.control_depend(loss, init)
self.depend_parameter_use(clear_before_grad, scaling_sens) self.depend_parameter_use(clear_before_grad, scaling_sens)
grads = self.grad(self.network, weights)(input_ids, grads = self.grad(self.network, weights)(input_ids,
input_mask, input_mask,
@ -145,10 +144,10 @@ class GPT2FinetuneCell(nn.Cell):
if self.reducer_flag: if self.reducer_flag:
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)
if not self.gpu_target: if not self.gpu_target:
init = F.depend(init, grads)
flag = self.get_status(init) flag = self.get_status(init)
init = F.depend(init, flag)
flag_sum = self.reduce_sum(init, (0,)) flag_sum = self.reduce_sum(init, (0,))
F.control_depend(grads, flag)
F.control_depend(flag, flag_sum)
else: else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum) flag_sum = self.addn(flag_sum)

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -74,9 +74,9 @@ TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence_with_make_tupl
/* /*
* def before(x, y, a, b): * def before(x, y, a, b):
* z = make_tuple(TransData(a), TransData(b)) * z = make_tuple(TransData(a), TransData(b))
* depend_intput = control_depend(y, z) * depend_intput = depend(y, z)
* sum = add(x, depend_intput) * sum_add = add(x, depend_intput)
* return sum * return sum_add
*/ */
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence_with_make_tuple", "before"); FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence_with_make_tuple", "before");
@ -93,11 +93,11 @@ TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence_with_make_tupl
TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence) { TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence) {
/* /*
* def before(x, y, a, b): * def before(x, y, z):
* z = make_tuple(TransData(a), TransData(b)) * new_z = TransData(z)
* depend_intput = control_depend(y, z) * depend_intput = depend(y, new_z)
* sum = add(x, depend_intput) * sum_add = add(x, depend_intput)
* return sum * return sum_add
*/ */
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence", "before"); FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence", "before");