From 61f2c7e32b0cb59099b36a2ce03613b0b9ce3e97 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Sat, 10 Oct 2020 14:42:44 +0800 Subject: [PATCH] tflite parser suppport perchannel --- .../anf_importer/import_from_meta_graphT.cc | 17 +++++-- .../parser/tflite/tflite_model_parser.cc | 50 ++++++++++--------- 2 files changed, 39 insertions(+), 28 deletions(-) diff --git a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc index 251440101a2..43220710952 100644 --- a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc +++ b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc @@ -22,6 +22,7 @@ #include "tools/anf_importer/import_from_meta_graphT.h" #include "src/common/log_adapter.h" #include "include/errorcode.h" +#include "tools/common/tensor_util.h" namespace mindspore::lite { int AnfImporterFromMetaGraphT::ConverterConstTensor() { @@ -75,8 +76,12 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptrquantType != schema::QuantType_PostTraining && cNode->quantType != schema::QuantType_WeightQuant) { primitiveCValue->SetQuantType(cNode->quantType); for (int index : cNode->inputIndex) { - if (meta_graph_->allTensors[index]->quantParams.size() > 0) { - std::vector quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; + if (!meta_graph_->allTensors[index]->quantParams.empty()) { + std::vector quant_params(meta_graph_->allTensors[index]->quantParams.size()); + std::transform( + meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), + quant_params.begin(), + [](std::unique_ptr &quant_param) -> schema::QuantParamT { return *quant_param; }); primitiveCValue->AddInputQuantParam(quant_params); } else { std::vector empty_quant_params; @@ -84,8 +89,12 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptroutputIndex) { - if (meta_graph_->allTensors[index]->quantParams.size() > 0) { - std::vector quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; + if (!meta_graph_->allTensors[index]->quantParams.empty()) { + std::vector quant_params(meta_graph_->allTensors[index]->quantParams.size()); + std::transform( + meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(), + quant_params.begin(), + [](std::unique_ptr &quant_param) -> schema::QuantParamT { return *quant_param; }); primitiveCValue->AddOutputQuantParam(quant_params); } } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 5352209da24..72a68f984fd 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -64,31 +64,33 @@ STATUS TfliteModelParser::CopyConstTensorData(const std::vector &tflite_tensor, schema::TensorT *tensor) { - std::unique_ptr quant_param = std::make_unique(); - if (!tflite_tensor->quantization->scale.empty()) { - quant_param->scale = tflite_tensor->quantization->scale[0]; - } - - if (!tflite_tensor->quantization->zero_point.empty()) { - quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0]; - } - - // change quant param min to 0 to fit ms-lite ops - if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) { - quant_param->zeroPoint = quant_param->zeroPoint - 128; - tensor->dataType = TypeId::kNumberTypeInt8; - } - - if (!tflite_tensor->quantization->min.empty()) { - quant_param->min = tflite_tensor->quantization->min[0]; - } - - if (!tflite_tensor->quantization->max.empty()) { - quant_param->max = tflite_tensor->quantization->max[0]; - } - quant_param->inited = true; tensor->quantParams.clear(); - tensor->quantParams.emplace_back(std::move(quant_param)); + for (size_t i = 0; i < tflite_tensor->quantization->scale.size(); i++) { + std::unique_ptr quant_param = std::make_unique(); + if (!tflite_tensor->quantization->scale.empty()) { + quant_param->scale = tflite_tensor->quantization->scale[i]; + } + + if (!tflite_tensor->quantization->zero_point.empty()) { + quant_param->zeroPoint = tflite_tensor->quantization->zero_point[i]; + } + + // change quant param min to 0 to fit ms-lite ops + if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) { + quant_param->zeroPoint = quant_param->zeroPoint - 128; + tensor->dataType = TypeId::kNumberTypeInt8; + } + + if (!tflite_tensor->quantization->min.empty()) { + quant_param->min = tflite_tensor->quantization->min[i]; + } + + if (!tflite_tensor->quantization->max.empty()) { + quant_param->max = tflite_tensor->quantization->max[i]; + } + quant_param->inited = true; + tensor->quantParams.emplace_back(std::move(quant_param)); + } } STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflite_model,