set ms file permission and output all unsupported nodename

This commit is contained in:
xuanyue 2020-09-23 17:54:21 +08:00
parent 48a99330d4
commit 17cf494f24
17 changed files with 166 additions and 77 deletions

View File

@ -24,7 +24,7 @@
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "src/ops/primitive_c.h" #include "src/ops/primitive_c.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "tools/converter/return_code.h" #include "tools/converter/converter_context.h"
namespace mindspore::lite { namespace mindspore::lite {
class AnfExporter { class AnfExporter {
@ -47,7 +47,7 @@ class AnfExporter {
const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode);
void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT); void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT);
int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, int SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT,
schema::CNodeT *return_node); schema::CNodeT *return_node);
bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type);
int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph,
const std::shared_ptr<PrimitiveC> primitive, const std::unique_ptr<schema::CNodeT> &dst_node); const std::shared_ptr<PrimitiveC> primitive, const std::unique_ptr<schema::CNodeT> &dst_node);

View File

@ -202,7 +202,7 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64)
PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64)
int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node,
const onnx::ValueInfoProto &value_proto) { const onnx::ValueInfoProto &value_proto) {
if (node == nullptr) { if (node == nullptr) {
return RET_NULL_PTR; return RET_NULL_PTR;
} }
@ -273,7 +273,7 @@ int AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node
} }
int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, int AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto) { const onnx::GraphProto &importProto) {
if (outputFuncGraph == nullptr) { if (outputFuncGraph == nullptr) {
return RET_NULL_PTR; return RET_NULL_PTR;
} }
@ -557,6 +557,7 @@ std::unordered_map<std::string, abstract::AbstractTensorPtr> AnfImporterFromProt
CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::NodeProto &node_proto, const onnx::NodeProto &node_proto,
const schema::QuantType &quantType) { const schema::QuantType &quantType) {
static bool interrupt = false;
if (outputFuncGraph == nullptr) { if (outputFuncGraph == nullptr) {
MS_LOG(ERROR) << "output funcgraph is nullptr"; MS_LOG(ERROR) << "output funcgraph is nullptr";
return nullptr; return nullptr;
@ -600,13 +601,17 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
inputs.push_back(anfnode_build_map_[input_name]); inputs.push_back(anfnode_build_map_[input_name]);
} }
auto primitivec_ptr = PrimitiveC::Create(*prim, inputs, quantType); auto primitivec_ptr = PrimitiveC::Create(*prim, inputs, quantType);
if (primitivec_ptr == nullptr) { if (primitivec_ptr == nullptr || interrupt) {
MS_LOG(ERROR) << "Create PrimitiveC return nullptr, " << prim->name(); interrupt = true;
if (primitivec_ptr == nullptr) {
NoSupportOp::GetInstance()->InsertOp(prim->name());
}
return nullptr; return nullptr;
} }
inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr)); inputs.insert(inputs.begin(), NewValueNode(primitivec_ptr));
CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs);
if (cnode_ptr == nullptr) { if (cnode_ptr == nullptr) {
interrupt = true;
MS_LOG(ERROR) << "funcgraph new cnode failed"; MS_LOG(ERROR) << "funcgraph new cnode failed";
return nullptr; return nullptr;
} }
@ -700,40 +705,43 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &output
} }
int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph,
const onnx::GraphProto &importProto, const onnx::GraphProto &importProto,
const schema::QuantType &quantType) { const schema::QuantType &quantType) {
if (outputFuncGraph == nullptr) { if (outputFuncGraph == nullptr) {
MS_LOG(ERROR) << "funcgraph is nullptr"; MS_LOG(ERROR) << "funcgraph is nullptr";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
CNodePtr cnode_ptr = nullptr; CNodePtr cnode_ptr = nullptr;
int status = RET_OK;
for (int i = 0; i < importProto.node_size(); ++i) { for (int i = 0; i < importProto.node_size(); ++i) {
const onnx::NodeProto &node_proto = importProto.node(i); const onnx::NodeProto &node_proto = importProto.node(i);
const std::string &node_type = node_proto.op_type(); const std::string &node_type = node_proto.op_type();
if (node_type == kConstantValueNode) { if (node_type == kConstantValueNode) {
if (!BuildValueNodeForFuncGraph(node_proto)) { if (status == RET_OK && !BuildValueNodeForFuncGraph(node_proto)) {
MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i; MS_LOG(ERROR) << "Build ValueNode for funcgraph fail at index: : " << i;
return RET_ERROR; status = RET_ERROR;
} }
continue; continue;
} }
cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto, quantType); cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto, quantType);
if (cnode_ptr == nullptr) { if (cnode_ptr == nullptr) {
MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i;
return RET_NULL_PTR; status = (status == RET_OK ? RET_NULL_PTR : status);
} }
} }
if (status != RET_OK) {
return status;
}
if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) { if (!BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr)) {
MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed"; MS_LOG(ERROR) << "Build ReturnNode for funcgraph failed";
return RET_ERROR; status = RET_ERROR;
} }
return RET_OK; return status;
} }
int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, int AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
const schema::QuantType &quantType) { const schema::QuantType &quantType) {
if (outputFuncGraph == nullptr) { if (outputFuncGraph == nullptr) {
MS_LOG(ERROR) << "fundgraph is nullptr"; MS_LOG(ERROR) << "fundgraph is nullptr";
return RET_NULL_PTR; return RET_NULL_PTR;

View File

@ -24,6 +24,7 @@
#include "include/errorcode.h" #include "include/errorcode.h"
#include "tools/converter/parser/onnx/onnx.pb.h" #include "tools/converter/parser/onnx/onnx.pb.h"
#include "tools/converter/converter_context.h"
#include "tools/anf_importer/anf_importer.h" #include "tools/anf_importer/anf_importer.h"
#include "abstract/abstract_value.h" #include "abstract/abstract_value.h"
@ -47,10 +48,10 @@ class AnfImporterFromProtobuf : public AnfImporter {
int AddReturnCNode() override { return RET_ERROR; }; int AddReturnCNode() override { return RET_ERROR; };
int ParseModelConfigureInfo(const onnx::ModelProto &model_proto); int ParseModelConfigureInfo(const onnx::ModelProto &model_proto);
int BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, int BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
const schema::QuantType &quantType); const schema::QuantType &quantType);
int ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); int ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto);
int ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, int ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto,
const schema::QuantType &quantType); const schema::QuantType &quantType);
int BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); int BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto);
CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto,
const schema::QuantType &quantType); const schema::QuantType &quantType);

