From bc8b9245a25e96d4eb2d8c7df531b794adf1dcd8 Mon Sep 17 00:00:00 2001 From: wsc Date: Mon, 14 Sep 2020 10:26:01 +0800 Subject: [PATCH] fix bug of operator quantize --- mindspore/lite/src/ops/primitive_c.h | 3 +- .../lite/src/runtime/kernel/arm/fp32/cast.cc | 1 + .../parser/tflite/tflite_quantize_parser.cc | 32 ++++++++++++------- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index 80728033c52..cda2659d04e 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -39,7 +39,8 @@ constexpr uint32_t kDoubleNum = 2; constexpr uint32_t kMultiNum = 3; constexpr uint32_t kDimension_4d = 4; -const std::set kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeFloat32, kNumberTypeFloat16}; +const std::set kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeInt32, + kNumberTypeFloat32, kNumberTypeFloat16}; #ifdef PRIMITIVE_WRITEABLE using TensorPtr = std::shared_ptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc index ac45e4a306d..7a780a44fbc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc @@ -147,6 +147,7 @@ kernel::LiteKernel *CpuCastFp32KernelCreator(const std::vector & REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Cast, CpuCastFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeUInt8, PrimitiveType_Cast, CpuCastFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Cast, CpuCastFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Cast, CpuCastFp32KernelCreator) #ifndef ENABLE_ARM64 REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Cast, CpuCastFp32KernelCreator) diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc index b0aff032360..67fd6e5d1f4 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_quantize_parser.cc @@ -36,27 +36,37 @@ STATUS TfliteQuantizeParser::Parse(const std::unique_ptr &tfl return RET_NULL_PTR; } - std::unique_ptr attr = std::make_unique(); - if (attr == nullptr) { - MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; - } - const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]]; if (in_tensor == nullptr) { MS_LOG(ERROR) << "input tensor is null"; return RET_NULL_PTR; } - attr->srcT = GetTfliteDataType(in_tensor->type); const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]]; if (out_tensor == nullptr) { MS_LOG(ERROR) << "output tensor is null"; return RET_NULL_PTR; } - attr->dstT = GetTfliteDataType(out_tensor->type); - - op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; - op->primitive->value.value = attr.release(); + if (GetTfliteDataType(in_tensor->type) != kNumberTypeInt8) { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + attr->srcT = GetTfliteDataType(in_tensor->type); + attr->dstT = GetTfliteDataType(out_tensor->type); + op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast; + op->primitive->value.value = attr.release(); + } else { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + attr->srcT = GetTfliteDataType(in_tensor->type); + attr->dstT = GetTfliteDataType(out_tensor->type); + op->primitive->value.type = schema::PrimitiveType_Cast; + op->primitive->value.value = attr.release(); + } AddOpInput(op, tensors_id, tensors_format, tensors_id_map, tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);