!7121 [MSLITE] Fix bug of onxx cast parser.

Merge pull request !7121 from wangshaocong/bugfix_master
This commit is contained in:
mindspore-ci-bot 2020-10-12 20:26:55 +08:00 committed by Gitee
commit 79b974eb82
6 changed files with 12 additions and 5 deletions

View File

@ -95,7 +95,7 @@ int Cast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
return RET_OK;
}
if (input->data_type() != GetSrcT()) {
if (GetSrcT() != 0 && input->data_type() != GetSrcT()) {
MS_LOG(ERROR) << "input dataType is error";
return RET_INPUT_TENSOR_ERROR;
}

View File

@ -131,6 +131,7 @@
#include "src/ops/custom_predict.h"
#include "src/ops/custom_normalize.h"
#include "src/ops/custom_extract_features.h"
#include "src/ops/upsample.h"
#ifdef PRIMITIVE_WRITEABLE
#include "tools/converter/quantizer/quantize_util.h"
#endif
@ -692,6 +693,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new CustomNormalize(primitive);
case schema::PrimitiveType_CustomExtractFeatures:
return new CustomExtractFeatures(primitive);
case schema::PrimitiveType_Upsample:
return new Upsample(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:
@ -960,6 +963,8 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) {
return NewPrimitiveC<CustomNormalize>(primitive);
case schema::PrimitiveType_CustomExtractFeatures:
return NewPrimitiveC<CustomExtractFeatures>(primitive);
case schema::PrimitiveType_Upsample:
return NewPrimitiveC<Upsample>(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:

View File

@ -15,6 +15,7 @@
*/
#include "tools/converter/parser/onnx/onnx_cast_parser.h"
#include "tools/converter/parser/onnx/onnx_model_parser.h"
#include <memory>
namespace mindspore {
@ -40,7 +41,8 @@ STATUS OnnxCastParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
for (const auto &onnx_node_attr : onnx_node.attribute()) {
const auto &attribute_name = onnx_node_attr.name();
if (attribute_name == "to") {
attr->dstT = static_cast<int32_t>(onnx_node_attr.i());
attr->dstT = static_cast<int32_t>(
OnnxModelParser::GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(onnx_node_attr.i())));
}
}

View File

@ -43,9 +43,9 @@ class OnnxModelParser : public ModelParser {
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
private:
TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
private:
std::vector<int32_t> GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value);
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache);

View File

@ -15,7 +15,7 @@
*/
#include <memory>
#include "tools/converter/parser/onnx/onnx_unsample_parser.h"
#include "tools/converter/parser/onnx/onnx_upsample_parser.h"
namespace mindspore {
namespace lite {