View File

@ -15,6 +15,8 @@
*/ */
#include "tools/common/storage.h" #include "tools/common/storage.h"
#include <sys/stat.h>
#include <unistd.h>
#include "flatbuffers/flatbuffers.h" #include "flatbuffers/flatbuffers.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "src/common/file_utils.h" #include "src/common/file_utils.h"
@ -31,7 +33,10 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath
MS_LOG(ERROR) << "GetBufferPointer nullptr"; MS_LOG(ERROR) << "GetBufferPointer nullptr";
return RET_ERROR; return RET_ERROR;
} }
if (access((outputPath + ".ms").c_str(), F_OK) == 0) {
MS_LOG(WARNING) << "this file " << outputPath << ".ms has been existed";
chmod((outputPath + ".ms").c_str(), S_IWUSR);
}
std::ofstream output(outputPath + ".ms", std::ofstream::binary); std::ofstream output(outputPath + ".ms", std::ofstream::binary);
if (!output.is_open()) { if (!output.is_open()) {
MS_LOG(ERROR) << "Can not open output file: " << outputPath << ".ms"; MS_LOG(ERROR) << "Can not open output file: " << outputPath << ".ms";
@ -40,6 +45,7 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath
output.write((const char *)content, size); output.write((const char *)content, size);
output.close(); output.close();
chmod((outputPath + ".ms").c_str(), S_IRUSR);
return RET_OK; return RET_OK;
} }

View File

@ -23,7 +23,7 @@
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "tools/converter/quantizer/quantizer.h" #include "tools/converter/quantizer/quantizer.h"
#include "tools/converter/return_code.h" #include "tools/converter/converter_context.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

View File

