From 17cf494f24942c24b896eb67d68c0a96e647f9a5 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Wed, 23 Sep 2020 17:54:21 +0800 Subject: [PATCH] set ms file permission and output all unsupported nodename --- .../lite/tools/anf_exporter/anf_exporter.h | 4 +- .../anf_importer/import_from_protobuf.cc | 34 ++++++++----- .../tools/anf_importer/import_from_protobuf.h | 5 +- mindspore/lite/tools/common/storage.cc | 8 ++- .../lite/tools/converter/anf_transform.h | 2 +- mindspore/lite/tools/converter/converter.cc | 1 + mindspore/lite/tools/converter/converter.h | 2 +- .../{return_code.h => converter_context.h} | 31 ++++++++--- mindspore/lite/tools/converter/model_parser.h | 4 +- .../parser/caffe/caffe_model_parser.cc | 43 ++++++++++------ .../parser/onnx/onnx_model_parser.cc | 51 +++++++++++++------ .../parser/tflite/tflite_arithmetic_parser.cc | 10 ++++ .../parser/tflite/tflite_arithmetic_parser.h | 5 ++ .../parser/tflite/tflite_model_parser.cc | 32 +++++++----- .../converter/parser/tflite/tflite_util.cc | 1 + .../lite/tools/optimizer/common/gllo_utils.h | 8 +-- .../optimizer/fusion/conv_biasadd_fusion.h | 2 +- 17 files changed, 166 insertions(+), 77 deletions(-) rename mindspore/lite/tools/converter/{return_code.h => converter_context.h} (64%) diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.h b/mindspore/lite/tools/anf_exporter/anf_exporter.h index b9b92ffb416..d0a5c7cadc6 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.h +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.h @@ -24,7 +24,7 @@ #include "schema/inner/model_generated.h" #include "src/ops/primitive_c.h" #include "ir/func_graph.h" -#include "tools/converter/return_code.h" +#include "tools/converter/converter_context.h" namespace mindspore::lite { class AnfExporter { @@ -47,7 +47,7 @@ class AnfExporter { const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode); void SetGraphInputIndex(const std::unique_ptr &meta_graphT); int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, - schema::CNodeT *return_node); + schema::CNodeT *return_node); bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); int ConvertQuantParam(const std::unique_ptr &meta_graph, const std::shared_ptr primitive, const std::unique_ptr &dst_node); diff --git a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc index 3352aaf022b..1194e062ba3 100644 --- a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc @@ -202,7 +202,7 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, - const onnx::ValueInfoProto &value_proto) { + const onnx::ValueInfoProto &value_proto) { if (node == nullptr) { return RET_NULL_PTR; } @@ -273,7 +273,7 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node } int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto) { + const onnx::GraphProto &importProto) { if (outputFuncGraph == nullptr) { return RET_NULL_PTR; } @@ -557,6 +557,7 @@ std::unordered_map AnfImporterFromProt CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, const schema::QuantType &quantType) { + static bool interrupt = false; if (outputFuncGraph == nullptr) { MS_LOG(ERROR) << "output funcgraph is nullptr"; return nullptr; @@ -600,13 +601,17 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out inputs.push_back(anfnode_build_map_[input_name]); } auto primitivec_ptr = PrimitiveC::Create(*prim, inputs, quantType); - if (primitivec_ptr == nullptr) { - MS_LOG(ERROR) << "Create PrimitiveC return nullptr, " << prim->name(); + if (primitivec_ptr == nullptr || interrupt) { + interrupt = true; + if (primitivec_ptr == nullptr) { + NoSupportOp::GetInstance()->InsertOp(prim->name()); + } return nullptr; } inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr)); CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); if (cnode_ptr == nullptr) { + interrupt = true; MS_LOG(ERROR) << "funcgraph new cnode failed"; return nullptr; } @@ -700,40 +705,43 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output } int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto, - const schema::QuantType &quantType) { + const onnx::GraphProto &importProto, + const schema::QuantType &quantType) { if (outputFuncGraph == nullptr) { MS_LOG(ERROR) << "funcgraph is nullptr"; return RET_NULL_PTR; } MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); CNodePtr cnode_ptr = nullptr; + int status = RET_OK; for (int i = 0; i < importProto.node_size(); ++i) { const onnx::NodeProto &node_proto = importProto.node(i); const std::string &node_type = node_proto.op_type(); if (node_type == kConstantValueNode) { - if (!BuildValueNodeForFuncGraph(node_proto)) { + if (status == RET_OK && !BuildValueNodeForFuncGraph(node_proto)) { MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i; - return RET_ERROR; + status = RET_ERROR; } continue; } cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto, quantType); if (cnode_ptr == nullptr) { MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; - return RET_NULL_PTR; + status = (status == RET_OK ? RET_NULL_PTR : status); } } - + if (status != RET_OK) { + return status; + } if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; - return RET_ERROR; + status = RET_ERROR; } - return RET_OK; + return status; } int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, - const schema::QuantType &quantType) { + const schema::QuantType &quantType) { if (outputFuncGraph == nullptr) { MS_LOG(ERROR) << "fundgraph is nullptr"; return RET_NULL_PTR; diff --git a/mindspore/lite/tools/anf_importer/import_from_protobuf.h b/mindspore/lite/tools/anf_importer/import_from_protobuf.h index 84977c37121..03d457bb916 100644 --- a/mindspore/lite/tools/anf_importer/import_from_protobuf.h +++ b/mindspore/lite/tools/anf_importer/import_from_protobuf.h @@ -24,6 +24,7 @@ #include "include/errorcode.h" #include "tools/converter/parser/onnx/onnx.pb.h" +#include "tools/converter/converter_context.h" #include "tools/anf_importer/anf_importer.h" #include "abstract/abstract_value.h" @@ -47,10 +48,10 @@ class AnfImporterFromProtobuf : public AnfImporter { int AddReturnCNode() override { return RET_ERROR; }; int ParseModelConfigureInfo(const onnx::ModelProto &model_proto); int BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, - const schema::QuantType &quantType); + const schema::QuantType &quantType); int ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); int ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, - const schema::QuantType &quantType); + const schema::QuantType &quantType); int BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, const schema::QuantType &quantType); diff --git a/mindspore/lite/tools/common/storage.cc b/mindspore/lite/tools/common/storage.cc index eb53054438b..9d904f4398a 100644 --- a/mindspore/lite/tools/common/storage.cc +++ b/mindspore/lite/tools/common/storage.cc @@ -15,6 +15,8 @@ */ #include "tools/common/storage.h" +#include +#include #include "flatbuffers/flatbuffers.h" #include "utils/log_adapter.h" #include "src/common/file_utils.h" @@ -31,7 +33,10 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath MS_LOG(ERROR) << "GetBufferPointer nullptr"; return RET_ERROR; } - + if (access((outputPath + ".ms").c_str(), F_OK) == 0) { + MS_LOG(WARNING) << "this file " << outputPath << ".ms has been existed"; + chmod((outputPath + ".ms").c_str(), S_IWUSR); + } std::ofstream output(outputPath + ".ms", std::ofstream::binary); if (!output.is_open()) { MS_LOG(ERROR) << "Can not open output file: " << outputPath << ".ms"; @@ -40,6 +45,7 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath output.write((const char *)content, size); output.close(); + chmod((outputPath + ".ms").c_str(), S_IRUSR); return RET_OK; } diff --git a/mindspore/lite/tools/converter/anf_transform.h b/mindspore/lite/tools/converter/anf_transform.h index 597c9a69dd1..86ec6c66e70 100644 --- a/mindspore/lite/tools/converter/anf_transform.h +++ b/mindspore/lite/tools/converter/anf_transform.h @@ -23,7 +23,7 @@ #include "tools/converter/converter_flags.h" #include "ir/anf.h" #include "tools/converter/quantizer/quantizer.h" -#include "tools/converter/return_code.h" +#include "tools/converter/converter_context.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 58cf4536f0b..76f3ace21ab 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -152,6 +152,7 @@ int RunConverter(int argc, const char **argv) { return RET_INPUT_PARAM_INVALID; } } + NoSupportOp::GetInstance()->PrintOps(); status = ReturnCode::GetSingleReturnCode()->GetReturnCode(); if (fb_graph == nullptr) { MS_LOG(ERROR) << "Convert model return nullptr"; diff --git a/mindspore/lite/tools/converter/converter.h b/mindspore/lite/tools/converter/converter.h index f8ef7632a14..d59461f8834 100644 --- a/mindspore/lite/tools/converter/converter.h +++ b/mindspore/lite/tools/converter/converter.h @@ -25,7 +25,7 @@ #include "tools/anf_importer/anf_importer.h" #include "tools/converter/converter_flags.h" #include "tools/converter/anf_transform.h" -#include "tools/converter/return_code.h" +#include "tools/converter/converter_context.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/converter/return_code.h b/mindspore/lite/tools/converter/converter_context.h similarity index 64% rename from mindspore/lite/tools/converter/return_code.h rename to mindspore/lite/tools/converter/converter_context.h index be5df87c76a..23a54b02bc5 100644 --- a/mindspore/lite/tools/converter/return_code.h +++ b/mindspore/lite/tools/converter/converter_context.h @@ -17,13 +17,16 @@ #ifndef LITE_RETURN_CODE_H #define LITE_RETURN_CODE_H +#include +#include #include "include/errorcode.h" +#include "utils/log_adapter.h" namespace mindspore { namespace lite { class ReturnCode { public: - ~ReturnCode() {} + ~ReturnCode() = default; static ReturnCode *GetSingleReturnCode() { static ReturnCode returnCode; return &returnCode; @@ -33,15 +36,31 @@ class ReturnCode { statusCode = status; } } - STATUS GetReturnCode() const { - return statusCode; - } + STATUS GetReturnCode() const { return statusCode; } + private: ReturnCode() { statusCode = RET_OK; } int statusCode; }; + +class NoSupportOp { + public: + ~NoSupportOp() = default; + static NoSupportOp *GetInstance() { + static NoSupportOp noSupportOp; + return &noSupportOp; + } + void InsertOp(const std::string &op_name) { noSupportOps.insert(op_name); } + void PrintOps() const { + for (auto &op_name : noSupportOps) { + MS_LOG(ERROR) << "The op " << op_name << " hasn't been supported"; + } + } + + private: + NoSupportOp() { noSupportOps.clear(); } + std::set noSupportOps; +}; } // namespace lite } // namespace mindspore - #endif // LITE_RETURN_CODE_H - diff --git a/mindspore/lite/tools/converter/model_parser.h b/mindspore/lite/tools/converter/model_parser.h index d29a19c1ded..a5532da640a 100644 --- a/mindspore/lite/tools/converter/model_parser.h +++ b/mindspore/lite/tools/converter/model_parser.h @@ -22,7 +22,7 @@ #include "schema/inner/model_generated.h" #include "tools/anf_importer/import_from_meta_graphT.h" #include "ir/anf.h" -#include "tools/converter/return_code.h" +#include "tools/converter/converter_context.h" namespace mindspore::lite { using namespace schema; @@ -40,7 +40,7 @@ class ModelParser { return nullptr; } auto func_graph = this->Fb2Anf(meta_graph); - delete(meta_graph); + delete (meta_graph); return func_graph; } diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index 5c1eac363a9..7c4aff82b0c 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -84,6 +84,9 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co if (status != RET_OK) { MS_LOG(ERROR) << "ParseLayer failed " << status; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + for (auto &tensor : tensorCache.GetCachedTensor()) { + delete tensor; + } return nullptr; } @@ -179,6 +182,8 @@ STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, T STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight, TensorCache *tensorCache, schema::MetaGraphT *subGraphDef, const QuantType &quantType) { + static bool interrupt = false; + int status = RET_OK; for (int i = 0; i < proto.layer_size(); i++) { auto layer = proto.layer(i); @@ -222,38 +227,46 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff } continue; } - auto status = SetOpInputIdx(layer, op.get(), tensorCache); - if (status != RET_OK) { - MS_LOG(ERROR) << "Set Op " << layer.name() << " Input Index Failed!"; - return status; - } auto nodeParser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type().c_str()); - if (nodeParser == nullptr) { - MS_LOG(ERROR) << "Don't support type " << layer.type() << ". for caffe op " << layer.name(); - return RET_NULL_PTR; + if (nodeParser == nullptr || interrupt) { + interrupt = true; + if (nodeParser == nullptr) { + NoSupportOp::GetInstance()->InsertOp(layer.type()); + status = (status == RET_OK ? RET_NOT_FIND_OP : status); + } + continue; } std::vector weightVec; - status = nodeParser->Parse(layer, layerP, op.get(), &weightVec); - if (status != RET_OK) { + auto status_node = nodeParser->Parse(layer, layerP, op.get(), &weightVec); + if (status_node != RET_OK) { + interrupt = true; MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!"; - return status; + status = (status == RET_OK ? RET_NOT_FIND_OP : status); + continue; } + status_node = SetOpInputIdx(layer, op.get(), tensorCache); + if (status_node != RET_OK) { + MS_LOG(ERROR) << "Set Op " << layer.name() << " Input Index Failed!"; + status = (status == RET_OK ? status_node : status); + } SetWeightTensor(weightVec, op.get(), tensorCache); - status = SetOpOutputIdx(layer, op.get(), tensorCache); - if (status != RET_OK) { + status_node = SetOpOutputIdx(layer, op.get(), tensorCache); + if (status_node != RET_OK) { + interrupt = true; MS_LOG(ERROR) << "Set Op " << layer.name() << " Output Index Failed!"; - return status; + status = (status == RET_OK ? RET_NOT_FIND_OP : status); + continue; } // op->fmkType = FmkType_CAFFE; subGraphDef->nodes.emplace_back(move(op)); } } - return RET_OK; + return status; } STATUS CaffeModelParser::GetModelInput(const caffe::NetParameter &proto, TensorCache *tensorCache) { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index baa41f99215..a367b67605e 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -249,6 +249,7 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache, const QuantType &quantType) { // change op_type() to name(), that is unique + static bool interrupt = false; dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); dst_op->quantType = quantType; // dst_op->fmkType = FmkType_ONNX; @@ -256,15 +257,25 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, << onnx_node.input_size(); // get the real op type SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache); - auto status = ParseOnnxNodeAttr(onnx_graph, onnx_node, onnx_node.op_type(), dst_op); + auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type()); + if (node_parser == nullptr || interrupt) { + interrupt = true; + if (node_parser == nullptr) { + NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); + } + return RET_NOT_FIND_OP; + } + auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op); if (status != RET_OK) { - MS_LOG(ERROR) << "parser onnx node attr failed"; + interrupt = true; + MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed"; return status; } // set op input index std::vector node_inputs; (void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end()); if (SetOpInputIndex(node_inputs, dst_op, onnx_node, tensor_cache)) { + interrupt = true; MS_LOG(ERROR) << "SetOpInputIndex failed"; return RET_ERROR; } @@ -273,6 +284,7 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, (void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end()); if (SetOpOutputIndex(node_outputs, dst_op, tensor_cache) != RET_OK) { + interrupt = true; MS_LOG(ERROR) << "SetOpOutputIndex failed"; return RET_ERROR; } @@ -340,8 +352,7 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co const string &onnx_op_type, schema::CNodeT *dst_op) { auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type); if (node_parser == nullptr) { - MS_LOG(ERROR) << "not find " << onnx_op_type << ", node parser is nullptr"; - return RET_NULL_PTR; + return RET_NOT_FIND_OP; } return node_parser->Parse(onnx_graph, onnx_node, dst_op); } @@ -503,32 +514,42 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con } // init op node input/output tensor, and dst_op attr for (const auto &onnx_node : onnx_graph.node()) { + int status_node = RET_OK; if (onnx_node.op_type() == "Constant") { continue; } if (onnx_node.op_type() == "Gemm") { - ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache); + if (status == RET_OK) { + ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache); + } continue; } else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") { - status = ParseOnnxGivenFillNode(onnx_node, &tensor_cache); - if (status != RET_OK) { - MS_LOG(ERROR) << "ParseOnnxGivenFillNode failed: " << status; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + if (status == RET_OK) { + status_node = ParseOnnxGivenFillNode(onnx_node, &tensor_cache); + if (status_node != RET_OK) { + MS_LOG(ERROR) << "ParseOnnxGivenFillNode failed: " << status_node; + status = (status == RET_OK ? status_node : status); + } } continue; } std::unique_ptr dst_op = std::make_unique(); std::unique_ptr dst_tensor = std::make_unique(); - status = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType); - if (status != RET_OK) { - MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType); + if (status_node != RET_OK) { + status = (status == RET_OK ? status_node : status); + continue; } dst_graph->nodes.emplace_back(std::move(dst_op)); } + if (status != RET_OK) { + ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); + for (auto &tensor : tensor_cache.GetCachedTensor()) { + delete tensor; + } + return nullptr; + } SetAllTensors(tensor_cache, dst_graph.get()); dst_graph->name = GetModelName(modelFile); return dst_graph.release(); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc index 34b2c719058..b10169c9b06 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.cc @@ -300,6 +300,15 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr } op->primitive->value.type = schema::PrimitiveType_Floor; op->primitive->value.value = attr.release(); + } else if (std::strcmp(node_name, "Neg") == 0) { + MS_LOG(DEBUG) << "parse TfliteNegParser"; + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + op->primitive->value.type = schema::PrimitiveType_Neg; + op->primitive->value.value = attr.release(); } AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), @@ -415,6 +424,7 @@ TfliteNodeRegister g_TfliteLogParser("Log", new TfliteLogParser()); TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser()); TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser()); TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser()); +TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteNegParser()); TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser()); TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h index b25e57b1a17..53676f121fe 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_arithmetic_parser.h @@ -157,6 +157,11 @@ class TfliteFloorParser : public TfliteSingleInputOpParser { TfliteFloorParser() : TfliteSingleInputOpParser() {} }; +class TfliteNegParser : public TfliteSingleInputOpParser { + public: + TfliteNegParser() : TfliteSingleInputOpParser() {} +}; + class TfliteCompareOpParser : public TfliteNodeParser { public: TfliteCompareOpParser() : TfliteNodeParser("node_name") {} diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 82e4eb5d06f..d9228a3ef04 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -98,6 +98,7 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflit const std::unique_ptr &tflite_subgraph, const QuantType &quant_type, schema::MetaGraphT *sub_graph) { int idx = 0; + int status = RET_OK; for (const auto &tflite_op : tflite_subgraph->operators) { auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; auto op_type = GetMSOpType(tflite_op_type); @@ -114,21 +115,24 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflit auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type); if (node_parser == nullptr) { - MS_LOG(ERROR) << "cannot find node parser, opType: " << op_type.c_str(); - return RET_NOT_FIND_OP; - } - int status = node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, op.get(), &tensorsId, - &tensorsFormat, &tensorsIdMap); - if (status != RET_OK) { - MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; - return status; + NoSupportOp::GetInstance()->InsertOp(op_type); + status = (status == RET_OK ? RET_NOT_FIND_OP : status); + continue; } + if (status == RET_OK) { + status = node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, op.get(), &tensorsId, + &tensorsFormat, &tensorsIdMap); + if (status != RET_OK) { + MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; + continue; + } - sub_graph->nodes.emplace_back(op.release()); - opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get(); - tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get(); + sub_graph->nodes.emplace_back(op.release()); + opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get(); + tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get(); + } } - return RET_OK; + return status; } STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr &tflite_subgraph, @@ -162,8 +166,8 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr if (isConst) { int status = CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get()); if (status != RET_OK) { - MS_LOG(ERROR) << "obtain const tensor failed"; - return status; + MS_LOG(ERROR) << "obtain const tensor failed"; + return status; } } // set tensor attr diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc index 4d06d00a251..3c1d4fe4fe9 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc @@ -118,6 +118,7 @@ std::map tfMsOpTypeMap{ {tflite::BuiltinOperator_UNPACK, "Unstack"}, {tflite::BuiltinOperator_CUSTOM, "Custom"}, {tflite::BuiltinOperator_MIRROR_PAD, "MirrorPad"}, + {tflite::BuiltinOperator_NEG, "Neg"}, }; std::map tfMsActivationFunctionMap{ diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 1e6a2fbaf16..ebd0f8cf0c7 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -26,7 +26,7 @@ #include "backend/optimizer/common/pattern_engine.h" #include "schema/inner/model_generated.h" #include "src/param_value_lite.h" -#include "tools/converter/return_code.h" +#include "tools/converter/converter_context.h" using PrimitiveCPtr = std::shared_ptr; using mindspore::lite::RET_ERROR; @@ -73,7 +73,7 @@ bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node); size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); -ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node); +ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node); enum kTransFilterType { kKCHW2HWCK, // 0 @@ -105,11 +105,11 @@ STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW); -template +template static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW); -template +template static lite::STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type); STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format); diff --git a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.h index 7f95c4cc109..f3f1b01a21f 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.h @@ -18,7 +18,7 @@ #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ #include "backend/optimizer/common/optimizer.h" -#include "tools/converter/return_code.h" +#include "tools/converter/converter_context.h" namespace mindspore { namespace opt {