From c8899e1f5951509752d3748f7defd69068338409 Mon Sep 17 00:00:00 2001 From: lanzhineng Date: Mon, 8 Mar 2021 14:28:48 +0800 Subject: [PATCH] fix 310 maskrcnn infer failed in halway --- mindspore/ccsrc/transform/graph_ir/convert.cc | 122 ++++++++++++------ mindspore/ccsrc/transform/graph_ir/convert.h | 1 + 2 files changed, 83 insertions(+), 40 deletions(-) diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc index b64ef4a89dc..6f802649f29 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.cc +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -53,6 +53,37 @@ using Constant = ge::op::Constant; using Assign = ge::op::Assign; using Data = ge::op::Data; +namespace { +std::vector GetOrderedCNodes(const FuncGraphPtr fg) { + auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1); + auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector { + std::vector vecs; + if (node == nullptr) { + return vecs; + } + if (node->isa()) { + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + // Check if free variables used. + for (const auto &input : inputs) { + auto input_fg = GetValueNode(input); + if (input_fg) { + for (auto &fv : input_fg->free_variables_nodes()) { + if (fv->func_graph() == fg && fg->nodes().contains(fv)) { + vecs.push_back(fv); + } + } + } + } + (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); + } + return vecs; + }; + + return TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph); +} +} // namespace + // ---------------implement of DfGraphConvertor------------- PrimType GetCNodeFuncType(const CNodePtr cnode) { if (cnode->inputs().empty()) { @@ -214,7 +245,7 @@ void DfGraphConvertor::DrawParamInitSubGraph(const std::string &name, const AnfN void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector *init_input) { DfGraphPtr init_graph = std::make_shared("init"); - std::vector nodes = TopoSort(anf_graph_->get_return()); + std::vector nodes = GetOrderedCNodes(anf_graph_); for (auto &it : nodes) { if (it->isa()) { @@ -549,7 +580,7 @@ DfGraphConvertor &DfGraphConvertor::ConvertAllNode() { // Convert all anf node to Operator MS_LOG(DEBUG) << "convert all node"; - std::vector nodes = TopoSort(anf_graph_->get_return()); + std::vector nodes = GetOrderedCNodes(anf_graph_); for (auto &it : nodes) { (void)Convert(it); if (this->error_ != 0) { @@ -811,7 +842,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { } // Case node set input. - std::vector nodes = ::mindspore::TopoSort(anf_graph_->get_return()); + std::vector nodes = GetOrderedCNodes(anf_graph_); for (auto &it : nodes) { if (it->isa() && IsCaseNode(it->cast())) { auto node = it->cast(); @@ -825,7 +856,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { // set up dependencies MS_LOG(DEBUG) << "set up dependencies"; - nodes = ::mindspore::TopoSort(anf_graph_->get_return()); + nodes = GetOrderedCNodes(anf_graph_); for (auto &it : nodes) { SetNodeInput(it); SetOpControlInput(it); @@ -1195,6 +1226,51 @@ void DfGraphConvertor::SetTupleOpInput(const OpAdapterPtr &adpt, const CNodePtr } MS_LOG(WARNING) << "This anf node is not supported as a tuple item : " << node->ToString(); } +AnfNodePtr DfGraphConvertor::GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input) { + if (input == nullptr || node == nullptr) { + return nullptr; + } + AnfNodePtr pred = input; + while (pred->isa() && GetCNodeTargetFuncName(pred->cast()) == prim::kPrimDepend->name()) { + pred = pred->cast()->input(1); + } + + // skip input of UMonad, IOMonad + if (IsValueNode(pred) || IsValueNode(pred)) { + return nullptr; + } + + // skip input of the None, UpdateState + if (IsValueNode(pred) || IsPrimitiveCNode(pred, prim::kPrimUpdateState)) { + return nullptr; + } + + if (IsPrimitiveCNode(pred, prim::kPrimLoad)) { + pred = ParseLoadInput(pred->cast()); + } + + // transform "Const" op to "Variable" op when the next node is "Assign" op. + std::string c_name = GetCNodeTargetFuncName(node); + auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); + if (!training_ && pos != trans_var_list.end() && pred->isa()) { + std::string name = std::static_pointer_cast(pred)->name(); + auto op_itor = op_cache_.find(pred.get()); + if (op_itor == op_cache_.end()) { + MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << "."; + } + if (op_itor->second != nullptr && + (op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") && + vars_.find(name) != vars_.end()) { + auto variable = std::make_shared(name); + auto desc = vars_[name]->GetOutputDesc("y"); + (void)variable->update_output_desc_y(desc); + MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << "."; + op_itor->second = variable; // replace parameter with variable + vars_[name] = variable; + } + } + return pred; +} void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) { OperatorPtr src = Convert(node); @@ -1213,45 +1289,11 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node } else { pred = inputs[i]; } - - while (pred->isa() && GetCNodeTargetFuncName(pred->cast()) == prim::kPrimDepend->name()) { - pred = pred->cast()->input(1); - } - - // skip input of UMonad, IOMonad - if (IsValueNode(pred) || IsValueNode(pred)) { + pred = GetRealInputNode(node, pred); + if (pred == nullptr) { continue; } - // skip input of the None, Load, UpdateState - if (IsValueNode(pred) || IsPrimitiveCNode(pred, prim::kPrimUpdateState)) { - continue; - } - - if (IsPrimitiveCNode(pred, prim::kPrimLoad)) { - pred = ParseLoadInput(pred->cast()); - } - - // transform "Const" op to "Variable" op when the next node is "Assign" op. - std::string c_name = GetCNodeTargetFuncName(node); - auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); - if (!training_ && pos != trans_var_list.end() && pred->isa()) { - std::string name = std::static_pointer_cast(pred)->name(); - auto op_itor = op_cache_.find(pred.get()); - if (op_itor == op_cache_.end()) { - MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << "."; - } - if (op_itor->second != nullptr && - (op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") && - vars_.find(name) != vars_.end()) { - auto variable = std::make_shared(name); - auto desc = vars_[name]->GetOutputDesc("y"); - (void)variable->update_output_desc_y(desc); - MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << "."; - op_itor->second = variable; // replace parameter with variable - vars_[name] = variable; - } - } int index = SizeToInt(i); // find in out_hadnle_cache_ first auto it = out_handle_cache_.find(pred.get()); diff --git a/mindspore/ccsrc/transform/graph_ir/convert.h b/mindspore/ccsrc/transform/graph_ir/convert.h index c40efd876af..9789dd9467c 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.h +++ b/mindspore/ccsrc/transform/graph_ir/convert.h @@ -185,6 +185,7 @@ class DfGraphConvertor { void SetTupleOpInput(const OpAdapterPtr &adpt, const CNodePtr &node, const AnfNodePtr &pred, const OperatorPtr &src, int index); void UpdateTupleOutCache(void); + AnfNodePtr GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input); std::shared_ptr anf_graph_{nullptr}; std::shared_ptr df_graph_{nullptr};