From b0c80c2754985267316479a66429ffd7db25660a Mon Sep 17 00:00:00 2001 From: wang_shaocong Date: Mon, 12 Oct 2020 09:41:50 +0800 Subject: [PATCH] Fix bug of onnx cast parser. --- mindspore/lite/src/ops/cast.cc | 2 +- mindspore/lite/src/ops/primitive_c.cc | 5 +++++ .../lite/tools/converter/parser/onnx/onnx_cast_parser.cc | 4 +++- .../lite/tools/converter/parser/onnx/onnx_model_parser.h | 4 ++-- .../{onnx_unsample_parser.cc => onnx_upsample_parser.cc} | 2 +- .../onnx/{onnx_unsample_parser.h => onnx_upsample_parser.h} | 0 6 files changed, 12 insertions(+), 5 deletions(-) rename mindspore/lite/tools/converter/parser/onnx/{onnx_unsample_parser.cc => onnx_upsample_parser.cc} (96%) rename mindspore/lite/tools/converter/parser/onnx/{onnx_unsample_parser.h => onnx_upsample_parser.h} (100%) diff --git a/mindspore/lite/src/ops/cast.cc b/mindspore/lite/src/ops/cast.cc index 6e23da47327..41ff55115b6 100644 --- a/mindspore/lite/src/ops/cast.cc +++ b/mindspore/lite/src/ops/cast.cc @@ -95,7 +95,7 @@ int Cast::InferShape(std::vector inputs_, std::vector output return RET_OK; } - if (input->data_type() != GetSrcT()) { + if (GetSrcT() != 0 && input->data_type() != GetSrcT()) { MS_LOG(ERROR) << "input dataType is error"; return RET_INPUT_TENSOR_ERROR; } diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index ac7eb4ef69b..42fc24647b0 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -131,6 +131,7 @@ #include "src/ops/custom_predict.h" #include "src/ops/custom_normalize.h" #include "src/ops/custom_extract_features.h" +#include "src/ops/upsample.h" #ifdef PRIMITIVE_WRITEABLE #include "tools/converter/quantizer/quantize_util.h" #endif @@ -692,6 +693,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new CustomNormalize(primitive); case schema::PrimitiveType_CustomExtractFeatures: return new CustomExtractFeatures(primitive); + case schema::PrimitiveType_Upsample: + return new Upsample(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: @@ -960,6 +963,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { return NewPrimitiveC(primitive); case schema::PrimitiveType_CustomExtractFeatures: return NewPrimitiveC(primitive); + case schema::PrimitiveType_Upsample: + return NewPrimitiveC(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc index 3d458e66e51..d42813c6f16 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_cast_parser.cc @@ -15,6 +15,7 @@ */ #include "tools/converter/parser/onnx/onnx_cast_parser.h" +#include "tools/converter/parser/onnx/onnx_model_parser.h" #include namespace mindspore { @@ -40,7 +41,8 @@ STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "to") { - attr->dstT = static_cast(onnx_node_attr.i()); + attr->dstT = static_cast( + OnnxModelParser::GetDataTypeFromOnnx(static_cast(onnx_node_attr.i()))); } } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index abc4bccf8b8..8a8dfe52ea0 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -43,9 +43,9 @@ class OnnxModelParser : public ModelParser { schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType = QuantType_QUANT_NONE) override; - private: - TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); + static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); + private: std::vector GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache); diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc similarity index 96% rename from mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc rename to mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc index ad1be93b3b7..7e9b24c063f 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.cc @@ -15,7 +15,7 @@ */ #include -#include "tools/converter/parser/onnx/onnx_unsample_parser.h" +#include "tools/converter/parser/onnx/onnx_upsample_parser.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.h similarity index 100% rename from mindspore/lite/tools/converter/parser/onnx/onnx_unsample_parser.h rename to mindspore/lite/tools/converter/parser/onnx/onnx_upsample_parser.h