forked from mindspore-Ecosystem/mindspore
!12925 [MS][Maskrcnn][310infer]Maskrcnn 310 infer failed in halfway
From: @lanzhineng Reviewed-by: Signed-off-by:
This commit is contained in:
commit
2b2964e6dd
|
@ -53,6 +53,37 @@ using Constant = ge::op::Constant;
|
|||
using Assign = ge::op::Assign;
|
||||
using Data = ge::op::Data;
|
||||
|
||||
namespace {
|
||||
std::vector<AnfNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
|
||||
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
|
||||
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
|
||||
std::vector<AnfNodePtr> vecs;
|
||||
if (node == nullptr) {
|
||||
return vecs;
|
||||
}
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto &inputs = cnode->inputs();
|
||||
// Check if free variables used.
|
||||
for (const auto &input : inputs) {
|
||||
auto input_fg = GetValueNode<FuncGraphPtr>(input);
|
||||
if (input_fg) {
|
||||
for (auto &fv : input_fg->free_variables_nodes()) {
|
||||
if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
|
||||
vecs.push_back(fv);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
(void)vecs.insert(vecs.end(), inputs.begin(), inputs.end());
|
||||
}
|
||||
return vecs;
|
||||
};
|
||||
|
||||
return TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// ---------------implement of DfGraphConvertor-------------
|
||||
PrimType GetCNodeFuncType(const CNodePtr cnode) {
|
||||
if (cnode->inputs().empty()) {
|
||||
|
@ -214,7 +245,7 @@ void DfGraphConvertor::DrawParamInitSubGraph(const std::string &name, const AnfN
|
|||
|
||||
void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector<ge::Operator> *init_input) {
|
||||
DfGraphPtr init_graph = std::make_shared<DfGraph>("init");
|
||||
std::vector<AnfNodePtr> nodes = TopoSort(anf_graph_->get_return());
|
||||
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
|
||||
|
||||
for (auto &it : nodes) {
|
||||
if (it->isa<ValueNode>()) {
|
||||
|
@ -549,7 +580,7 @@ DfGraphConvertor &DfGraphConvertor::ConvertAllNode() {
|
|||
|
||||
// Convert all anf node to Operator
|
||||
MS_LOG(DEBUG) << "convert all node";
|
||||
std::vector<AnfNodePtr> nodes = TopoSort(anf_graph_->get_return());
|
||||
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
|
||||
for (auto &it : nodes) {
|
||||
(void)Convert(it);
|
||||
if (this->error_ != 0) {
|
||||
|
@ -811,7 +842,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
|||
}
|
||||
|
||||
// Case node set input.
|
||||
std::vector<AnfNodePtr> nodes = ::mindspore::TopoSort(anf_graph_->get_return());
|
||||
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
|
||||
for (auto &it : nodes) {
|
||||
if (it->isa<CNode>() && IsCaseNode(it->cast<CNodePtr>())) {
|
||||
auto node = it->cast<CNodePtr>();
|
||||
|
@ -825,7 +856,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
|||
|
||||
// set up dependencies
|
||||
MS_LOG(DEBUG) << "set up dependencies";
|
||||
nodes = ::mindspore::TopoSort(anf_graph_->get_return());
|
||||
nodes = GetOrderedCNodes(anf_graph_);
|
||||
for (auto &it : nodes) {
|
||||
SetNodeInput(it);
|
||||
SetOpControlInput(it);
|
||||
|
@ -1195,6 +1226,51 @@ void DfGraphConvertor::SetTupleOpInput(const OpAdapterPtr &adpt, const CNodePtr
|
|||
}
|
||||
MS_LOG(WARNING) << "This anf node is not supported as a tuple item : " << node->ToString();
|
||||
}
|
||||
AnfNodePtr DfGraphConvertor::GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input) {
|
||||
if (input == nullptr || node == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr pred = input;
|
||||
while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == prim::kPrimDepend->name()) {
|
||||
pred = pred->cast<CNodePtr>()->input(1);
|
||||
}
|
||||
|
||||
// skip input of UMonad, IOMonad
|
||||
if (IsValueNode<UMonad>(pred) || IsValueNode<IOMonad>(pred)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// skip input of the None, UpdateState
|
||||
if (IsValueNode<None>(pred) || IsPrimitiveCNode(pred, prim::kPrimUpdateState)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (IsPrimitiveCNode(pred, prim::kPrimLoad)) {
|
||||
pred = ParseLoadInput(pred->cast<CNodePtr>());
|
||||
}
|
||||
|
||||
// transform "Const" op to "Variable" op when the next node is "Assign" op.
|
||||
std::string c_name = GetCNodeTargetFuncName(node);
|
||||
auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name);
|
||||
if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) {
|
||||
std::string name = std::static_pointer_cast<Parameter>(pred)->name();
|
||||
auto op_itor = op_cache_.find(pred.get());
|
||||
if (op_itor == op_cache_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << ".";
|
||||
}
|
||||
if (op_itor->second != nullptr &&
|
||||
(op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") &&
|
||||
vars_.find(name) != vars_.end()) {
|
||||
auto variable = std::make_shared<Variable>(name);
|
||||
auto desc = vars_[name]->GetOutputDesc("y");
|
||||
(void)variable->update_output_desc_y(desc);
|
||||
MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << ".";
|
||||
op_itor->second = variable; // replace parameter with variable
|
||||
vars_[name] = variable;
|
||||
}
|
||||
}
|
||||
return pred;
|
||||
}
|
||||
|
||||
void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
|
||||
OperatorPtr src = Convert(node);
|
||||
|
@ -1213,45 +1289,11 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
|
|||
} else {
|
||||
pred = inputs[i];
|
||||
}
|
||||
|
||||
while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == prim::kPrimDepend->name()) {
|
||||
pred = pred->cast<CNodePtr>()->input(1);
|
||||
}
|
||||
|
||||
// skip input of UMonad, IOMonad
|
||||
if (IsValueNode<UMonad>(pred) || IsValueNode<IOMonad>(pred)) {
|
||||
pred = GetRealInputNode(node, pred);
|
||||
if (pred == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// skip input of the None, Load, UpdateState
|
||||
if (IsValueNode<None>(pred) || IsPrimitiveCNode(pred, prim::kPrimUpdateState)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (IsPrimitiveCNode(pred, prim::kPrimLoad)) {
|
||||
pred = ParseLoadInput(pred->cast<CNodePtr>());
|
||||
}
|
||||
|
||||
// transform "Const" op to "Variable" op when the next node is "Assign" op.
|
||||
std::string c_name = GetCNodeTargetFuncName(node);
|
||||
auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name);
|
||||
if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) {
|
||||
std::string name = std::static_pointer_cast<Parameter>(pred)->name();
|
||||
auto op_itor = op_cache_.find(pred.get());
|
||||
if (op_itor == op_cache_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << ".";
|
||||
}
|
||||
if (op_itor->second != nullptr &&
|
||||
(op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") &&
|
||||
vars_.find(name) != vars_.end()) {
|
||||
auto variable = std::make_shared<Variable>(name);
|
||||
auto desc = vars_[name]->GetOutputDesc("y");
|
||||
(void)variable->update_output_desc_y(desc);
|
||||
MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << ".";
|
||||
op_itor->second = variable; // replace parameter with variable
|
||||
vars_[name] = variable;
|
||||
}
|
||||
}
|
||||
int index = SizeToInt(i);
|
||||
// find in out_hadnle_cache_ first
|
||||
auto it = out_handle_cache_.find(pred.get());
|
||||
|
|
|
@ -185,6 +185,7 @@ class DfGraphConvertor {
|
|||
void SetTupleOpInput(const OpAdapterPtr &adpt, const CNodePtr &node, const AnfNodePtr &pred, const OperatorPtr &src,
|
||||
int index);
|
||||
void UpdateTupleOutCache(void);
|
||||
AnfNodePtr GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input);
|
||||
|
||||
std::shared_ptr<AnfGraph> anf_graph_{nullptr};
|
||||
std::shared_ptr<DfGraph> df_graph_{nullptr};
|
||||
|
|
Loading…
Reference in New Issue