!4709 debug onnx model converter

Merge pull request !4709 from wangzhe/master
This commit is contained in:
mindspore-ci-bot 2020-08-19 16:24:14 +08:00 committed by Gitee
commit 9ce6b36e02
9 changed files with 37 additions and 50 deletions

View File

@ -360,6 +360,8 @@ PrimitiveC *ModelImpl::CopyPrimitive(const schema::Primitive *src_prim) {
return new DeDepthwiseConv2D(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Shape:
return new Shape(const_cast<schema::Primitive *>(src_prim));
case schema::PrimitiveType_Unsqueeze:
return new Unsqueeze(const_cast<schema::Primitive *>(src_prim));
default:
break;
}

View File

@ -37,11 +37,7 @@ int Shape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor:
}
auto in_tensor = inputs_.front();
auto out_tensor = outputs_.front();
auto ret_dtype = out_tensor->set_data_type(kNumberTypeInt32);
if (ret_dtype != in_tensor->data_type()) {
MS_LOG(ERROR) << "Set datatype fails.";
return RET_ERROR;
}
out_tensor->set_data_type(kNumberTypeInt32);
if (!GetInferFlag()) {
return RET_OK;
}

View File

@ -1,3 +1,4 @@
mtk_detect-mbv2-shortcut-400-400-simplified.onnx
mtk_emotions-d2012-75.8%.onnx
mtk_face_features_v3.onnx
ml_face_3d.onnx

View File