@ -152,6 +152,7 @@ int RunConverter(int argc, const char **argv) {
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
} }
NoSupportOp::GetInstance()->PrintOps();
status = ReturnCode::GetSingleReturnCode()->GetReturnCode(); status = ReturnCode::GetSingleReturnCode()->GetReturnCode();
if (fb_graph == nullptr) { if (fb_graph == nullptr) {
MS_LOG(ERROR) << "Convert model return nullptr"; MS_LOG(ERROR) << "Convert model return nullptr";

View File

@ -25,7 +25,7 @@
#include "tools/anf_importer/anf_importer.h" #include "tools/anf_importer/anf_importer.h"
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"
#include "tools/converter/anf_transform.h" #include "tools/converter/anf_transform.h"
#include "tools/converter/return_code.h" #include "tools/converter/converter_context.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {

View File

@ -17,13 +17,16 @@
#ifndef LITE_RETURN_CODE_H #ifndef LITE_RETURN_CODE_H
#define LITE_RETURN_CODE_H #define LITE_RETURN_CODE_H
#include <string>
#include <set>
#include "include/errorcode.h" #include "include/errorcode.h"
#include "utils/log_adapter.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
class ReturnCode { class ReturnCode {
public: public:
~ReturnCode() {} ~ReturnCode() = default;
static ReturnCode *GetSingleReturnCode() { static ReturnCode *GetSingleReturnCode() {
static ReturnCode returnCode; static ReturnCode returnCode;
return &returnCode; return &returnCode;
@ -33,15 +36,31 @@ class ReturnCode {
statusCode = status; statusCode = status;
} }
} }
STATUS GetReturnCode() const { STATUS GetReturnCode() const { return statusCode; }
return statusCode;
}
private: private:
ReturnCode() { statusCode = RET_OK; } ReturnCode() { statusCode = RET_OK; }
int statusCode; int statusCode;
}; };
class NoSupportOp {
public:
~NoSupportOp() = default;
static NoSupportOp *GetInstance() {
static NoSupportOp noSupportOp;
return &noSupportOp;
}
void InsertOp(const std::string &op_name) { noSupportOps.insert(op_name); }
void PrintOps() const {
for (auto &op_name : noSupportOps) {
MS_LOG(ERROR) << "The op " << op_name << " hasn't been supported";
}
}
private:
NoSupportOp() { noSupportOps.clear(); }
std::set<std::string> noSupportOps;
};
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // LITE_RETURN_CODE_H #endif // LITE_RETURN_CODE_H

View File

@ -22,7 +22,7 @@
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "tools/anf_importer/import_from_meta_graphT.h" #include "tools/anf_importer/import_from_meta_graphT.h"
#include "ir/anf.h" #include "ir/anf.h"
#include "tools/converter/return_code.h" #include "tools/converter/converter_context.h"
namespace mindspore::lite { namespace mindspore::lite {
using namespace schema; using namespace schema;
@ -40,7 +40,7 @@ class ModelParser {
return nullptr; return nullptr;
} }
auto func_graph = this->Fb2Anf(meta_graph); auto func_graph = this->Fb2Anf(meta_graph);
delete(meta_graph); delete (meta_graph);
return func_graph; return func_graph;
} }

View File

