forked from mindspore-Ecosystem/mindspore
tflite parser suppport perchannel
This commit is contained in:
parent
8df757143a
commit
61f2c7e32b
|
@ -22,6 +22,7 @@
|
|||
#include "tools/anf_importer/import_from_meta_graphT.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
int AnfImporterFromMetaGraphT::ConverterConstTensor() {
|
||||
|
@ -75,8 +76,12 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s
|
|||
if (cNode->quantType != schema::QuantType_PostTraining && cNode->quantType != schema::QuantType_WeightQuant) {
|
||||
primitiveCValue->SetQuantType(cNode->quantType);
|
||||
for (int index : cNode->inputIndex) {
|
||||
if (meta_graph_->allTensors[index]->quantParams.size() > 0) {
|
||||
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])};
|
||||
if (!meta_graph_->allTensors[index]->quantParams.empty()) {
|
||||
std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size());
|
||||
std::transform(
|
||||
meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(),
|
||||
quant_params.begin(),
|
||||
[](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; });
|
||||
primitiveCValue->AddInputQuantParam(quant_params);
|
||||
} else {
|
||||
std::vector<schema::QuantParamT> empty_quant_params;
|
||||
|
@ -84,8 +89,12 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s
|
|||
}
|
||||
}
|
||||
for (int index : cNode->outputIndex) {
|
||||
if (meta_graph_->allTensors[index]->quantParams.size() > 0) {
|
||||
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])};
|
||||
if (!meta_graph_->allTensors[index]->quantParams.empty()) {
|
||||
std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size());
|
||||
std::transform(
|
||||
meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(),
|
||||
quant_params.begin(),
|
||||
[](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; });
|
||||
primitiveCValue->AddOutputQuantParam(quant_params);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -64,31 +64,33 @@ STATUS TfliteModelParser::CopyConstTensorData(const std::vector<std::unique_ptr<
|
|||
|
||||
void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::TensorT> &tflite_tensor,
|
||||
schema::TensorT *tensor) {
|
||||
std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<QuantParamT>();
|
||||
if (!tflite_tensor->quantization->scale.empty()) {
|
||||
quant_param->scale = tflite_tensor->quantization->scale[0];
|
||||
}
|
||||
|
||||
if (!tflite_tensor->quantization->zero_point.empty()) {
|
||||
quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0];
|
||||
}
|
||||
|
||||
// change quant param min to 0 to fit ms-lite ops
|
||||
if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) {
|
||||
quant_param->zeroPoint = quant_param->zeroPoint - 128;
|
||||
tensor->dataType = TypeId::kNumberTypeInt8;
|
||||
}
|
||||
|
||||
if (!tflite_tensor->quantization->min.empty()) {
|
||||
quant_param->min = tflite_tensor->quantization->min[0];
|
||||
}
|
||||
|
||||
if (!tflite_tensor->quantization->max.empty()) {
|
||||
quant_param->max = tflite_tensor->quantization->max[0];
|
||||
}
|
||||
quant_param->inited = true;
|
||||
tensor->quantParams.clear();
|
||||
tensor->quantParams.emplace_back(std::move(quant_param));
|
||||
for (size_t i = 0; i < tflite_tensor->quantization->scale.size(); i++) {
|
||||
std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<QuantParamT>();
|
||||
if (!tflite_tensor->quantization->scale.empty()) {
|
||||
quant_param->scale = tflite_tensor->quantization->scale[i];
|
||||
}
|
||||
|
||||
if (!tflite_tensor->quantization->zero_point.empty()) {
|
||||
quant_param->zeroPoint = tflite_tensor->quantization->zero_point[i];
|
||||
}
|
||||
|
||||
// change quant param min to 0 to fit ms-lite ops
|
||||
if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) {
|
||||
quant_param->zeroPoint = quant_param->zeroPoint - 128;
|
||||
tensor->dataType = TypeId::kNumberTypeInt8;
|
||||
}
|
||||
|
||||
if (!tflite_tensor->quantization->min.empty()) {
|
||||
quant_param->min = tflite_tensor->quantization->min[i];
|
||||
}
|
||||
|
||||
if (!tflite_tensor->quantization->max.empty()) {
|
||||
quant_param->max = tflite_tensor->quantization->max[i];
|
||||
}
|
||||
quant_param->inited = true;
|
||||
tensor->quantParams.emplace_back(std::move(quant_param));
|
||||
}
|
||||
}
|
||||
|
||||
STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||
|
|
Loading…
Reference in New Issue