tflite parser suppport perchannel

This commit is contained in:
cjh9368 2020-10-10 14:42:44 +08:00
parent 8df757143a
commit 61f2c7e32b
2 changed files with 39 additions and 28 deletions

View File

@ -22,6 +22,7 @@
#include "tools/anf_importer/import_from_meta_graphT.h" #include "tools/anf_importer/import_from_meta_graphT.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "tools/common/tensor_util.h"
namespace mindspore::lite { namespace mindspore::lite {
int AnfImporterFromMetaGraphT::ConverterConstTensor() { int AnfImporterFromMetaGraphT::ConverterConstTensor() {
@ -75,8 +76,12 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s
if (cNode->quantType != schema::QuantType_PostTraining && cNode->quantType != schema::QuantType_WeightQuant) { if (cNode->quantType != schema::QuantType_PostTraining && cNode->quantType != schema::QuantType_WeightQuant) {
primitiveCValue->SetQuantType(cNode->quantType); primitiveCValue->SetQuantType(cNode->quantType);
for (int index : cNode->inputIndex) { for (int index : cNode->inputIndex) {
if (meta_graph_->allTensors[index]->quantParams.size() > 0) { if (!meta_graph_->allTensors[index]->quantParams.empty()) {
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size());
std::transform(
meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(),
quant_params.begin(),
[](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; });
primitiveCValue->AddInputQuantParam(quant_params); primitiveCValue->AddInputQuantParam(quant_params);
} else { } else {
std::vector<schema::QuantParamT> empty_quant_params; std::vector<schema::QuantParamT> empty_quant_params;
@ -84,8 +89,12 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s
} }
} }
for (int index : cNode->outputIndex) { for (int index : cNode->outputIndex) {
if (meta_graph_->allTensors[index]->quantParams.size() > 0) { if (!meta_graph_->allTensors[index]->quantParams.empty()) {
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size());
std::transform(
meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(),
quant_params.begin(),
[](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; });
primitiveCValue->AddOutputQuantParam(quant_params); primitiveCValue->AddOutputQuantParam(quant_params);
} }
} }

View File

@ -64,31 +64,33 @@ STATUS TfliteModelParser::CopyConstTensorData(const std::vector<std::unique_ptr<
void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::TensorT> &tflite_tensor, void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::TensorT> &tflite_tensor,
schema::TensorT *tensor) { schema::TensorT *tensor) {
std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<QuantParamT>();
if (!tflite_tensor->quantization->scale.empty()) {
quant_param->scale = tflite_tensor->quantization->scale[0];
}
if (!tflite_tensor->quantization->zero_point.empty()) {
quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0];
}
// change quant param min to 0 to fit ms-lite ops
if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) {
quant_param->zeroPoint = quant_param->zeroPoint - 128;
tensor->dataType = TypeId::kNumberTypeInt8;
}
if (!tflite_tensor->quantization->min.empty()) {
quant_param->min = tflite_tensor->quantization->min[0];
}
if (!tflite_tensor->quantization->max.empty()) {
quant_param->max = tflite_tensor->quantization->max[0];
}
quant_param->inited = true;
tensor->quantParams.clear(); tensor->quantParams.clear();
tensor->quantParams.emplace_back(std::move(quant_param)); for (size_t i = 0; i < tflite_tensor->quantization->scale.size(); i++) {
std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<QuantParamT>();
if (!tflite_tensor->quantization->scale.empty()) {
quant_param->scale = tflite_tensor->quantization->scale[i];
}
if (!tflite_tensor->quantization->zero_point.empty()) {
quant_param->zeroPoint = tflite_tensor->quantization->zero_point[i];
}
// change quant param min to 0 to fit ms-lite ops
if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) {
quant_param->zeroPoint = quant_param->zeroPoint - 128;
tensor->dataType = TypeId::kNumberTypeInt8;
}
if (!tflite_tensor->quantization->min.empty()) {
quant_param->min = tflite_tensor->quantization->min[i];
}
if (!tflite_tensor->quantization->max.empty()) {
quant_param->max = tflite_tensor->quantization->max[i];
}
quant_param->inited = true;
tensor->quantParams.emplace_back(std::move(quant_param));
}
} }
STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflite_model, STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflite_model,