@ -21,7 +21,7 @@
namespace mindspore {
namespace lite {
bool OnnxConvParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr) {
bool OnnxConvParser::ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DepthwiseConvParser";
if (attr == nullptr || attr->group != attr->channelIn) {
return false;
@ -55,7 +55,7 @@ bool OnnxConvParser::ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *
STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx ConvParser";
auto attr = new schema::Conv2DT();
std::unique_ptr<schema::Conv2DT> attr(new (std::nothrow) schema::Conv2DT());
// set opdef each attr params
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "group") {
@ -153,17 +153,15 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr;
if (attr->group != 1) {
if (!ParseGroupConvolution(op, attr)) {
delete attr;
if (!ParseGroupConvolution(attr, op)) {
MS_LOG(ERROR) << "Convert Convolution to Depthwise failed";
return RET_ERROR;
}
} else {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_Conv2D;
op->primitive->value.value = attr.release();
}
return RET_OK;
}

View File

@ -17,6 +17,7 @@
#ifndef MS_ONNX_CONV_PARSER_H
#define MS_ONNX_CONV_PARSER_H
#include <memory>
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -28,9 +29,8 @@ class OnnxConvParser : public OnnxNodeParser {
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
private:
bool ParseGroupConvolution(schema::CNodeT *op, schema::Conv2DT *attr);
bool ParseGroupConvolution(const std::unique_ptr<schema::Conv2DT> &attr, schema::CNodeT *op);
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_CONV_PARSER_H

View File

@ -21,14 +21,14 @@
namespace mindspore {
namespace lite {
bool OnnxDeConvParser::ParseGroupDeConvolution(schema::CNodeT *op, schema::DeConv2DT *attr) {
bool OnnxDeConvParser::ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx DeConvParser";
if (attr == nullptr || attr->group != attr->channelOut) {
return false;
}
auto deDepthwiseConv2DParam(new (std::nothrow) schema::DeDepthwiseConv2DT());
std::unique_ptr<schema::DeDepthwiseConv2DT> deDepthwiseConv2DParam(new (std::nothrow) schema::DeDepthwiseConv2DT());
if (deDepthwiseConv2DParam == nullptr) {
// MS_LOGW("new DeDepthwiseConv2DT failed");
MS_LOG(ERROR) << "new DeDepthwiseConv2DT failed";
return false;
}
deDepthwiseConv2DParam->format = attr->format;
@ -51,15 +51,14 @@ bool OnnxDeConvParser::ParseGroupDeConvolution(schema::CNodeT *op, schema::DeCon
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
delete (op->primitive->value.value);
op->primitive->value.value = deDepthwiseConv2DParam;
op->primitive->value.value = deDepthwiseConv2DParam.release();
}
return true;
}
STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
auto attr = new schema::DeConv2DT();
std::unique_ptr<schema::DeConv2DT> attr(new (std::nothrow) schema::DeConv2DT());
// set opdef each attr params
for (const auto &onnx_node_attr : onnx_node.attribute()) {
if (onnx_node_attr.name() == "group") {
@ -133,23 +132,20 @@ STATUS OnnxDeConvParser::Parse(const onnx::GraphProto &onnx_graph,
attr->format = schema::Format_NCHW;
attr->hasBias = onnx_node.input().size() == 3;
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
op->primitive->value.value = attr;
}
if (attr->group != 1) {
if (!ParseGroupDeConvolution(op, attr)) {
delete attr;
// MS_LOGE("Convert DeConvolution to DeDepthwise failed");
if (!ParseGroupDeConvolution(attr, op)) {
MS_LOG(ERROR) << "Convert DeConvolution to DeDepthwise failed";
return RET_ERROR;
}
} else {
op->primitive = std::make_unique<schema::PrimitiveT>();
op->primitive->value.type = schema::PrimitiveType_DeConv2D;
op->primitive->value.value = attr.release();
}
return RET_OK;
}
OnnxNodeRegistrar g_onnxDeConvParser("ConvTranspose", new OnnxDeConvParser());
} // namespace lite
} // namespace mindspore

View File

@ -17,6 +17,7 @@
#ifndef MS_ONNX_DECONV_PARSER_H
#define MS_ONNX_DECONV_PARSER_H
#include <memory>
#include "tools/converter/parser/onnx/onnx_node_parser.h"
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
@ -28,9 +29,8 @@ class OnnxDeConvParser : public OnnxNodeParser {
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
private:
bool ParseGroupDeConvolution(schema::CNodeT *op, schema::DeConv2DT *attr);
bool ParseGroupDeConvolution(const std::unique_ptr<schema::DeConv2DT> &attr, schema::CNodeT *op);
};
} // namespace lite
} // namespace mindspore
#endif // MS_ONNX_DECONV_PARSER_H

View File

@ -151,7 +151,6 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std
return RET_ERROR;
}
if (data_type == kNumberTypeInt64) {
MS_LOG(ERROR) << "INT64" << proto.name();
tensor->dataType = kNumberTypeInt32; // CopyOnnxTensorData will convert int64 to int32
}
*index = tensor_cache->AddTensor(name, tensor.release(), type);
@ -168,7 +167,7 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
if (status != RET_OK) {
return status;
}
MS_LOG(ERROR) << "input_value name: " << input_value.name() << ", graph input index: " << index;
MS_LOG(DEBUG) << "input_value name: " << input_value.name() << ", graph input index: " << index;
graph->inputIndex.emplace_back(static_cast<uint32_t>(index));
}
}
@ -184,7 +183,7 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
return status;
}
graph->outputIndex.emplace_back(index);
MS_LOG(ERROR) << "output_value name: " << output_value.name() << ", graph output index: " << index;
MS_LOG(DEBUG) << "output_value name: " << output_value.name() << ", graph output index: " << index;
}
return RET_OK;
}
@ -399,10 +398,9 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs
STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, schema::TensorT *tensor) {
size_t data_count = 1;
std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; });
MS_LOG(ERROR) << "const tensor dims " << tensor->dims.size();
size_t data_size = 0;
const void *tensor_data = nullptr;
int32_t *buffer = nullptr;
std::unique_ptr<int32_t[]> buffer;
switch (tensor->dataType) {
case kNumberTypeFloat32:
data_size = data_count * sizeof(float);
@ -422,7 +420,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
break;
case kNumberTypeInt64:
data_size = data_count * sizeof(int32_t);
buffer = new int32_t[data_count];
buffer = std::make_unique<int32_t[]>(data_count);
const int64_t *in_data;
if (onnx_const_value.int64_data_size() == 0) {
in_data = reinterpret_cast<const int64_t *>(onnx_const_value.raw_data().data());
@ -437,7 +435,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
buffer[i] = static_cast<int>(in_data[i]);
}
}
tensor_data = reinterpret_cast<void *>(buffer);
tensor_data = reinterpret_cast<void *>(buffer.get());
break;
case kNumberTypeUInt8:
case kNumberTypeInt8:
@ -453,9 +451,6 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
MS_LOG(ERROR) << "memcpy_s failed";
return RET_ERROR;
}
if (kNumberTypeInt64 == tensor->dataType) {
free(buffer);
}
return RET_OK;
}

View File

@ -19,17 +19,17 @@
namespace mindspore {
namespace lite {
STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph,
const onnx::NodeProto &onnx_node,
STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx SliceParser";
std::unique_ptr<schema::SliceT> attr(new schema::SliceT());
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto& attribute_name = onnx_node_attr.name();
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "starts") {
const int size = onnx_node_attr.ints_size();
MS_LOG(ERROR) << "SLICE starts size " << size;
for (int i = 0; i < size; ++i) {
attr->begin.emplace_back(static_cast<int32_t>(onnx_node_attr.ints(i)));
attr->begin.emplace_back(static_cast<int32_t>(onnx_node_attr.ints(i)));
}
} else if (attribute_name == "ends") {
const int size = onnx_node_attr.ints_size();
@ -49,4 +49,3 @@ STATUS OnnxSliceParser::Parse(const onnx::GraphProto &onnx_graph,
OnnxNodeRegistrar g_onnxSliceParser("Slice", new OnnxSliceParser());
} // namespace lite
} // namespace mindspore