forked from mindspore-Ecosystem/mindspore
!14612 remove ControlDepend
From: @huangbingjian Reviewed-by: Signed-off-by:
This commit is contained in:
commit
e2260a2f09
|
@ -108,7 +108,7 @@ bool InputCheck(const AnfNodePtr &node) {
|
|||
MS_LOG(INFO) << "Data->TransData->split, can not optimizer.";
|
||||
return false;
|
||||
}
|
||||
if (in_node_name == prim::kPrimControlDepend->name() || in_node_name == prim::kPrimDepend->name()) {
|
||||
if (in_node_name == prim::kPrimDepend->name()) {
|
||||
return false;
|
||||
}
|
||||
if ((AnfAlgo::HasNodeAttr("non_task", in_cnode) && AnfAlgo::GetNodeAttr<bool>(in_node, "non_task")) ||
|
||||
|
@ -131,7 +131,7 @@ bool OutputCheck(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
|||
return false;
|
||||
}
|
||||
for (const auto &item : outputs) {
|
||||
if (IsPrimitiveCNode(item, prim::kPrimControlDepend) || IsPrimitiveCNode(item, prim::kPrimDepend)) {
|
||||
if (IsPrimitiveCNode(item, prim::kPrimDepend)) {
|
||||
MS_LOG(INFO) << "Split has control edge, can not optimizer.";
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -168,7 +168,7 @@ const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, c
|
|||
(void)manager->Replace(output, bn_training_update_v2_outputs[output_index]);
|
||||
output_index++;
|
||||
}
|
||||
// Return the new node for control depends.
|
||||
// Return the new node.
|
||||
return bn_training_update_v2;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -201,20 +201,6 @@ const AnfNodePtr DropoutAndDropoutGradUnifyMindIR::Process(const FuncGraphPtr &f
|
|||
}
|
||||
}
|
||||
}
|
||||
if (dropout_do_mask1 != nullptr) {
|
||||
// Dropout is used by ControlDepend in some situation, need to replace ControlDepend.
|
||||
auto &users = manager->node_users();
|
||||
iter = users.find(dropout_node);
|
||||
if (iter != users.end()) {
|
||||
for (auto &node_index : iter->second) {
|
||||
auto used_node = node_index.first;
|
||||
if (AnfAlgo::CheckPrimitiveType(used_node, prim::kPrimControlDepend)) {
|
||||
(void)manager->Replace(used_node, dropout_do_mask1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDropoutDoMask-backward
|
||||
if (equiv->find(grad_input_) == equiv->end()) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -426,9 +426,6 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOu
|
|||
}
|
||||
auto output_info_list = iter->second;
|
||||
for (const auto &output_info : output_info_list) {
|
||||
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) {
|
||||
continue;
|
||||
}
|
||||
if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() &&
|
||||
output_info.second == kDependAttachNodeIndex) {
|
||||
continue;
|
||||
|
@ -908,16 +905,12 @@ void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const C
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
// find BatchNorm's output which is a Depend or ControlDepend
|
||||
// find BatchNorm's output which is a Depend
|
||||
for (const auto &node_index : manager->node_users()[old_node]) {
|
||||
AnfNodePtr output = node_index.first;
|
||||
size_t index = IntToSize(node_index.second);
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
|
||||
auto control_depend = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(control_depend);
|
||||
control_depend->set_input(index, new_node);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) {
|
||||
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimDepend)) {
|
||||
auto depend = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend);
|
||||
depend->set_input(index, new_node);
|
||||
|
|
|
@ -210,7 +210,7 @@ bool CheckSupportDataType(const AnfNodePtr &node, const std::set<TypeId> &suppor
|
|||
// Create a new value node of func graph,not kernel graph
|
||||
ValueNodePtr MakeValueNode(const ValueNodePtr &value_node);
|
||||
|
||||
// Transfer depend or control_depend to the new node
|
||||
// Transfer depend to the new node
|
||||
void TransferDepend(const CNodePtr &old_node, const FuncGraphPtr &graph, const CNodePtr &new_node);
|
||||
|
||||
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);
|
||||
|
|
|
@ -327,7 +327,7 @@ void AtomicCleanInsertter::ProcessOriginCNode(const AnfNodePtr &composite_node,
|
|||
|
||||
void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node,
|
||||
const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index) {
|
||||
// Create depend node to hold new control depend node.
|
||||
// Create depend node to hold execution order.
|
||||
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), clean_node, composite_node};
|
||||
auto depend_cnode = main_graph->NewCNode(d_inputs);
|
||||
depend_cnode->set_abstract(clean_node->abstract());
|
||||
|
@ -501,12 +501,11 @@ bool AtomicCleanInsertter::IsExistStructuralObstacle(const KernelGraphPtr &main_
|
|||
const FuncGraphManagerPtr &mng) {
|
||||
auto reduce_users = FindOriginCNodeUsers(main_graph, node, mng, false);
|
||||
// If reduce user is MakeTuple and not last node, there is no cheap method to set right running order between reduce
|
||||
// node and user node. If reduce is Depend or ControlDepend node, the origin node may be wrong!
|
||||
return std::all_of(reduce_users.cbegin(), reduce_users.cend(),
|
||||
[&main_graph](const std::pair<AnfNodePtr, int> &user_info) -> bool {
|
||||
// node and user node. If reduce is Depend node, the origin node may be wrong!
|
||||
return std::all_of(
|
||||
reduce_users.cbegin(), reduce_users.cend(), [&main_graph](const std::pair<AnfNodePtr, int> &user_info) -> bool {
|
||||
auto &user = user_info.first;
|
||||
if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend) ||
|
||||
IsPrimitiveCNode(user, prim::kPrimControlDepend)) &&
|
||||
if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend)) &&
|
||||
!(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) {
|
||||
return false;
|
||||
} else {
|
||||
|
|
|
@ -123,9 +123,9 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
|
|||
bool changed = false;
|
||||
auto mng = kernel_graph->manager();
|
||||
|
||||
// depend_prior[depend] = pair(prior, controlDependNode)
|
||||
// depend_prior[depend] = pair(prior, behind)
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior;
|
||||
InitDependPrior(todos, &depend_prior);
|
||||
// InitDependPrior(todos, &depend_prior);
|
||||
|
||||
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
||||
auto node = (*iter)->cast<CNodePtr>();
|
||||
|
|
|
@ -657,76 +657,6 @@ void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
|
|||
#endif
|
||||
}
|
||||
|
||||
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
|
||||
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
||||
auto cnode = (*iter)->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimControlDepend)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto prior_node = cnode->input(kControlDependPriorIndex);
|
||||
auto depend_node = cnode->input(kControlDependBehindIndex);
|
||||
MS_EXCEPTION_IF_NULL(prior_node);
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
std::vector<AnfNodePtr> prior_nodes = {prior_node};
|
||||
std::vector<AnfNodePtr> depend_nodes = {depend_node};
|
||||
|
||||
int depend_mode = 0;
|
||||
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
|
||||
depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode);
|
||||
}
|
||||
|
||||
auto GetOutputNodes = [cnode](const AnfNodePtr ¶m) -> std::vector<AnfNodePtr> {
|
||||
std::vector<AnfNodePtr> out_nodes;
|
||||
auto user_set = param->func_graph()->manager()->node_users()[param];
|
||||
for (auto iter = user_set.cbegin(); iter != user_set.cend(); ++iter) {
|
||||
if (iter->first != cnode) {
|
||||
out_nodes.push_back(iter->first);
|
||||
}
|
||||
}
|
||||
return out_nodes;
|
||||
};
|
||||
|
||||
if (prior_node->isa<Parameter>() && depend_mode == 1) {
|
||||
prior_nodes = GetOutputNodes(prior_node);
|
||||
}
|
||||
if (depend_node->isa<Parameter>()) {
|
||||
depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector<AnfNodePtr>{};
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> real_prior_nodes;
|
||||
std::set<AnfNodePtr> prior_visited;
|
||||
for (const auto &tmp : prior_nodes) {
|
||||
AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
|
||||
}
|
||||
prior_visited.clear();
|
||||
std::vector<AnfNodePtr> real_depend_nodes;
|
||||
std::set<AnfNodePtr> depend_visited;
|
||||
for (const auto &tmp : depend_nodes) {
|
||||
AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
|
||||
}
|
||||
depend_visited.clear();
|
||||
|
||||
for (auto &prior : real_prior_nodes) {
|
||||
if (AnfAlgo::CheckPrimitiveType(prior, prim::kPrimControlDepend)) {
|
||||
continue;
|
||||
}
|
||||
for (auto &depend : real_depend_nodes) {
|
||||
if (AnfAlgo::CheckPrimitiveType(depend, prim::kPrimControlDepend)) {
|
||||
continue;
|
||||
}
|
||||
depend_prior->insert({depend, std::make_pair(prior, cnode)});
|
||||
}
|
||||
}
|
||||
real_prior_nodes.clear();
|
||||
real_depend_nodes.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) {
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri;
|
||||
|
|
|
@ -75,8 +75,6 @@ std::vector<PrimitivePtr> GetFusibleOpList();
|
|||
bool IsBasicFuseOp(const AnfNodePtr &node);
|
||||
bool IsFusibleOp(const AnfNodePtr &node);
|
||||
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
|
||||
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior);
|
||||
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs);
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -55,7 +55,7 @@ CNodePtr GetRealPrevCNode(const AnfNodePtr &node, size_t index, std::vector<Kern
|
|||
auto item_idx = GetValue<int64_t>(value_node->value());
|
||||
pass_vector->push_back(make_pair(cnode, IntToSize(1)));
|
||||
return GetRealPrevCNode(cnode->input(1), LongToSize(item_idx), pass_vector);
|
||||
} else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
|
||||
} else if (IsPrimitive(input0, prim::kPrimDepend)) {
|
||||
pass_vector->push_back(make_pair(cnode, IntToSize(1)));
|
||||
return GetRealPrevCNode(cnode->input(1), 0, pass_vector);
|
||||
} else if (IsPrimitive(input0, prim::kPrimUpdateState)) {
|
||||
|
@ -92,8 +92,7 @@ const AnfNodePtr ProcessMatchedNodes(const FuncGraphPtr &func_graph, const CNode
|
|||
auto pass_size = pass_vector->size();
|
||||
for (size_t idx = 1; idx <= pass_size - 1; ++idx) {
|
||||
auto nd = (*pass_vector)[idx].first;
|
||||
if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend) ||
|
||||
AnfAlgo::CheckPrimitiveType(nd, prim::kPrimControlDepend)) {
|
||||
if (AnfAlgo::CheckPrimitiveType(nd, prim::kPrimDepend)) {
|
||||
has_depend_node = true;
|
||||
}
|
||||
if (users[nd].size() >= 2) {
|
||||
|
|
|
@ -248,7 +248,7 @@ class AnfRuntimeAlgorithm {
|
|||
static void InferShape(const CNodePtr &node);
|
||||
static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
|
||||
static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
|
||||
// Find control_depend real input nodes.
|
||||
// Find real input nodes.
|
||||
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
|
||||
std::set<AnfNodePtr> *visited);
|
||||
};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -534,14 +534,17 @@ void AscendControlParser::InsertDependToGraph(NotNull<KernelGraphPtr> kg, NotNul
|
|||
return_node->set_input(kFirstDataInputIndex, depend_node);
|
||||
}
|
||||
|
||||
void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> first_node,
|
||||
NotNull<AnfNodePtr> second_node) {
|
||||
MS_LOG(INFO) << "Insert control depend at the end of graph, the first node is " << first_node->DebugString()
|
||||
<< ", the second node is " << second_node->DebugString();
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimControlDepend->name())),
|
||||
first_node, second_node};
|
||||
auto control_depend = kg->NewCNode(inputs);
|
||||
InsertDependToGraph(kg, NOT_NULL(control_depend));
|
||||
void AscendControlParser::InsertControlDependToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> prior_node,
|
||||
NotNull<AnfNodePtr> behind_node) {
|
||||
MS_LOG(INFO) << "Insert control dependence at the end of graph, the prior node is " << prior_node->DebugString()
|
||||
<< ", the behind node is " << behind_node->DebugString();
|
||||
auto manager = kg->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
AnfNodePtrList inputs = {NewValueNode(prim::kPrimDepend), behind_node, prior_node};
|
||||
auto depend_cnode = kg->NewCNode(inputs);
|
||||
if (!manager->Replace(behind_node, depend_cnode)) {
|
||||
MS_LOG(EXCEPTION) << behind_node->DebugString() << ", replace node failed.";
|
||||
}
|
||||
}
|
||||
|
||||
void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -422,7 +422,7 @@ void KernelGraph::CheckLoop() {
|
|||
none_zero_nodes[it.first] = it.second;
|
||||
}
|
||||
}
|
||||
// if don't consider control depend and loop exit,a exception will be throw
|
||||
// if don't consider loop exit,a exception will be throw
|
||||
if (!none_zero_nodes.empty()) {
|
||||
MS_LOG(WARNING) << "Nums of loop:" << GetLoopNum(none_zero_nodes);
|
||||
MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size();
|
||||
|
@ -815,61 +815,10 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
|
|||
return output_nodes;
|
||||
}
|
||||
|
||||
// update the depend relations of control depend
|
||||
void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
|
||||
for (const auto &node : depends) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) {
|
||||
MS_LOG(EXCEPTION) << node->DebugString() << " is not a control depend";
|
||||
}
|
||||
auto prior_node = cnode->input(kControlDependPriorIndex);
|
||||
auto depend_node = cnode->input(kControlDependBehindIndex);
|
||||
MS_EXCEPTION_IF_NULL(prior_node);
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
std::vector<AnfNodePtr> prior_nodes = {prior_node};
|
||||
std::vector<AnfNodePtr> depend_nodes = {depend_node};
|
||||
int depend_mode = 0;
|
||||
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
|
||||
depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Prior node[" << prior_node->DebugString() << "], depend node[" << depend_node->DebugString()
|
||||
<< "], depend_mode :" << depend_mode << ".";
|
||||
if (prior_node->isa<Parameter>() && depend_mode == 1) {
|
||||
prior_nodes = GetOutputNodes(prior_node);
|
||||
}
|
||||
if (depend_node->isa<Parameter>()) {
|
||||
depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector<AnfNodePtr>{};
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> real_prior_nodes;
|
||||
std::set<AnfNodePtr> prior_visited;
|
||||
for (const auto &tmp : prior_nodes) {
|
||||
AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
|
||||
}
|
||||
std::vector<AnfNodePtr> real_depend_nodes;
|
||||
std::set<AnfNodePtr> depend_visited;
|
||||
for (const auto &tmp : depend_nodes) {
|
||||
AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
|
||||
}
|
||||
UpdateNodeInputOutputEdges(real_prior_nodes, real_depend_nodes);
|
||||
}
|
||||
}
|
||||
|
||||
void KernelGraph::UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real_prior_nodes,
|
||||
const std::vector<AnfNodePtr> &real_depend_nodes) {
|
||||
for (auto &first_node : real_prior_nodes) {
|
||||
if (AnfAlgo::CheckPrimitiveType(first_node, prim::kPrimControlDepend)) {
|
||||
continue;
|
||||
}
|
||||
for (auto &second_node : real_depend_nodes) {
|
||||
if (AnfAlgo::CheckPrimitiveType(second_node, prim::kPrimControlDepend)) {
|
||||
continue;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(first_node);
|
||||
MS_EXCEPTION_IF_NULL(second_node);
|
||||
MS_LOG(DEBUG) << "Add first node:" << first_node->DebugString() << ",second node:" << second_node->DebugString();
|
||||
|
@ -878,35 +827,6 @@ void KernelGraph::UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real
|
|||
}
|
||||
}
|
||||
|
||||
bool KernelGraph::HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
|
||||
std::unordered_set<AnfNodePtr> *visited_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(que);
|
||||
MS_EXCEPTION_IF_NULL(visited_nodes);
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimControlDepend)) {
|
||||
return false;
|
||||
}
|
||||
// set the control depend visited but don't push it into the que
|
||||
if (visited_nodes->find(node) != visited_nodes->end()) {
|
||||
return true;
|
||||
}
|
||||
(void)visited_nodes->insert(cnode);
|
||||
// add a 0 depend num to keep the link relations to prepare for finding zero output nodes
|
||||
auto prior_node = cnode->input(kControlDependPriorIndex);
|
||||
auto depend_node = cnode->input(kControlDependBehindIndex);
|
||||
for (const auto &input : cnode->inputs()) {
|
||||
AddDependEdge(node, input, 0);
|
||||
}
|
||||
PushNoVisitedNode(depend_node, que, visited_nodes);
|
||||
PushNoVisitedNode(prior_node, que, visited_nodes);
|
||||
return true;
|
||||
}
|
||||
|
||||
void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(seed_nodes);
|
||||
node_output_edges_.clear();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -286,15 +286,11 @@ class KernelGraph : public FuncGraph {
|
|||
std::unordered_set<AnfNodePtr> *visited_nodes, bool comm_first = true);
|
||||
// update node edge list
|
||||
void UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes);
|
||||
// add node depend edge by data edge or control depend
|
||||
// add node depend edge by data edge
|
||||
void AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num);
|
||||
void UpdateNodeInputOutputEdges(const std::vector<AnfNodePtr> &real_prior_nodes,
|
||||
const std::vector<AnfNodePtr> &real_depend_nodes);
|
||||
// handle control depend
|
||||
std::vector<AnfNodePtr> GetOutputNodes(const AnfNodePtr &node);
|
||||
bool HandleControlDependNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
|
||||
std::unordered_set<AnfNodePtr> *visited_nodes);
|
||||
void UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends);
|
||||
AnfNodePtr TransValueNodeTuple(const AbstractBasePtr abstract, const ValuePtr &value);
|
||||
AnfNodePtr TransParameterTuple(const AbstractBasePtr &abstract);
|
||||
AnfNodePtr TransCNodeTuple(const CNodePtr &node);
|
||||
|
|
|
@ -223,11 +223,9 @@ BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &gra
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
VectorRef ret;
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(cnode->input(i), prim::kPrimControlDepend)) {
|
||||
auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node);
|
||||
ret.push_back(out);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
// if is graph return nothing ,the function should return a null anylist
|
||||
|
@ -386,22 +384,6 @@ bool ExistSummaryNode(const KernelGraph *graph) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &node_inputs = cnode->inputs();
|
||||
for (size_t i = 1; i < node_inputs.size(); ++i) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(node_inputs[i], prim::kPrimControlDepend)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void GetParameterIndex(KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
|
||||
std::map<AnfNodePtr, size_t> *parameter_index) {
|
||||
size_t index = 0;
|
||||
|
@ -692,9 +674,6 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
|
|||
AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (IgnoreCreateParameterForMakeTuple(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
|
||||
auto parameters = AnfAlgo::GetAllOutput(new_parameter);
|
||||
std::vector<AnfNodePtr> pre_graph_out = {node};
|
||||
|
@ -1872,9 +1851,6 @@ std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_gr
|
|||
auto &users = front_func_graph_manager->node_users()[front_node];
|
||||
std::vector<AnfNodePtr> result;
|
||||
for (auto &user : users) {
|
||||
if (IsPrimitiveCNode(user.first, prim::kPrimControlDepend)) {
|
||||
continue;
|
||||
}
|
||||
if (IsPrimitiveCNode(user.first, prim::kPrimDepend)) {
|
||||
auto depend_cnode = user.first->cast<CNodePtr>();
|
||||
if (depend_cnode == nullptr) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -84,7 +84,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<PrimitivePtr> adapter_convert_ops = {prim::kPrimDepend, prim::kPrimControlDepend, prim::kPrimLoad};
|
||||
std::vector<PrimitivePtr> adapter_convert_ops = {prim::kPrimDepend, prim::kPrimLoad};
|
||||
for (auto &item : adapter_convert_ops) {
|
||||
if (IsPrimitiveCNode(node, item)) {
|
||||
return true;
|
||||
|
@ -243,8 +243,7 @@ CNodePtr MergeNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, int64_t sw
|
|||
return merge_op;
|
||||
}
|
||||
|
||||
// construct a depend node with merge output node, merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data))
|
||||
// control_depend(output_node, square_op)
|
||||
// merge(square_op(switch(ctrl_data)), switch(opposite_ctrl_data))
|
||||
AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond, const AnfNodePtr &output_node,
|
||||
int64_t switch_idx) {
|
||||
tensor::TensorPtr const_data = GetConstData();
|
||||
|
@ -259,54 +258,21 @@ AnfNodePtr GenerateSwitchDependNode(const FuncGraphPtr &graph, const AnfNodePtr
|
|||
SetSquareOp(switch_idx, square_op);
|
||||
}
|
||||
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
AnfNodePtrList inputs = {NewValueNode(prim::kPrimDepend), square_op, output_node};
|
||||
auto depend_cnode = graph->NewCNode(inputs);
|
||||
if (!manager->Replace(square_op, depend_cnode)) {
|
||||
MS_LOG(EXCEPTION) << square_op->DebugString() << ", replace node failed.";
|
||||
}
|
||||
|
||||
CNodePtr merge_op = GetMergeOp(switch_idx);
|
||||
if (merge_op == nullptr) {
|
||||
merge_op = MergeNode(graph, cond, switch_idx, const_data, square_op);
|
||||
SetMergeOp(switch_idx, merge_op);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> control_depend_nodes{NewValueNode(prim::kPrimControlDepend), output_node, square_op};
|
||||
auto control_depend_op = graph->NewCNode(control_depend_nodes);
|
||||
|
||||
std::vector<AnfNodePtr> depend_nodes{NewValueNode(prim::kPrimDepend), merge_op, control_depend_op};
|
||||
auto depend_op = graph->NewCNode(depend_nodes);
|
||||
|
||||
return depend_op;
|
||||
}
|
||||
|
||||
// construct a merge output and add dependency with the netoutput node from control_depend
|
||||
// we need to reserve the control_depend node, besides the generated merge node and control_depend node
|
||||
CNodePtr GenerateSwitchControlDependNode(const FuncGraphPtr &graph, const AnfNodePtr &cond,
|
||||
const AnfNodePtr &ctrl_dep_node, const AnfNodePtr &ctrl_depend_dst,
|
||||
int64_t switch_idx) {
|
||||
auto PrimMerge = prim::GetPythonOps("merge", "mindspore.ops.functional")->cast<PrimitivePtr>();
|
||||
auto PrimSquare = prim::GetPythonOps("square", "mindspore.ops.functional")->cast<PrimitivePtr>();
|
||||
std::vector<int64_t> shp = {1};
|
||||
tensor::TensorPtr const_data = std::make_shared<tensor::Tensor>(kInt64->type_id(), shp);
|
||||
auto *val = static_cast<int64_t *>(const_data->data_c());
|
||||
*val = 0;
|
||||
// for the control_depend netoutput node , add two const data to merge the flow ,one for depended node with same
|
||||
// switch the other use the opposite
|
||||
auto ctrl_data = NewValueNode(const_data);
|
||||
auto oppsite_ctrl_data = NewValueNode(const_data);
|
||||
auto ctrl_node = GenerateSwitchNode(graph, cond, ctrl_data, switch_idx);
|
||||
auto opposite_ctrl_node = GenerateSwitchNode(graph, cond, oppsite_ctrl_data, 1 - switch_idx);
|
||||
|
||||
std::vector<AnfNodePtr> square_nodes{NewValueNode(PrimSquare), ctrl_node};
|
||||
auto square_op = graph->NewCNode(square_nodes);
|
||||
|
||||
std::vector<AnfNodePtr> merge_nodes;
|
||||
merge_nodes.push_back(NewValueNode(PrimMerge));
|
||||
std::vector<AnfNodePtr> make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), square_op, opposite_ctrl_node};
|
||||
merge_nodes.push_back(graph->NewCNode(make_tuple_nodes));
|
||||
auto merge_output = graph->NewCNode(merge_nodes);
|
||||
|
||||
std::vector<AnfNodePtr> control_depend_nodes{NewValueNode(prim::kPrimControlDepend), ctrl_depend_dst, square_op};
|
||||
auto cond_dep_output = graph->NewCNode(control_depend_nodes);
|
||||
|
||||
std::vector<AnfNodePtr> depended_make_tuple_nodes{NewValueNode(prim::kPrimMakeTuple), ctrl_dep_node, merge_output,
|
||||
cond_dep_output};
|
||||
return graph->NewCNode(depended_make_tuple_nodes);
|
||||
return merge_op;
|
||||
}
|
||||
|
||||
// generate switch nodes for true graph node inputs
|
||||
|
@ -321,26 +287,12 @@ AnfNodePtr GenerateSwitchDependFalseNode(const FuncGraphPtr &graph, const AnfNod
|
|||
return GenerateSwitchDependNode(graph, cond, data, 0);
|
||||
}
|
||||
|
||||
// generate switch nodes for true graph node inputs
|
||||
CNodePtr GenerateSwitchControlDependTrueNode(const FuncGraphPtr &graph, const AnfNodePtr &cond,
|
||||
const AnfNodePtr &con_input, const AnfNodePtr &output) {
|
||||
// for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch
|
||||
return GenerateSwitchControlDependNode(graph, cond, con_input, output, 1);
|
||||
}
|
||||
|
||||
// generate switch nodes for false graph node inputs
|
||||
CNodePtr GenerateSwitchControlDependFalseNode(const FuncGraphPtr &graph, const AnfNodePtr &cond,
|
||||
const AnfNodePtr &con_input, const AnfNodePtr &output) {
|
||||
// for switch op ,the output is a tuple ,0-th is false_branch, 1-th is true branch
|
||||
return GenerateSwitchControlDependNode(graph, cond, con_input, output, 0);
|
||||
}
|
||||
|
||||
// to judge if the node used in ControlDepend is a net output node
|
||||
// to judge if the node used in Depend is a net output node
|
||||
bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node) {
|
||||
auto uses = manager->node_users()[node];
|
||||
bool is_output_node = true;
|
||||
for (auto &item : uses) {
|
||||
if (IsPrimitiveCNode(item.first, prim::kPrimControlDepend) || IsPrimitiveCNode(item.first, prim::kPrimDepend)) {
|
||||
if (IsPrimitiveCNode(item.first, prim::kPrimDepend)) {
|
||||
continue;
|
||||
}
|
||||
is_output_node = false;
|
||||
|
@ -353,8 +305,7 @@ bool IsNetOutputNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node)
|
|||
void GenerateReplNodeForDependMakeTuple(
|
||||
const AnfNodePtr &depended_node, const FuncGraphPtr &graph, const AnfNodePtr &cond,
|
||||
const std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> &repl_node,
|
||||
const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func,
|
||||
const std::function<CNodePtr(FuncGraphPtr, AnfNodePtr, AnfNodePtr, AnfNodePtr)> &gen_ctl_depd_func) {
|
||||
const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func) {
|
||||
MS_EXCEPTION_IF_NULL(graph->manager());
|
||||
|
||||
auto make_tuple_inputs = depended_node->cast<CNodePtr>()->inputs();
|
||||
|
@ -368,26 +319,6 @@ void GenerateReplNodeForDependMakeTuple(
|
|||
new_make_tuple_nodes.push_back(depended_tuple_input_node);
|
||||
continue;
|
||||
}
|
||||
if (IsPrimitiveCNode(depended_tuple_input_node->cast<CNodePtr>(), prim::kPrimControlDepend)) {
|
||||
// only when the control depend input is not square op (the op to use as merge output)
|
||||
auto control_inputs = depended_tuple_input_node->cast<CNodePtr>()->inputs();
|
||||
if (control_inputs.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size();
|
||||
}
|
||||
// control inputs: primitive, src, dst
|
||||
auto dst_node = control_inputs[2];
|
||||
if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) {
|
||||
auto gen_node = gen_ctl_depd_func(graph, cond, make_tuple_inputs[idx], dst_node);
|
||||
MS_EXCEPTION_IF_NULL(gen_node);
|
||||
auto tuple_inputs = gen_node->inputs();
|
||||
// add depended tuple inputs to new_make_tuple directly
|
||||
for (size_t i = 1; i < tuple_inputs.size(); i++) {
|
||||
new_make_tuple_nodes.push_back(tuple_inputs[i]);
|
||||
}
|
||||
}
|
||||
replace_make_tuple = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (graph->manager()->node_users()[depended_tuple_input_node].size() == 1) {
|
||||
auto gen_node = generate_func(graph, cond, depended_tuple_input_node);
|
||||
|
@ -408,8 +339,7 @@ void GenerateReplNodeForDependMakeTuple(
|
|||
void GenerateRepDepend(
|
||||
const CNodePtr &node, const FuncGraphPtr &graph, const AnfNodePtr &cond,
|
||||
const std::shared_ptr<std::unordered_map<AnfNodePtr, AnfNodePtr>> &repl_node,
|
||||
const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func,
|
||||
const std::function<CNodePtr(FuncGraphPtr, AnfNodePtr, AnfNodePtr, AnfNodePtr)> &gen_ctl_depd_func) {
|
||||
const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &generate_func) {
|
||||
auto inputs = node->inputs();
|
||||
if (inputs.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "Inputs should be [depend, actual_value, depended_node].";
|
||||
|
@ -422,19 +352,7 @@ void GenerateRepDepend(
|
|||
new_depened_inputs.push_back(inputs[1]);
|
||||
// depended node should be make_tuple or a single depended node
|
||||
if (IsPrimitiveCNode(depended_node, prim::kPrimMakeTuple)) {
|
||||
GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func, gen_ctl_depd_func);
|
||||
} else if (IsPrimitiveCNode(depended_node, prim::kPrimControlDepend)) {
|
||||
// only when the control depend input is not square op (the op to use as merge output)
|
||||
auto control_inputs = depended_node->cast<CNodePtr>()->inputs();
|
||||
// control inputs: primitive, src, dst
|
||||
if (control_inputs.size() != 3) {
|
||||
MS_LOG(EXCEPTION) << "controldepend input size != 3, got " << control_inputs.size();
|
||||
}
|
||||
auto dst_node = control_inputs[2];
|
||||
if (!IsPrimitiveCNode(dst_node, prim::kPrimSquare) && IsNetOutputNode(graph->manager(), dst_node)) {
|
||||
auto gen_node = gen_ctl_depd_func(graph, cond, depended_node, dst_node);
|
||||
(*repl_node)[depended_node] = gen_node;
|
||||
}
|
||||
GenerateReplNodeForDependMakeTuple(depended_node, graph, cond, repl_node, generate_func);
|
||||
} else {
|
||||
// Check if there is only single user for depend_node.
|
||||
if (graph->manager()->node_users()[depended_node].size() == 1) {
|
||||
|
@ -448,11 +366,9 @@ void GenerateRepDepend(
|
|||
|
||||
// generate depend node for netoutput node, to resolve the stream synchronize problem of ge
|
||||
// traverse all nodes of depend node, find the graph output node , generaete a merge node of (square, const)
|
||||
// and add control_depend of graph output node and square node.
|
||||
FuncGraphPtr TransformGraphDependNode(
|
||||
const FuncGraphPtr &graph, const AnfNodePtr &cond,
|
||||
const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &gen_depend_func,
|
||||
const std::function<CNodePtr(FuncGraphPtr, AnfNodePtr, AnfNodePtr, AnfNodePtr)> &gen_ctl_depd_func) {
|
||||
const std::function<AnfNodePtr(FuncGraphPtr graph, AnfNodePtr cond, AnfNodePtr data)> &gen_depend_func) {
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
||||
|
@ -478,7 +394,7 @@ FuncGraphPtr TransformGraphDependNode(
|
|||
if (IsPrimitiveCNode(depended_node, prim::kPrimDepend)) {
|
||||
continue;
|
||||
}
|
||||
GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func, gen_ctl_depd_func);
|
||||
GenerateRepDepend(cnode, graph, cond, repl_node, gen_depend_func);
|
||||
}
|
||||
}
|
||||
ResetSharedOp();
|
||||
|
@ -494,12 +410,12 @@ FuncGraphPtr TransformGraphDependNode(
|
|||
|
||||
FuncGraphPtr TransformGraphCondTrueBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) {
|
||||
(void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchTrueNode);
|
||||
return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode, GenerateSwitchControlDependTrueNode);
|
||||
return TransformGraphDependNode(graph, cond, GenerateSwitchDependTrueNode);
|
||||
}
|
||||
|
||||
FuncGraphPtr TransformGraphCondFalseBranchNodes(const FuncGraphPtr &graph, const AnfNodePtr &cond) {
|
||||
(void)TransformGraphCondBranchNodes(graph, cond, GenerateSwitchFalseNode);
|
||||
return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode, GenerateSwitchControlDependFalseNode);
|
||||
return TransformGraphDependNode(graph, cond, GenerateSwitchDependFalseNode);
|
||||
}
|
||||
|
||||
// judge if the true and false graph output is compatible(they shall have same tuple size)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -218,11 +218,11 @@ bool HasForwardOutput(const FuncGraphManagerPtr &mng, const AnfNodePtr &node) {
|
|||
if (output_set_iter == node_users.end()) {
|
||||
return false;
|
||||
}
|
||||
for (const auto &node_index_set : output_set_iter->second) {
|
||||
if (!IsBpropNode(node_index_set.first) && !IsPrimitiveCNode(node_index_set.first, prim::kPrimControlDepend)) {
|
||||
|
||||
if (std::any_of(output_set_iter->second.begin(), output_set_iter->second.end(),
|
||||
[](const auto &node_index_set) { return !IsBpropNode(node_index_set.first); })) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -367,7 +367,6 @@ constexpr char HISTOGRAMSUMMARY[] = "HistogramSummary";
|
|||
constexpr char DEBUG[] = "Debug";
|
||||
constexpr char BROADCASTGRADIENTARGS[] = "BroadcastGradientArgs";
|
||||
constexpr char INVERTPERMUTATION[] = "InvertPermutation";
|
||||
constexpr char CONTROLDEPEND[] = "ControlDepend";
|
||||
constexpr char DOT[] = "dot";
|
||||
constexpr char IM2COL[] = "im2col";
|
||||
constexpr char COL2IM[] = "col2im";
|
||||
|
|
|
@ -259,11 +259,9 @@ BaseRef CreateOutputTensors(const AnfNodePtr &output_node, const KernelGraphPtr
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
VectorRef ret;
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(cnode->input(i), prim::kPrimControlDepend)) {
|
||||
auto out = CreateOutputTensors(cnode->input(i), graph, input_tensors);
|
||||
ret.push_back(out);
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -1044,7 +1044,7 @@ bool DfGraphConvertor::IsControlEdgeNode(const AnfNodePtr &node) {
|
|||
OperatorPtr DfGraphConvertor::ToOperatorPtr(const AnfNodePtr &node) {
|
||||
auto op = Convert(GetRealOpNode(node));
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert control depend node to operator failed, " << node->ToString();
|
||||
MS_LOG(ERROR) << "Convert real op node to operator failed, " << node->ToString();
|
||||
error_ = FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -1170,13 +1170,13 @@ void DfGraphConvertor::AutoMonadSetControlInput(const AnfNodePtr &node) {
|
|||
|
||||
void DfGraphConvertor::SetOpControlInput(const AnfNodePtr &node) {
|
||||
AutoMonadSetControlInput(node);
|
||||
if (control_depend_cache_.find(node.get()) == control_depend_cache_.end()) {
|
||||
if (control_edge_cache_.find(node.get()) == control_edge_cache_.end()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<ControlEdge> control_edges = control_depend_cache_[node.get()];
|
||||
std::vector<ControlEdge> control_edges = control_edge_cache_[node.get()];
|
||||
if ((control_edges.empty())) {
|
||||
MS_LOG(ERROR) << "Get control depend node's src or dest operator failed";
|
||||
MS_LOG(ERROR) << "Get control edge node's src or dest operator failed";
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -1600,7 +1600,7 @@ std::vector<OperatorPtr> DfGraphConvertor::ConvertDependNode(const AnfNodePtr no
|
|||
for (size_t index = 1; index < node_inputs.size(); index++) {
|
||||
auto op = Convert(GetRealOpNode(node_inputs[index]));
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert control depend node to operator failed";
|
||||
MS_LOG(ERROR) << "Convert real op node to operator failed";
|
||||
error_ = FAILED;
|
||||
return std::vector<OperatorPtr>({});
|
||||
}
|
||||
|
@ -1611,194 +1611,13 @@ std::vector<OperatorPtr> DfGraphConvertor::ConvertDependNode(const AnfNodePtr no
|
|||
|
||||
auto op = Convert(GetRealOpNode(node));
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "Convert control depend node to operator failed";
|
||||
MS_LOG(ERROR) << "Convert real op node to operator failed";
|
||||
error_ = FAILED;
|
||||
return std::vector<OperatorPtr>({});
|
||||
}
|
||||
return std::vector<OperatorPtr>({op});
|
||||
}
|
||||
|
||||
// get the anf node list for depend
|
||||
std::vector<AnfNodePtr> DfGraphConvertor::GetDependNodes(const AnfNodePtr &node) {
|
||||
std::vector<AnfNodePtr> nodes;
|
||||
// for make tuple, should control depend on the tuple items
|
||||
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
auto node_inputs = node->cast<CNodePtr>()->inputs();
|
||||
for (size_t index = 1; index < node_inputs.size(); index++) {
|
||||
nodes.push_back(GetRealOpNode(node_inputs[index]));
|
||||
}
|
||||
return nodes;
|
||||
}
|
||||
|
||||
// for parameter ,find the apply that used the parameter as the control depended node
|
||||
if (node->isa<Parameter>()) {
|
||||
auto uses = node->func_graph()->manager()->node_users()[node];
|
||||
for (auto &use : uses) {
|
||||
auto use_node = use.first;
|
||||
if ((use_node->isa<CNode>()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) {
|
||||
nodes.push_back(GetRealOpNode(use_node));
|
||||
}
|
||||
}
|
||||
return nodes;
|
||||
}
|
||||
nodes.push_back(GetRealOpNode(node));
|
||||
return nodes;
|
||||
}
|
||||
|
||||
void DfGraphConvertor::DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node) {
|
||||
#ifdef DRAW_GE_GRAPH
|
||||
auto src_depend_nodes = GetDependNodes(src_node);
|
||||
auto dst_depend_nodes = GetDependNodes(dest_node);
|
||||
if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() > 1) {
|
||||
for (auto &item : dst_depend_nodes) {
|
||||
compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[item.get()]
|
||||
<< "[style=\"dotted\"]" << endl;
|
||||
}
|
||||
} else if (src_depend_nodes.size() > 1 && dst_depend_nodes.size() == 1) {
|
||||
for (auto &item : src_depend_nodes) {
|
||||
compute_sout_ << op_draw_name_[item.get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()]
|
||||
<< "[style=\"dotted\"]" << endl;
|
||||
}
|
||||
} else if (src_depend_nodes.size() == 1 && dst_depend_nodes.size() == 1) {
|
||||
compute_sout_ << op_draw_name_[src_depend_nodes[0].get()] << " -> " << op_draw_name_[dst_depend_nodes[0].get()]
|
||||
<< "[style=\"dotted\"]" << endl;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void DfGraphConvertor::GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node,
|
||||
const AnfNodePtr &dest_node,
|
||||
const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list,
|
||||
const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list) {
|
||||
if (src_node->isa<Parameter>()) {
|
||||
auto uses = node->func_graph()->manager()->node_users()[src_node];
|
||||
for (auto &use : uses) {
|
||||
auto use_node = use.first;
|
||||
if ((use_node->isa<CNode>()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) &&
|
||||
(!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) {
|
||||
auto converted_list = ConvertDependNode(use_node);
|
||||
src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (dest_node->isa<Parameter>()) {
|
||||
auto uses = node->func_graph()->manager()->node_users()[dest_node];
|
||||
for (auto &use : uses) {
|
||||
auto use_node = use.first;
|
||||
if ((use_node->isa<CNode>()) && (!IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) &&
|
||||
(!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple))) {
|
||||
auto converted_list = ConvertDependNode(use_node);
|
||||
dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool DfGraphConvertor::GetControlDependList(const CNodePtr &node,
|
||||
const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list,
|
||||
const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list) {
|
||||
const int CONTROL_DEPEND_INDEX = 0;
|
||||
const int SRC_NODE_INDEX = 1;
|
||||
const int DEST_NODE_INDEX = 2;
|
||||
const int DEPEND_MODE_NORMAL_USE = 0;
|
||||
const int DEPEND_MODE_ON_PARAMETER_USE = 1;
|
||||
|
||||
auto node_inputs = node->inputs();
|
||||
if (node_inputs.size() <= DEST_NODE_INDEX) {
|
||||
MS_LOG(WARNING) << "Control depend node input size error";
|
||||
return false;
|
||||
}
|
||||
auto src_node = node_inputs[SRC_NODE_INDEX];
|
||||
auto dest_node = node_inputs[DEST_NODE_INDEX];
|
||||
if ((src_node == nullptr) || (dest_node == nullptr)) {
|
||||
MS_LOG(ERROR) << "Control depend node miss src or dest node";
|
||||
error_ = FAILED;
|
||||
return false;
|
||||
}
|
||||
AnfNodePtr fn = node_inputs[CONTROL_DEPEND_INDEX];
|
||||
PrimitivePtr prim_ptr = GetValueNode<PrimitivePtr>(fn);
|
||||
ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode");
|
||||
int depend_mode = DEPEND_MODE_NORMAL_USE;
|
||||
if (mode_ptr != nullptr) {
|
||||
auto mode_int = mode_ptr->cast<Int64ImmPtr>();
|
||||
MS_EXCEPTION_IF_NULL(mode_int);
|
||||
depend_mode = mode_int->value();
|
||||
MS_LOG(DEBUG) << "depend_mode = " << depend_mode;
|
||||
}
|
||||
if (depend_mode == DEPEND_MODE_ON_PARAMETER_USE) {
|
||||
GetDependOnParameterUse(node, src_node, dest_node, src_ops_list, dst_ops_list);
|
||||
}
|
||||
|
||||
if (src_node->isa<CNode>()) {
|
||||
auto converted_list = ConvertDependNode(src_node);
|
||||
src_ops_list->insert(src_ops_list->end(), converted_list.begin(), converted_list.end());
|
||||
}
|
||||
|
||||
if (dest_node->isa<CNode>()) {
|
||||
auto converted_list = ConvertDependNode(dest_node);
|
||||
dst_ops_list->insert(dst_ops_list->end(), converted_list.begin(), converted_list.end());
|
||||
}
|
||||
if (src_ops_list->empty() || dst_ops_list->empty()) {
|
||||
MS_LOG(DEBUG) << "Control depend node's src or dest node is not a CNode, ignore it";
|
||||
error_ = SUCCESS;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) {
|
||||
const int SRC_NODE_INDEX = 1;
|
||||
const int DEST_NODE_INDEX = 2;
|
||||
if (control_depend_cache_.find(node.get()) != control_depend_cache_.end()) {
|
||||
return;
|
||||
}
|
||||
auto node_inputs = node->inputs();
|
||||
if (node_inputs.size() <= DEST_NODE_INDEX) {
|
||||
MS_LOG(WARNING) << "Control depend node input size error";
|
||||
return;
|
||||
}
|
||||
auto src_node = node_inputs[SRC_NODE_INDEX];
|
||||
auto dest_node = node_inputs[DEST_NODE_INDEX];
|
||||
if ((src_node == nullptr) || (dest_node == nullptr)) {
|
||||
MS_LOG(ERROR) << "Control depend node miss src or dest node";
|
||||
error_ = FAILED;
|
||||
return;
|
||||
}
|
||||
std::shared_ptr<std::vector<OperatorPtr>> src_ops_list = std::make_shared<std::vector<OperatorPtr>>();
|
||||
std::shared_ptr<std::vector<OperatorPtr>> dst_ops_list = std::make_shared<std::vector<OperatorPtr>>();
|
||||
if (!GetControlDependList(node, src_ops_list, dst_ops_list)) {
|
||||
MS_LOG(ERROR) << "Get depend list failed";
|
||||
error_ = FAILED;
|
||||
return;
|
||||
}
|
||||
std::vector<ControlEdge> control_edges;
|
||||
if (src_ops_list->size() == 1 && dst_ops_list->size() > 1) {
|
||||
(void)std::transform(dst_ops_list->begin(), dst_ops_list->end(), std::back_inserter(control_edges),
|
||||
[src_ops_list](const OperatorPtr &op) -> ControlEdge {
|
||||
return {(*src_ops_list)[0], op};
|
||||
});
|
||||
} else if (src_ops_list->size() > 1 && dst_ops_list->size() == 1) {
|
||||
(void)std::transform(src_ops_list->begin(), src_ops_list->end(), std::back_inserter(control_edges),
|
||||
[dst_ops_list](const OperatorPtr &op) -> ControlEdge {
|
||||
return {op, (*dst_ops_list)[0]};
|
||||
});
|
||||
} else if (src_ops_list->size() == 1 && dst_ops_list->size() == 1) {
|
||||
control_edges.push_back({(*src_ops_list)[0], (*dst_ops_list)[0]});
|
||||
} else if (src_ops_list->empty() || dst_ops_list->empty()) {
|
||||
MS_LOG(DEBUG) << "Depend list of src or dst is empty, ignore it";
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Convert control depend node to operator failed, depend src:" << src_ops_list->size()
|
||||
<< " -> dst:" << dst_ops_list->size();
|
||||
error_ = FAILED;
|
||||
return;
|
||||
}
|
||||
control_depend_cache_[node.get()] = control_edges;
|
||||
|
||||
#ifdef DRAW_GE_GRAPH
|
||||
DrawControlDepend(src_node, dest_node);
|
||||
#endif
|
||||
}
|
||||
|
||||
bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) {
|
||||
// ignore apply node of return
|
||||
if (name == "" || name == prim::kPrimReturn->name() || name == prim::kPrimDepend->name() ||
|
||||
|
@ -1818,12 +1637,6 @@ bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node)
|
|||
return false;
|
||||
}
|
||||
|
||||
// ControlDepend
|
||||
if (name == prim::kPrimControlDepend->name()) {
|
||||
ConvertControlDependNode(node);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -145,19 +145,11 @@ class DfGraphConvertor {
|
|||
OperatorPtr ConvertCNode(CNodePtr node);
|
||||
std::vector<OperatorPtr> ConvertDependNode(AnfNodePtr node);
|
||||
AnfNodePtr GetRealOpNode(AnfNodePtr node);
|
||||
std::vector<AnfNodePtr> GetDependNodes(const AnfNodePtr &node);
|
||||
OperatorPtr ConvertParameter(AnfNodePtr node);
|
||||
Status TryConvertValueNodeToMultiConst(const ValueNodePtr node);
|
||||
OperatorPtr ConvertValueNode(ValueNodePtr node);
|
||||
void GetCaseNodeInput(const CNodePtr node, const CNodePtr input_node);
|
||||
void ConvertTupleGetItem(const CNodePtr node);
|
||||
void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node,
|
||||
const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list,
|
||||
const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list);
|
||||
bool GetControlDependList(const CNodePtr &node, const std::shared_ptr<std::vector<OperatorPtr>> &src_ops_list,
|
||||
const std::shared_ptr<std::vector<OperatorPtr>> &dst_ops_list);
|
||||
void DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node);
|
||||
void ConvertControlDependNode(const CNodePtr node);
|
||||
void ConvertMakeTuple(const CNodePtr node);
|
||||
bool CheckCNode(const std::string &name, const CNodePtr node);
|
||||
void TraceOutput(AnfNodePtr node);
|
||||
|
@ -195,7 +187,7 @@ class DfGraphConvertor {
|
|||
std::shared_ptr<DfGraph> broadcast_graph_{nullptr};
|
||||
std::unordered_map<AnfNode *, DfGraph> branches_map_;
|
||||
std::unordered_map<AnfNode *, OperatorPtr> op_cache_;
|
||||
std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_depend_cache_;
|
||||
std::unordered_map<AnfNode *, std::vector<ControlEdge>> control_edge_cache_;
|
||||
std::unordered_map<AnfNodePtr, std::set<AnfNodePtr>> monad_control_edge_cache_;
|
||||
/* record "tuple_getitem"<->"out_handler" mapping */
|
||||
std::unordered_map<AnfNode *, OutHandler> out_handle_cache_;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -51,88 +51,6 @@ std::string GetOtherTarget(const std::vector<AnfNodePtr> &nodes) {
|
|||
}
|
||||
return "";
|
||||
}
|
||||
bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node,
|
||||
std::vector<AnfNodePtr> *prior_nodes, std::vector<AnfNodePtr> *depend_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(prior_node);
|
||||
MS_EXCEPTION_IF_NULL(behind_node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto &node_users = manager->node_users();
|
||||
if (prior_node->isa<Parameter>()) {
|
||||
for (auto &user : node_users[prior_node]) {
|
||||
auto cnode = user.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
|
||||
prior_nodes->emplace_back(cnode);
|
||||
}
|
||||
}
|
||||
} else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) {
|
||||
prior_nodes->emplace_back(prior_node);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
if (behind_node->isa<Parameter>()) {
|
||||
for (auto &user : node_users[behind_node]) {
|
||||
auto cnode = user.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
|
||||
depend_nodes->emplace_back(cnode);
|
||||
}
|
||||
}
|
||||
} else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) {
|
||||
depend_nodes->emplace_back(behind_node);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges,
|
||||
std::map<AnfNodePtr, size_t> *nodes_ref) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto input_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
auto prior_node = input_cnode->input(kControlDependPriorIndex);
|
||||
auto depend_node = input_cnode->input(kControlDependBehindIndex);
|
||||
MS_EXCEPTION_IF_NULL(prior_node);
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
auto prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim_ptr);
|
||||
ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode");
|
||||
int64_t depend_mode = 0;
|
||||
if (mode_ptr != nullptr) {
|
||||
depend_mode = GetValue<int64_t>(mode_ptr);
|
||||
}
|
||||
if ((prior_node->isa<Parameter>() || depend_node->isa<Parameter>()) && depend_mode == 0) {
|
||||
return;
|
||||
}
|
||||
std::vector<AnfNodePtr> prior_nodes;
|
||||
std::vector<AnfNodePtr> behind_nodes;
|
||||
if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) {
|
||||
return;
|
||||
}
|
||||
for (auto &first_node : prior_nodes) {
|
||||
for (auto &second_node : behind_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(first_node);
|
||||
MS_EXCEPTION_IF_NULL(second_node);
|
||||
auto iter = control_edges->find(second_node);
|
||||
if (iter == control_edges->end()) {
|
||||
(void)control_edges->insert(
|
||||
std::pair<AnfNodePtr, std::vector<AnfNodePtr>>(second_node, std::vector<AnfNodePtr>{first_node}));
|
||||
} else {
|
||||
iter->second.emplace_back(first_node);
|
||||
}
|
||||
auto ref_iter = nodes_ref->find(first_node);
|
||||
if (ref_iter != nodes_ref->end()) {
|
||||
ref_iter->second++;
|
||||
} else {
|
||||
(void)nodes_ref->insert(std::pair<AnfNodePtr, size_t>(first_node, 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *nodes_ref,
|
||||
std::map<AnfNodePtr, std::vector<AnfNodePtr>> *control_edges) {
|
||||
|
@ -149,9 +67,6 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map<AnfNodePtr, size_t> *n
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (auto &input : cnode->inputs()) {
|
||||
if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) {
|
||||
AddControlEdge(graph, input, control_edges, nodes_ref);
|
||||
}
|
||||
auto iter = nodes_ref->find(input);
|
||||
if (iter != nodes_ref->end()) {
|
||||
iter->second++;
|
||||
|
@ -479,12 +394,10 @@ void AddSegmentDependency(const FuncGraphPtr &graph, const std::string &default_
|
|||
node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end());
|
||||
}
|
||||
GraphSegmentPtr node_segment{nullptr};
|
||||
if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
|
||||
auto node_iter = node_to_segment.find(node);
|
||||
if (node_iter != node_to_segment.end()) {
|
||||
node_segment = node_iter->second;
|
||||
}
|
||||
}
|
||||
for (auto &input : node_inputs) {
|
||||
if (node_segment != nullptr && !node_segment->is_cut_ && input->isa<CNode>()) {
|
||||
GraphSegmentPtr input_segment{nullptr};
|
||||
|
@ -615,18 +528,14 @@ void SplitDynamicNodeSegment(const std::vector<AnfNodePtr> &segment_nodes, std::
|
|||
std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment,
|
||||
const std::set<AnfNodePtr> &dynamic_nodes_set) {
|
||||
SplitDynamicNodesHelper helper;
|
||||
bool is_last_node_dynamic = false;
|
||||
for (auto &node : segment_nodes) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) {
|
||||
helper.AddNode(node, is_last_node_dynamic);
|
||||
continue;
|
||||
}
|
||||
auto &inputs = cnode->inputs();
|
||||
bool has_dynamic_shape = dynamic_nodes_set.find(node) != dynamic_nodes_set.end();
|
||||
bool depend_common_node = false;
|
||||
bool depend_dynamic_node = false;
|
||||
bool is_last_node_dynamic = false;
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
if (dynamic_nodes_set.find(inputs[i]) != dynamic_nodes_set.end()) {
|
||||
has_dynamic_shape = true;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -87,26 +87,7 @@ AnfNodePtr RefSubGraphNode(const FuncGraphPtr &fg, const AnfNodePtr &node, AnfNo
|
|||
if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
|
||||
eqv[node] = node;
|
||||
} else if (eqv.find(node) == eqv.end()) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimControlDepend)) {
|
||||
eqv[node] = NewValueNode(MakeValue(0));
|
||||
return eqv[node];
|
||||
}
|
||||
bool ignore_make_tuple = false;
|
||||
if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
|
||||
ignore_make_tuple = true;
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &node_inputs = cnode->inputs();
|
||||
for (size_t i = 1; i < node_inputs.size(); ++i) {
|
||||
if (!IsPrimitiveCNode(node_inputs[i], prim::kPrimControlDepend)) {
|
||||
ignore_make_tuple = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!ignore_make_tuple) {
|
||||
inputs.push_back(node);
|
||||
}
|
||||
eqv[node] = fg->add_parameter();
|
||||
eqv[node]->set_abstract(node->abstract());
|
||||
eqv[node]->set_kernel_info(node->kernel_info_ptr());
|
||||
|
@ -148,14 +129,6 @@ std::tuple<FuncGraphPtr, AnfNodePtrList, AnfNodePtrList> TransformSegmentToAnfGr
|
|||
for (size_t i = 2; i < inps.size(); ++i) {
|
||||
args.emplace_back(NewValueNode(MakeValue(0)));
|
||||
}
|
||||
} else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) {
|
||||
for (size_t i = 1; i < inps.size(); ++i) {
|
||||
if (inps[i]->isa<CNode>() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) {
|
||||
args.emplace_back(NewValueNode(MakeValue(static_cast<int>(i))));
|
||||
} else {
|
||||
args.emplace_back(RefSubGraphNode(fg, inps[i], &inputs, &eqv));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args),
|
||||
[&fg, &inputs, &eqv](const AnfNodePtr &a) { return RefSubGraphNode(fg, a, &inputs, &eqv); });
|
||||
|
|
|
@ -182,8 +182,6 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -188,23 +188,6 @@ AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitiveP
|
|||
return args_spec_list[0]->Broaden();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// args: Two objects of a subclass of AbstractBase
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 2);
|
||||
auto arg_src = args_spec_list[0];
|
||||
auto arg_dst = args_spec_list[1];
|
||||
// control depend can not setup tuple of ops to tuple of ops dependency relation
|
||||
if (arg_src->isa<AbstractTuple>() && arg_dst->isa<AbstractTuple>()) {
|
||||
auto src_size = arg_src->cast<AbstractTuplePtr>()->size();
|
||||
auto dst_size = arg_src->cast<AbstractTuplePtr>()->size();
|
||||
if (src_size > 1 && dst_size > 1) {
|
||||
MS_LOG(EXCEPTION) << "Control depend can not setup operator dependency relationship from tuple from tuple";
|
||||
}
|
||||
}
|
||||
return std::make_shared<AbstractScalar>(kAnyValue, kBool);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
|
|
|
@ -149,7 +149,6 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimStateSetItem, {InferImplStateSetItem, nullptr, true}},
|
||||
{prim::kPrimDepend, {InferImplDepend, nullptr, true}},
|
||||
{prim::kPrimUpdateState, {InferImplUpdateState, nullptr, true}},
|
||||
{prim::kPrimControlDepend, {InferImplControlDepend, nullptr, true}},
|
||||
// Debug
|
||||
{prim::kPrimDebug, {InferImplDebug, nullptr, true}},
|
||||
// Dynamic shape testing
|
||||
|
|
|
@ -453,7 +453,6 @@ inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookB
|
|||
inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType");
|
||||
inline const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape");
|
||||
inline const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print");
|
||||
inline const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend");
|
||||
inline const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_");
|
||||
inline const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
|
||||
inline const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -399,8 +399,7 @@ std::string GetAttrTarget(const PrimitivePtr &primitive, const ValuePtr &att_tar
|
|||
if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) ||
|
||||
IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) ||
|
||||
IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) ||
|
||||
IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) ||
|
||||
IsPrimitive(attr_input, prim::kPrimPartial)) {
|
||||
IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimPartial)) {
|
||||
primitive->EraseAttr(primitive_target);
|
||||
return default_target;
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,7 +23,6 @@
|
|||
#include "ir/dtype/tensor_type.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "ops/control_depend.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,7 +23,6 @@
|
|||
#include "ops/expand_dims.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "ops/control_depend.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,7 +30,7 @@ static const std::set<std::string> PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem,
|
|||
"identity", "partial", "env_setitem", "env_getitem", "env_add", "MakeRefKey", "make_ref", "get_ref_key",
|
||||
"get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary",
|
||||
"ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs",
|
||||
"InvertPermutation", "ControlDepend", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||
"InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed",
|
||||
"stop_gradient", "Send", "UpdateState", "Load"};
|
||||
// clang-format on
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -27,7 +27,6 @@
|
|||
#include "abstract/abstract_value.h"
|
||||
#include "mindspore/core/ir/primitive.h"
|
||||
#include "ops/fusion/partial_fusion.h"
|
||||
#include "ops/control_depend.h"
|
||||
#include "ops/depend.h"
|
||||
#include "ops/make_tuple.h"
|
||||
#include "ops/quant_dtype_cast.h"
|
||||
|
@ -213,8 +212,7 @@ void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) {
|
|||
MS_LOG(ERROR) << "value node is invalid.";
|
||||
return;
|
||||
}
|
||||
if (value_node->value() != nullptr && (opt::CheckPrimitiveType(depend_node, prim::kPrimDepend) ||
|
||||
opt::CheckPrimitiveType(depend_node, prim::kPrimControlDepend))) {
|
||||
if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) {
|
||||
has_depend = true;
|
||||
bool mask_out = (depend_node->inputs().size() == 3);
|
||||
for (size_t j = 1; j < depend_node->inputs().size(); ++j) {
|
||||
|
@ -466,8 +464,8 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr<sc
|
|||
}
|
||||
|
||||
RemoveIfDepend(cnode);
|
||||
if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameControlDepend ||
|
||||
prim->name() == mindspore::ops::kNameTupleGetItem || prim->name() == mindspore::ops::kNameMakeTuple) {
|
||||
if (prim->name() == mindspore::ops::kNameDepend || prim->name() == mindspore::ops::kNameTupleGetItem ||
|
||||
prim->name() == mindspore::ops::kNameMakeTuple) {
|
||||
continue;
|
||||
}
|
||||
if (prim->name() == "make_tuple") {
|
||||
|
|
|
@ -57,8 +57,8 @@ bool IsRealKernel(const AnfNodePtr &node) {
|
|||
IsPrimitive(input, prim::kPrimTensorSummary) ||
|
||||
IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
|
||||
IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
|
||||
IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
|
||||
IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial);
|
||||
IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimReturn) ||
|
||||
IsPrimitive(input, prim::kPrimPartial);
|
||||
return !is_virtual_node;
|
||||
}
|
||||
|
||||
|
|
|
@ -43,8 +43,8 @@ tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) {
|
|||
|
||||
bool IsSpecialType(const CNodePtr &cnode) {
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) ||
|
||||
CheckPrimitiveType(cnode, prim::kPrimControlDepend) || CheckPrimitiveType(cnode, prim::kPrimMakeTuple) ||
|
||||
CheckPrimitiveType(cnode, prim::kPrimReturn) || CheckPrimitiveType(cnode, std::make_shared<Primitive>("While")) ||
|
||||
CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimReturn) ||
|
||||
CheckPrimitiveType(cnode, std::make_shared<Primitive>("While")) ||
|
||||
CheckPrimitiveType(cnode, std::make_shared<Primitive>("If"))) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -58,13 +58,6 @@ int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraph
|
|||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
}
|
||||
if (CheckPrimitiveType(anf_node, prim::kPrimControlDepend)) {
|
||||
if (cnode->size() != InputDoubleNum) {
|
||||
MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
|
||||
remove_cnode_.insert(anf_node);
|
||||
return lite::RET_NO_CHANGE;
|
||||
}
|
||||
}
|
||||
|
||||
bool replace_succ = manager->Replace(anf_node, cnode->input(1));
|
||||
if (!replace_succ) {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -96,7 +96,6 @@ class GPT2FinetuneCell(nn.Cell):
|
|||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
@ -132,8 +131,8 @@ class GPT2FinetuneCell(nn.Cell):
|
|||
|
||||
if not self.gpu_target:
|
||||
init = self.alloc_status()
|
||||
init = F.depend(init, loss)
|
||||
clear_before_grad = self.clear_before_grad(init)
|
||||
F.control_depend(loss, init)
|
||||
self.depend_parameter_use(clear_before_grad, scaling_sens)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_mask,
|
||||
|
@ -145,10 +144,10 @@ class GPT2FinetuneCell(nn.Cell):
|
|||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
if not self.gpu_target:
|
||||
init = F.depend(init, grads)
|
||||
flag = self.get_status(init)
|
||||
init = F.depend(init, flag)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
F.control_depend(grads, flag)
|
||||
F.control_depend(flag, flag_sum)
|
||||
else:
|
||||
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
|
||||
flag_sum = self.addn(flag_sum)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -74,9 +74,9 @@ TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence_with_make_tupl
|
|||
/*
|
||||
* def before(x, y, a, b):
|
||||
* z = make_tuple(TransData(a), TransData(b))
|
||||
* depend_intput = control_depend(y, z)
|
||||
* sum = add(x, depend_intput)
|
||||
* return sum
|
||||
* depend_intput = depend(y, z)
|
||||
* sum_add = add(x, depend_intput)
|
||||
* return sum_add
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence_with_make_tuple", "before");
|
||||
|
||||
|
@ -93,11 +93,11 @@ TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence_with_make_tupl
|
|||
|
||||
TEST_F(TestHWOptimizeDependence, test_optimize_control_dependence) {
|
||||
/*
|
||||
* def before(x, y, a, b):
|
||||
* z = make_tuple(TransData(a), TransData(b))
|
||||
* depend_intput = control_depend(y, z)
|
||||
* sum = add(x, depend_intput)
|
||||
* return sum
|
||||
* def before(x, y, z):
|
||||
* new_z = TransData(z)
|
||||
* depend_intput = depend(y, new_z)
|
||||
* sum_add = add(x, depend_intput)
|
||||
* return sum_add
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_optimize_control_dependence", "before");
|
||||
|
||||
|
|
Loading…
Reference in New Issue