@ -84,6 +84,9 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "ParseLayer failed " << status; MS_LOG(ERROR) << "ParseLayer failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
for (auto &tensor : tensorCache.GetCachedTensor()) {
delete tensor;
}
return nullptr; return nullptr;
} }
@ -179,6 +182,8 @@ STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, T
STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight, STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight,
TensorCache *tensorCache, schema::MetaGraphT *subGraphDef, TensorCache *tensorCache, schema::MetaGraphT *subGraphDef,
const QuantType &quantType) { const QuantType &quantType) {
static bool interrupt = false;
int status = RET_OK;
for (int i = 0; i < proto.layer_size(); i++) { for (int i = 0; i < proto.layer_size(); i++) {
auto layer = proto.layer(i); auto layer = proto.layer(i);
@ -222,38 +227,46 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff
} }
continue; continue;
} }
auto status = SetOpInputIdx(layer, op.get(), tensorCache);
if (status != RET_OK) {
MS_LOG(ERROR) << "Set Op " << layer.name() << " Input Index Failed!";
return status;
}
auto nodeParser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type().c_str()); auto nodeParser = CaffeNodeParserRegistry::GetInstance()->GetNodeParser(layer.type().c_str());
if (nodeParser == nullptr) { if (nodeParser == nullptr || interrupt) {
MS_LOG(ERROR) << "Don't support type " << layer.type() << ". for caffe op " << layer.name(); interrupt = true;
return RET_NULL_PTR; if (nodeParser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(layer.type());
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
}
continue;
} }
std::vector<schema::TensorT *> weightVec; std::vector<schema::TensorT *> weightVec;
status = nodeParser->Parse(layer, layerP, op.get(), &weightVec); auto status_node = nodeParser->Parse(layer, layerP, op.get(), &weightVec);
if (status != RET_OK) { if (status_node != RET_OK) {
interrupt = true;
MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!"; MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!";
return status; status = (status == RET_OK ? RET_NOT_FIND_OP : status);
continue;
} }
status_node = SetOpInputIdx(layer, op.get(), tensorCache);
if (status_node != RET_OK) {
MS_LOG(ERROR) << "Set Op " << layer.name() << " Input Index Failed!";
status = (status == RET_OK ? status_node : status);
}
SetWeightTensor(weightVec, op.get(), tensorCache); SetWeightTensor(weightVec, op.get(), tensorCache);
status = SetOpOutputIdx(layer, op.get(), tensorCache); status_node = SetOpOutputIdx(layer, op.get(), tensorCache);
if (status != RET_OK) { if (status_node != RET_OK) {
interrupt = true;
MS_LOG(ERROR) << "Set Op " << layer.name() << " Output Index Failed!"; MS_LOG(ERROR) << "Set Op " << layer.name() << " Output Index Failed!";
return status; status = (status == RET_OK ? RET_NOT_FIND_OP : status);
continue;
} }
// op->fmkType = FmkType_CAFFE; // op->fmkType = FmkType_CAFFE;
subGraphDef->nodes.emplace_back(move(op)); subGraphDef->nodes.emplace_back(move(op));
} }
} }
return RET_OK; return status;
} }
STATUS CaffeModelParser::GetModelInput(const caffe::NetParameter &proto, TensorCache *tensorCache) { STATUS CaffeModelParser::GetModelInput(const caffe::NetParameter &proto, TensorCache *tensorCache) {

View File

@ -249,6 +249,7 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, schema::CNodeT *dst_op, schema::TensorT *dst_tensor,
TensorCache *tensor_cache, const QuantType &quantType) { TensorCache *tensor_cache, const QuantType &quantType) {
// change op_type() to name(), that is unique // change op_type() to name(), that is unique
static bool interrupt = false;
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
dst_op->quantType = quantType; dst_op->quantType = quantType;
// dst_op->fmkType = FmkType_ONNX; // dst_op->fmkType = FmkType_ONNX;
@ -256,15 +257,25 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
<< onnx_node.input_size(); << onnx_node.input_size();
// get the real op type // get the real op type
SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache); SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache);
auto status = ParseOnnxNodeAttr(onnx_graph, onnx_node, onnx_node.op_type(), dst_op); auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type());
if (node_parser == nullptr || interrupt) {
interrupt = true;
if (node_parser == nullptr) {
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
}
return RET_NOT_FIND_OP;
}
auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "parser onnx node attr failed"; interrupt = true;
MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed";
return status; return status;
} }
// set op input index // set op input index
std::vector<string> node_inputs; std::vector<string> node_inputs;
(void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end()); (void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end());
if (SetOpInputIndex(node_inputs, dst_op, onnx_node, tensor_cache)) { if (SetOpInputIndex(node_inputs, dst_op, onnx_node, tensor_cache)) {
interrupt = true;
MS_LOG(ERROR) << "SetOpInputIndex failed"; MS_LOG(ERROR) << "SetOpInputIndex failed";
return RET_ERROR; return RET_ERROR;
} }
@ -273,6 +284,7 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
(void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end()); (void)node_outputs.insert(node_outputs.begin(), onnx_node.output().begin(), onnx_node.output().end());
if (SetOpOutputIndex(node_outputs, dst_op, tensor_cache) != RET_OK) { if (SetOpOutputIndex(node_outputs, dst_op, tensor_cache) != RET_OK) {
interrupt = true;
MS_LOG(ERROR) << "SetOpOutputIndex failed"; MS_LOG(ERROR) << "SetOpOutputIndex failed";
return RET_ERROR; return RET_ERROR;
} }
@ -340,8 +352,7 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, co
const string &onnx_op_type, schema::CNodeT *dst_op) { const string &onnx_op_type, schema::CNodeT *dst_op) {
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type); auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type);
if (node_parser == nullptr) { if (node_parser == nullptr) {
MS_LOG(ERROR) << "not find " << onnx_op_type << ", node parser is nullptr"; return RET_NOT_FIND_OP;
return RET_NULL_PTR;
} }
return node_parser->Parse(onnx_graph, onnx_node, dst_op); return node_parser->Parse(onnx_graph, onnx_node, dst_op);
} }
@ -503,32 +514,42 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
} }
// init op node input/output tensor, and dst_op attr // init op node input/output tensor, and dst_op attr
for (const auto &onnx_node : onnx_graph.node()) { for (const auto &onnx_node : onnx_graph.node()) {
int status_node = RET_OK;
if (onnx_node.op_type() == "Constant") { if (onnx_node.op_type() == "Constant") {
continue; continue;
} }
if (onnx_node.op_type() == "Gemm") { if (onnx_node.op_type() == "Gemm") {
ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache); if (status == RET_OK) {
ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache);
}
continue; continue;
} else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") { } else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") {
status = ParseOnnxGivenFillNode(onnx_node, &tensor_cache); if (status == RET_OK) {
if (status != RET_OK) { status_node = ParseOnnxGivenFillNode(onnx_node, &tensor_cache);
MS_LOG(ERROR) << "ParseOnnxGivenFillNode failed: " << status; if (status_node != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); MS_LOG(ERROR) << "ParseOnnxGivenFillNode failed: " << status_node;
return nullptr; status = (status == RET_OK ? status_node : status);
}
} }
continue; continue;
} }
std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>(); std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>();
std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>(); std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>();
status = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType); status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType);
if (status != RET_OK) { if (status_node != RET_OK) {
MS_LOG(ERROR) << "parse node " << onnx_node.op_type() << " failed"; status = (status == RET_OK ? status_node : status);
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); continue;
return nullptr;
} }
dst_graph->nodes.emplace_back(std::move(dst_op)); dst_graph->nodes.emplace_back(std::move(dst_op));
} }
if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
for (auto &tensor : tensor_cache.GetCachedTensor()) {
delete tensor;
}
return nullptr;
}
SetAllTensors(tensor_cache, dst_graph.get()); SetAllTensors(tensor_cache, dst_graph.get());
dst_graph->name = GetModelName(modelFile); dst_graph->name = GetModelName(modelFile);
return dst_graph.release(); return dst_graph.release();

View File

@ -300,6 +300,15 @@ STATUS TfliteSingleInputOpParser::Parse(const std::unique_ptr<tflite::OperatorT>
} }
op->primitive->value.type = schema::PrimitiveType_Floor; op->primitive->value.type = schema::PrimitiveType_Floor;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
} else if (std::strcmp(node_name, "Neg") == 0) {
MS_LOG(DEBUG) << "parse TfliteNegParser";
auto attr = std::make_unique<schema::NegT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_Neg;
op->primitive->value.value = attr.release();
} }
AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(),
@ -415,6 +424,7 @@ TfliteNodeRegister g_TfliteLogParser("Log", new TfliteLogParser());
TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser()); TfliteNodeRegister g_tfliteRoundParser("Round", new TfliteRoundParser());
TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser()); TfliteNodeRegister g_TfliteCeilParser("Ceil", new TfliteCeilParser());
TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser()); TfliteNodeRegister g_tfliteFloorParser("flOOR", new TfliteFloorParser());
TfliteNodeRegister g_tfliteNegParser("Neg", new TfliteNegParser());
TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser()); TfliteNodeRegister g_tfliteEqualParser("Equal", new TfliteEqualParser());
TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser()); TfliteNodeRegister g_tfliteNotEqualParser("NotEqual", new TfliteNotEqualParser());

View File

@ -157,6 +157,11 @@ class TfliteFloorParser : public TfliteSingleInputOpParser {
TfliteFloorParser() : TfliteSingleInputOpParser() {} TfliteFloorParser() : TfliteSingleInputOpParser() {}
}; };
class TfliteNegParser : public TfliteSingleInputOpParser {
public:
TfliteNegParser() : TfliteSingleInputOpParser() {}
};
class TfliteCompareOpParser : public TfliteNodeParser { class TfliteCompareOpParser : public TfliteNodeParser {
public: public:
TfliteCompareOpParser() : TfliteNodeParser("node_name") {} TfliteCompareOpParser() : TfliteNodeParser("node_name") {}

View File

@ -98,6 +98,7 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const QuantType &quant_type, schema::MetaGraphT *sub_graph) { const QuantType &quant_type, schema::MetaGraphT *sub_graph) {
int idx = 0; int idx = 0;
int status = RET_OK;
for (const auto &tflite_op : tflite_subgraph->operators) { for (const auto &tflite_op : tflite_subgraph->operators) {
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
auto op_type = GetMSOpType(tflite_op_type); auto op_type = GetMSOpType(tflite_op_type);
@ -114,21 +115,24 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type); auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(op_type);
if (node_parser == nullptr) { if (node_parser == nullptr) {
MS_LOG(ERROR) << "cannot find node parser, opType: " << op_type.c_str(); NoSupportOp::GetInstance()->InsertOp(op_type);
return RET_NOT_FIND_OP; status = (status == RET_OK ? RET_NOT_FIND_OP : status);
} continue;
int status = node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, op.get(), &tensorsId,
&tensorsFormat, &tensorsIdMap);
if (status != RET_OK) {
MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed";
return status;
} }
if (status == RET_OK) {
status = node_parser->Parse(tflite_op, tflite_subgraph->tensors, tflite_model->buffers, op.get(), &tensorsId,
&tensorsFormat, &tensorsIdMap);
if (status != RET_OK) {
MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed";
continue;
}
sub_graph->nodes.emplace_back(op.release()); sub_graph->nodes.emplace_back(op.release());
opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get(); opMap[sub_graph->nodes.back()->name] = sub_graph->nodes.back().get();
tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get(); tfliteOpMap[tflite_op.get()] = sub_graph->nodes.back().get();
}
} }
return RET_OK; return status;
} }
STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
@ -162,8 +166,8 @@ STATUS TfliteModelParser::ConvertTensor(const std::unique_ptr<tflite::SubGraphT>
if (isConst) { if (isConst) {
int status = CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get()); int status = CopyConstTensorData(tflite_model_buffer, tflite_tensor.get(), tensor.get());
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "obtain const tensor failed"; MS_LOG(ERROR) << "obtain const tensor failed";
return status; return status;
} }
} }
// set tensor attr // set tensor attr

View File

@ -118,6 +118,7 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{
{tflite::BuiltinOperator_UNPACK, "Unstack"}, {tflite::BuiltinOperator_UNPACK, "Unstack"},
{tflite::BuiltinOperator_CUSTOM, "Custom"}, {tflite::BuiltinOperator_CUSTOM, "Custom"},
{tflite::BuiltinOperator_MIRROR_PAD, "MirrorPad"}, {tflite::BuiltinOperator_MIRROR_PAD, "MirrorPad"},
{tflite::BuiltinOperator_NEG, "Neg"},
}; };
std::map<tflite::ActivationFunctionType, schema::ActivationType> tfMsActivationFunctionMap{ std::map<tflite::ActivationFunctionType, schema::ActivationType> tfMsActivationFunctionMap{

View File

@ -26,7 +26,7 @@
#include "backend/optimizer/common/pattern_engine.h" #include "backend/optimizer/common/pattern_engine.h"
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "src/param_value_lite.h" #include "src/param_value_lite.h"
#include "tools/converter/return_code.h" #include "tools/converter/converter_context.h"
using PrimitiveCPtr = std::shared_ptr<mindspore::lite::PrimitiveC>; using PrimitiveCPtr = std::shared_ptr<mindspore::lite::PrimitiveC>;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
@ -73,7 +73,7 @@ bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node);
size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);
ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node); ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node);
enum kTransFilterType { enum kTransFilterType {
kKCHW2HWCK, // 0 kKCHW2HWCK, // 0
@ -105,11 +105,11 @@ STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type,
STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
int32_t filterH, int32_t filterW); int32_t filterH, int32_t filterW);
template<typename T> template <typename T>
static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
int32_t filterH, int32_t filterW); int32_t filterH, int32_t filterW);
template<typename T> template <typename T>
static lite::STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type); static lite::STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type);
STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format); STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format);

View File

@ -18,7 +18,7 @@
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_ #define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BIASADD_FUSION_H_
#include "backend/optimizer/common/optimizer.h" #include "backend/optimizer/common/optimizer.h"
#include "tools/converter/return_code.h" #include "tools/converter/converter_context.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {