From d4671497e91e865fbfe0e8fd7bfc971c364276c7 Mon Sep 17 00:00:00 2001 From: yeyunpeng Date: Mon, 10 Aug 2020 12:41:49 +0800 Subject: [PATCH] fix op multi output problem --- mindspore/lite/schema/model.fbs | 4 +- mindspore/lite/schema/ops.fbs | 6 + .../src/common/anf_exporter/anf_exporter.cc | 130 ++++++++++-------- .../src/common/anf_exporter/anf_exporter.h | 5 +- .../anf_importer/import_from_meta_graphT.cc | 45 ++++-- 5 files changed, 119 insertions(+), 71 deletions(-) diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index cb1bad7031d..6d00ff472e6 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -189,7 +189,9 @@ union PrimitiveType { ActivationGrad, PriorBox, SpaceToBatchND, - TopKV2 + TopKV2, + Return, + MakeTuple } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index a7fd9a5bf58..27b3b942d57 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -864,3 +864,9 @@ table TopKV2 { sorted : bool = true; } + +table MakeTuple { +} + +table Return { +} \ No newline at end of file diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc index d9653a9e0d4..db50b02825a 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -81,8 +81,7 @@ bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) { return false; } ValueNodePtr valueNode = utils::cast(indexNode); - mapRemoveGetItem_[tupleGetItemNode->input(1)->fullname_with_scope()] = - GetValue(valueNode->value()); + mapRemoveGetItem_[tupleGetItemNode->input(1)->fullname_with_scope()] = GetValue(valueNode->value()); } else { inputs.emplace_back(cnode->input(i)); } @@ -114,16 +113,34 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { auto metaGraphT = std::make_unique(); for (const auto &cnode : cnodes) { auto primitive = GetValueNode(cnode->input(0)); - if (primitive != nullptr && - RemoveNodeInAnfExporter.count(primitive->name()) != 0) { - continue; + if (primitive != nullptr) { + if (RemoveNodeInAnfExporter.count(primitive->name()) != 0) { + continue; + } + } else { + auto primitiveT_value = GetValueNode>(cnode->input(0)); + auto primT = primitiveT_value->GetPrimitiveT(); + if (primT->value.type == schema::PrimitiveType_TupleGetItem || + primT->value.type == schema::PrimitiveType_MakeTuple) { + continue; + } } mapRemoveGetItem_.clear(); RemoveIfMakeTuple(cnode); RemoveIfTupleGetItem(cnode); - if (primitive != nullptr && primitive->name() == prim::kPrimReturn->name()) { - AddOutPutIfReturn(metaGraphT, cnode); - continue; + + if (primitive != nullptr) { + if (primitive->name() == prim::kPrimReturn->name()) { + AddOutPutIfReturn(metaGraphT, cnode); + continue; + } + } else { + auto primitiveT_value = GetValueNode>(cnode->input(0)); + auto primT = primitiveT_value->GetPrimitiveT(); + if (primT->value.type == schema::PrimitiveType_Return) { + AddOutPutIfReturn(metaGraphT, cnode); + continue; + } } auto node = std::make_unique(); @@ -134,27 +151,24 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { primitive = GetValueNode(cnode->input(0)); MS_ASSERT(primitive != nullptr); std::string opType = primitive->name(); - auto nodeParser = - AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); + auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); if (nodeParser == nullptr) { MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; return nullptr; } std::vector outputs; if (utils::isa(cnode->abstract())) { - auto abstract_cnode = - utils::cast(cnode->abstract()); + auto abstract_cnode = utils::cast(cnode->abstract()); outputs.resize(abstract_cnode->size()); } nodeParser->Parse(cnode, node.get(), &outputs); SetOpInputNode(cnode, metaGraphT.get(), node.get()); - SetOpOutputNode(outputs, metaGraphT.get(), node.get()); + SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get()); metaGraphT->nodes.emplace_back(std::move(node)); continue; } - auto primitiveT_value = - GetValueNode>(cnode->input(0)); + auto primitiveT_value = GetValueNode>(cnode->input(0)); if (primitiveT_value == nullptr) { MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; return nullptr; @@ -166,11 +180,10 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { return nullptr; } - node->primitive = - std::unique_ptr(primitiveT_value->GetPrimitiveT()); + node->primitive = std::unique_ptr(primitiveT_value->GetPrimitiveT()); std::vector outputs; SetOpInputNode(cnode, metaGraphT.get(), node.get()); - SetOpOutputNode(outputs, metaGraphT.get(), node.get()); + SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get()); // add quant param node->quantType = primitiveT_value->GetQuantType(); @@ -244,9 +257,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { return metaGraphT.release(); } -void AnfExporter::SetOpInputNode(const CNodePtr &cnode, - schema::MetaGraphT *meta_graph, - schema::CNodeT *fbNode) { +void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode) { MS_ASSERT(nullptr != meta_graph); MS_ASSERT(nullptr != fbNode); if (cnode->inputs().size() <= 1) { @@ -281,38 +292,30 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, auto paramTensor = std::make_unique(); auto abstractBase = paramNode->abstract(); if (abstractBase == nullptr) { - MS_LOG(ERROR) << "Abstract of parameter is nullptr, " - << paramNode->name(); + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name(); MS_ASSERT(false); return; } if (!utils::isa(abstractBase)) { - MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " - << paramNode->name(); + MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << paramNode->name(); MS_ASSERT(false); return; } - auto abstractTensor = - utils::cast(abstractBase); + auto abstractTensor = utils::cast(abstractBase); auto typePtr = abstractTensor->element()->GetTypeTrack(); MS_ASSERT(typePtr != nullptr); paramTensor->dataType = typePtr->type_id(); if (!utils::isa(abstractTensor->BuildShape())) { - MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " - << paramNode->name(); + MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name(); MS_ASSERT(false); return; } - paramTensor->dims = - utils::cast(abstractTensor->BuildShape()) - ->shape(); - auto paramValue = - std::dynamic_pointer_cast(paramNode->default_param()); + paramTensor->dims = utils::cast(abstractTensor->BuildShape())->shape(); + auto paramValue = std::dynamic_pointer_cast(paramNode->default_param()); if (paramValue != nullptr) { paramTensor->nodeType = schema::NodeType_ValueNode; paramTensor->data.resize(paramValue->tensor_size()); - memcpy(paramTensor->data.data(), paramValue->tensor_addr(), - paramValue->tensor_size()); + memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); for (auto &ite : paramValue->quant_param()) { auto quantPar = std::make_unique(); quantPar->scale = ite->scale; @@ -326,8 +329,7 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, paramTensor->dataType = paramValue->tensor_type(); } } - nodeIdMap[paramNode->fullname_with_scope()] = - meta_graph->allTensors.size(); + nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size(); fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); meta_graph->allTensors.emplace_back(std::move(paramTensor)); } else if (inputNode->isa()) { @@ -336,19 +338,15 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, auto value = valueNode->value(); if (value->isa()) { auto valueAbstract = valueNode->abstract(); - auto abstractTensor = - utils::cast(valueAbstract); + auto abstractTensor = utils::cast(valueAbstract); auto typePtr = abstractTensor->element()->GetTypeTrack(); paramTensor->dataType = typePtr->type_id(); - paramTensor->dims = - utils::cast(abstractTensor->BuildShape()) - ->shape(); + paramTensor->dims = utils::cast(abstractTensor->BuildShape())->shape(); paramTensor->nodeType = schema::NodeType_ValueNode; auto data = value->cast(); paramTensor->data.resize(data->Size()); memcpy(paramTensor->data.data(), data->Data(), data->Size()); - nodeIdMap[valueNode->fullname_with_scope()] = - meta_graph->allTensors.size(); + nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size(); fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); meta_graph->allTensors.emplace_back(std::move(paramTensor)); } else if (value->isa()) { @@ -376,30 +374,44 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, } } -void AnfExporter::SetOpOutputNode( - const std::vector &outputTensors, - schema::MetaGraphT *graph, schema::CNodeT *cnode) { +void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::vector &outputTensors, + schema::MetaGraphT *graph, schema::CNodeT *fbnode) { MS_ASSERT(nullptr != graph); - MS_ASSERT(nullptr != cnode); - std::string cnodeName = cnode->name; + MS_ASSERT(nullptr != fbnode); + std::string cnodeName = fbnode->name; if (!outputTensors.empty()) { int i = 0; for (auto outputTensor : outputTensors) { std::string name = cnodeName + "_o:" + std::to_string(i); - auto msTensor = new schema::TensorT(); - msTensor->nodeType = schema::NodeType_Parameter; nodeIdMap[name] = graph->allTensors.size(); - cnode->outputIndex.emplace_back(graph->allTensors.size()); - graph->allTensors.emplace_back(msTensor); + fbnode->outputIndex.emplace_back(graph->allTensors.size()); + graph->allTensors.emplace_back(outputTensor); i++; } return; } - auto msTensor = new schema::TensorT(); - msTensor->nodeType = schema::NodeType_Parameter; - cnode->outputIndex.emplace_back(graph->allTensors.size()); - nodeIdMap[cnodeName] = graph->allTensors.size(); - graph->allTensors.emplace_back(msTensor); + + if (utils::isa(cnode->abstract())) { + auto tuple = std::reinterpret_pointer_cast(cnode->abstract()); + for (int i = 0; i < tuple->size(); i++) { + auto msTensor = new schema::TensorT(); + msTensor->nodeType = schema::NodeType_Parameter; + fbnode->outputIndex.emplace_back(graph->allTensors.size()); + if (tuple->size() == 1) { + nodeIdMap[cnodeName] = graph->allTensors.size(); + } else { + std::string name = cnodeName + "_o:" + std::to_string(i); + nodeIdMap[name] = graph->allTensors.size(); + } + graph->allTensors.emplace_back(msTensor); + } + } else { + auto msTensor = new schema::TensorT(); + msTensor->nodeType = schema::NodeType_Parameter; + fbnode->outputIndex.emplace_back(graph->allTensors.size()); + nodeIdMap[cnodeName] = graph->allTensors.size(); + graph->allTensors.emplace_back(msTensor); + } } schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph) { diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.h b/mindspore/lite/src/common/anf_exporter/anf_exporter.h index 8cb04e9d72d..c4f5d7a3989 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.h +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.h @@ -32,8 +32,8 @@ class AnfExporter { AnfExporter() = default; virtual ~AnfExporter() = default; schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph); - void SetOpOutputNode(const std::vector &outputTensors, schema::MetaGraphT *graph, - schema::CNodeT *cnode); + void SetOpOutputNode(const CNodePtr &cnode, const std::vector &outputTensors, + schema::MetaGraphT *graph, schema::CNodeT *fbnode); void SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode); void RemoveIfMakeTuple(const CNodePtr &cnode); bool RemoveIfTupleGetItem(const CNodePtr &cnode); @@ -47,4 +47,3 @@ class AnfExporter { schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph); } // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ - diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc index c470d6a6e30..703e0b0715f 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc +++ b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc @@ -71,11 +71,11 @@ int AnfImporterFromMetaGraphT::ConverterCNode() { for (size_t i = 0; i < meta_graph_->nodes.size(); i++) { auto &cNode = meta_graph_->nodes.at(i); MS_EXCEPTION_IF_NULL(cNode); - auto tensor_id = cNode->outputIndex.front(); - if (nullptr != GetNode(tensor_id)) { - continue; - } + bool flag = false; + if (cNode->outputIndex.size() > 1) { + flag = true; + } auto primTValue = std::make_shared(cNode->primitive.release()); cNode->primitive = nullptr; auto value_node = NewValueNode(primTValue); @@ -90,9 +90,39 @@ int AnfImporterFromMetaGraphT::ConverterCNode() { // todo: CheckInputNodeType, the first node should be op; op_inputs.push_back(node); } - auto cnode = func_graph_->NewCNode(op_inputs); - cnode->set_fullname_with_scope(cNode->name); - AddNode(tensor_id, cnode); + + auto new_cnode = func_graph_->NewCNode(op_inputs); + new_cnode->set_fullname_with_scope(cNode->name); + + std::vector out_tensor_ids = cNode->outputIndex; + + AbstractBasePtrList ptr_list; + int total = 0; + for (auto out_tensor_id : out_tensor_ids) { + if (nullptr != GetNode(out_tensor_id)) { + ptr_list.push_back(GetNode(out_tensor_id)->abstract()); + continue; + } + std::vector shape; + auto &tensor = meta_graph_->allTensors.at(out_tensor_id); + for (int &dim : tensor->dims) { + shape.push_back(dim); + } + auto type_id = static_cast(tensor->dataType); + auto type_ptr = TypeIdToType(type_id); + auto abstract_tensor = std::make_shared(type_ptr, shape); + auto getItemPrim = NewValueNode(prim::kPrimTupleGetItem); + if (flag) { + auto getItemIndex = NewValueNode(MakeValue(total++)); + std::vector inputs{getItemPrim, new_cnode, getItemIndex}; + CNodePtr new_item_cnode = func_graph_->NewCNode(inputs); + AddNode(out_tensor_id, new_item_cnode); + } else { + AddNode(out_tensor_id, new_cnode); + } + ptr_list.push_back(std::move(abstract_tensor)); + } + new_cnode->set_abstract(std::make_shared(ptr_list)); } return RET_OK; } @@ -120,4 +150,3 @@ void AnfImporterFromMetaGraphT::AddReturnCNode() { FuncGraphPtr AnfImporterFromMetaGraphT::GetResult() { return this->func_graph_; } } // namespace mindspore::lite -