forked from mindspore-Ecosystem/mindspore
add onnx parser and adjust the way of printing unsupport op
This commit is contained in:
parent
5a83415f07
commit
357b597b4f
|
@ -58,7 +58,8 @@ enum ActivationType : byte {
|
|||
THRESHOLDRELU = 14,
|
||||
LINEAR = 15,
|
||||
HARD_TANH = 16,
|
||||
UNKNOW = 17
|
||||
SIGN = 17,
|
||||
UNKNOW = 18
|
||||
}
|
||||
enum ActivationGradType : byte {
|
||||
NO_ACTIVATION = 0,
|
||||
|
|
|
@ -595,10 +595,14 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
|
|||
for (int i = 0; i < node_proto.input_size(); ++i) {
|
||||
const std::string &input_name = node_proto.input(i);
|
||||
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
|
||||
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
|
||||
return nullptr;
|
||||
if (!interrupt) {
|
||||
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
|
||||
interrupt = true;
|
||||
}
|
||||
inputs.push_back(nullptr);
|
||||
} else {
|
||||
inputs.push_back(anfnode_build_map_[input_name]);
|
||||
}
|
||||
inputs.push_back(anfnode_build_map_[input_name]);
|
||||
}
|
||||
auto primitivec_ptr = PrimitiveC::Create(*prim, inputs, quantType);
|
||||
if (primitivec_ptr == nullptr || interrupt) {
|
||||
|
@ -714,6 +718,7 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG
|
|||
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
|
||||
CNodePtr cnode_ptr = nullptr;
|
||||
int status = RET_OK;
|
||||
NoSupportOp::GetInstance()->SetFmkType("MINDIR");
|
||||
for (int i = 0; i < importProto.node_size(); ++i) {
|
||||
const onnx::NodeProto &node_proto = importProto.node(i);
|
||||
const std::string &node_type = node_proto.op_type();
|
||||
|
|
|
@ -34,7 +34,6 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath
|
|||
return RET_ERROR;
|
||||
}
|
||||
if (access((outputPath + ".ms").c_str(), F_OK) == 0) {
|
||||
MS_LOG(WARNING) << "this file " << outputPath << ".ms has been existed";
|
||||
chmod((outputPath + ".ms").c_str(), S_IWUSR);
|
||||
}
|
||||
std::ofstream output(outputPath + ".ms", std::ofstream::binary);
|
||||
|
|
|
@ -65,7 +65,6 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
|||
FuncGraphPtr graph = nullptr;
|
||||
if (flag->fmk == converter::FmkType_MS) {
|
||||
MS_ASSERT(nullptr != modelImporter);
|
||||
modelImporter->Import(flag->quantType);
|
||||
int status = modelImporter->Import(flag->quantType);
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
graph = modelImporter->GetResult();
|
||||
|
|
|
@ -50,16 +50,23 @@ class NoSupportOp {
|
|||
static NoSupportOp noSupportOp;
|
||||
return &noSupportOp;
|
||||
}
|
||||
void SetFmkType(const std::string &fmk_type) { fmkType = fmk_type; }
|
||||
void InsertOp(const std::string &op_name) { noSupportOps.insert(op_name); }
|
||||
void PrintOps() const {
|
||||
for (auto &op_name : noSupportOps) {
|
||||
MS_LOG(ERROR) << "The op " << op_name << " hasn't been supported";
|
||||
if (!noSupportOps.empty()) {
|
||||
MS_LOG(ERROR) << "===========================================";
|
||||
MS_LOG(ERROR) << "UNSUPPORT OP LIST:";
|
||||
for (auto &op_name : noSupportOps) {
|
||||
MS_LOG(ERROR) << "FMKTYPE: " << fmkType << ", OP TYPE: " << op_name;
|
||||
}
|
||||
MS_LOG(ERROR) << "===========================================";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
NoSupportOp() { noSupportOps.clear(); }
|
||||
std::set<std::string> noSupportOps;
|
||||
std::string fmkType;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -80,6 +80,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
NoSupportOp::GetInstance()->SetFmkType("CAFFE");
|
||||
status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quantType);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "ParseLayer failed " << status;
|
||||
|
@ -242,7 +243,11 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff
|
|||
auto status_node = nodeParser->Parse(layer, layerP, op.get(), &weightVec);
|
||||
if (status_node != RET_OK) {
|
||||
interrupt = true;
|
||||
MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!";
|
||||
if (status_node == RET_NOT_SUPPORT) {
|
||||
NoSupportOp::GetInstance()->InsertOp(layer.type());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!";
|
||||
}
|
||||
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -559,6 +559,29 @@ STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS OnnxSignParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
|
||||
MS_LOG(DEBUG) << "onnx TanhParser";
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
op->primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (op->primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "op->primitive is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->type = schema::ActivationType_SIGN;
|
||||
op->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
op->primitive->value.value = attr.release();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser());
|
||||
OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser());
|
||||
OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser());
|
||||
|
@ -584,5 +607,6 @@ OnnxNodeRegistrar g_onnxTanParser("Tan", new OnnxTanParser());
|
|||
OnnxNodeRegistrar g_onnxAtanParser("Atan", new OnnxAtanParser());
|
||||
OnnxNodeRegistrar g_onnxAsinParser("Asin", new OnnxAsinParser());
|
||||
OnnxNodeRegistrar g_onnxTanhParser("Tanh", new OnnxTanhParser());
|
||||
OnnxNodeRegistrar g_onnxSignParser("Sign", new OnnxTanhParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -165,6 +165,12 @@ class OnnxTanhParser : public OnnxNodeParser {
|
|||
OnnxTanhParser() : OnnxNodeParser("Tanh") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
};
|
||||
|
||||
class OnnxSignParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxSignParser() : OnnxNodeParser("Sign") {}
|
||||
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H
|
||||
|
|
|
@ -47,12 +47,18 @@ STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->type = schema::ActivationType_RELU6;
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Activation;
|
||||
op->primitive->value.value = attr.release();
|
||||
} else {
|
||||
MS_LOG(ERROR) << "only support convert clip(0,6) to relu6, other value is not supported";
|
||||
return RET_ERROR;
|
||||
std::unique_ptr<schema::ClipT> attr = std::make_unique<schema::ClipT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->max = max;
|
||||
attr->min = min;
|
||||
op->primitive->value.type = schema::PrimitiveType_Clip;
|
||||
op->primitive->value.value = attr.release();
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
|
|
@ -271,7 +271,11 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
|
|||
auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op);
|
||||
if (status != RET_OK) {
|
||||
interrupt = true;
|
||||
MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed";
|
||||
if (status == RET_NOT_SUPPORT) {
|
||||
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
||||
} else {
|
||||
MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed";
|
||||
}
|
||||
return status;
|
||||
}
|
||||
// set op input index
|
||||
|
@ -514,6 +518,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
|||
return nullptr;
|
||||
}
|
||||
// init op node input/output tensor, and dst_op attr
|
||||
NoSupportOp::GetInstance()->SetFmkType("ONNX");
|
||||
for (const auto &onnx_node : onnx_graph.node()) {
|
||||
int status_node = RET_OK;
|
||||
if (onnx_node.op_type() == "Constant") {
|
||||
|
|
|
@ -96,6 +96,7 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
|
|||
const QuantType &quant_type, schema::MetaGraphT *sub_graph) {
|
||||
int idx = 0;
|
||||
int status = RET_OK;
|
||||
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
|
||||
for (const auto &tflite_op : tflite_subgraph->operators) {
|
||||
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
|
||||
auto op_type = GetMSOpType(tflite_op_type);
|
||||
|
@ -119,7 +120,11 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
|
|||
if (status == RET_OK) {
|
||||
status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get());
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed";
|
||||
if (status == RET_NOT_SUPPORT) {
|
||||
NoSupportOp::GetInstance()->InsertOp(op_type);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed";
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue