forked from mindspore-Ecosystem/mindspore
!5374 clean code for graph kernel module.
Merge pull request !5374 from chenlei_autodiff/clean_code_graph_kernel
This commit is contained in:
commit
8cea881642
|
@ -73,54 +73,44 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode, bool is_before_ker
|
|||
return used_nodes;
|
||||
}
|
||||
|
||||
void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) {
|
||||
void SearchForDependNode(const AnfNodeSet &outputs_set, const AnfNodeIndexSet &users,
|
||||
std::vector<CNodePtr> *control_depend_nodes, std::vector<size_t> *control_depend_use_index,
|
||||
bool *is_only_control_depend_use, AnfNodePtr *use_out) {
|
||||
for (auto &user : users) {
|
||||
auto use_node = user.first;
|
||||
if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) {
|
||||
*is_only_control_depend_use = false;
|
||||
continue;
|
||||
}
|
||||
if (outputs_set.count(use_node) != 0) {
|
||||
*use_out = use_node;
|
||||
}
|
||||
if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) {
|
||||
control_depend_nodes->push_back(use_node->cast<CNodePtr>());
|
||||
control_depend_use_index->push_back(user.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_outputs, const FuncGraphManagerPtr &mng,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *eqv) {
|
||||
AnfNodeSet outputs_set;
|
||||
for (auto out : *outputs) {
|
||||
outputs_set.insert(out);
|
||||
}
|
||||
|
||||
AnfNodePtrList vir_outputs;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> eqv;
|
||||
auto fg_outputs = fg->output();
|
||||
if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) {
|
||||
auto cnode = fg_outputs->cast<CNodePtr>();
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
vir_outputs.push_back(cnode->input(i));
|
||||
}
|
||||
} else {
|
||||
vir_outputs.push_back(fg_outputs);
|
||||
}
|
||||
|
||||
if (vir_outputs.size() != outputs->size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output";
|
||||
}
|
||||
bool has_erase_outs = false;
|
||||
size_t index = -1;
|
||||
for (auto it = outputs->begin(); it != outputs->end();) {
|
||||
index++;
|
||||
auto out = *it;
|
||||
eqv[out] = vir_outputs[index];
|
||||
(*eqv)[out] = vir_outputs[index];
|
||||
auto users = mng->node_users()[out];
|
||||
bool is_only_control_depend_use = true;
|
||||
std::vector<size_t> control_depend_use_index;
|
||||
std::vector<CNodePtr> control_depend_nodes;
|
||||
AnfNodePtr use_out = nullptr;
|
||||
for (auto &user : users) {
|
||||
auto use_node = user.first;
|
||||
if (outputs_set.count(use_node) == 0 && !(IsPrimitiveCNode(use_node, prim::kPrimControlDepend))) {
|
||||
is_only_control_depend_use = false;
|
||||
continue;
|
||||
}
|
||||
if (outputs_set.count(use_node) != 0) {
|
||||
use_out = use_node;
|
||||
}
|
||||
|
||||
if (IsPrimitiveCNode(use_node, prim::kPrimControlDepend)) {
|
||||
control_depend_nodes.push_back(use_node->cast<CNodePtr>());
|
||||
control_depend_use_index.push_back(user.second);
|
||||
}
|
||||
}
|
||||
|
||||
SearchForDependNode(outputs_set, users, &control_depend_nodes, &control_depend_use_index,
|
||||
&is_only_control_depend_use, &use_out);
|
||||
if (is_only_control_depend_use && !control_depend_nodes.empty()) {
|
||||
MS_EXCEPTION_IF_NULL(use_out);
|
||||
it = outputs->erase(it);
|
||||
|
@ -142,8 +132,27 @@ void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, con
|
|||
it++;
|
||||
}
|
||||
}
|
||||
return has_erase_outs;
|
||||
}
|
||||
|
||||
if (!has_erase_outs) {
|
||||
void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) {
|
||||
AnfNodePtrList vir_outputs;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> eqv;
|
||||
auto fg_outputs = fg->output();
|
||||
if (IsPrimitiveCNode(fg_outputs, prim::kPrimMakeTuple)) {
|
||||
auto cnode = fg_outputs->cast<CNodePtr>();
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
vir_outputs.push_back(cnode->input(i));
|
||||
}
|
||||
} else {
|
||||
vir_outputs.push_back(fg_outputs);
|
||||
}
|
||||
|
||||
if (vir_outputs.size() != outputs->size()) {
|
||||
MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output";
|
||||
}
|
||||
|
||||
if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue