forked from mindspore-Ecosystem/mindspore
tflite parser suppport perchannel
This commit is contained in:
parent
8df757143a
commit
61f2c7e32b
|
@ -22,6 +22,7 @@
|
||||||
#include "tools/anf_importer/import_from_meta_graphT.h"
|
#include "tools/anf_importer/import_from_meta_graphT.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
|
#include "tools/common/tensor_util.h"
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
int AnfImporterFromMetaGraphT::ConverterConstTensor() {
|
int AnfImporterFromMetaGraphT::ConverterConstTensor() {
|
||||||
|
@ -75,8 +76,12 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s
|
||||||
if (cNode->quantType != schema::QuantType_PostTraining && cNode->quantType != schema::QuantType_WeightQuant) {
|
if (cNode->quantType != schema::QuantType_PostTraining && cNode->quantType != schema::QuantType_WeightQuant) {
|
||||||
primitiveCValue->SetQuantType(cNode->quantType);
|
primitiveCValue->SetQuantType(cNode->quantType);
|
||||||
for (int index : cNode->inputIndex) {
|
for (int index : cNode->inputIndex) {
|
||||||
if (meta_graph_->allTensors[index]->quantParams.size() > 0) {
|
if (!meta_graph_->allTensors[index]->quantParams.empty()) {
|
||||||
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])};
|
std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size());
|
||||||
|
std::transform(
|
||||||
|
meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(),
|
||||||
|
quant_params.begin(),
|
||||||
|
[](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; });
|
||||||
primitiveCValue->AddInputQuantParam(quant_params);
|
primitiveCValue->AddInputQuantParam(quant_params);
|
||||||
} else {
|
} else {
|
||||||
std::vector<schema::QuantParamT> empty_quant_params;
|
std::vector<schema::QuantParamT> empty_quant_params;
|
||||||
|
@ -84,8 +89,12 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int index : cNode->outputIndex) {
|
for (int index : cNode->outputIndex) {
|
||||||
if (meta_graph_->allTensors[index]->quantParams.size() > 0) {
|
if (!meta_graph_->allTensors[index]->quantParams.empty()) {
|
||||||
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])};
|
std::vector<schema::QuantParamT> quant_params(meta_graph_->allTensors[index]->quantParams.size());
|
||||||
|
std::transform(
|
||||||
|
meta_graph_->allTensors[index]->quantParams.begin(), meta_graph_->allTensors[index]->quantParams.end(),
|
||||||
|
quant_params.begin(),
|
||||||
|
[](std::unique_ptr<schema::QuantParamT> &quant_param) -> schema::QuantParamT { return *quant_param; });
|
||||||
primitiveCValue->AddOutputQuantParam(quant_params);
|
primitiveCValue->AddOutputQuantParam(quant_params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,31 +64,33 @@ STATUS TfliteModelParser::CopyConstTensorData(const std::vector<std::unique_ptr<
|
||||||
|
|
||||||
void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::TensorT> &tflite_tensor,
|
void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::TensorT> &tflite_tensor,
|
||||||
schema::TensorT *tensor) {
|
schema::TensorT *tensor) {
|
||||||
std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<QuantParamT>();
|
|
||||||
if (!tflite_tensor->quantization->scale.empty()) {
|
|
||||||
quant_param->scale = tflite_tensor->quantization->scale[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!tflite_tensor->quantization->zero_point.empty()) {
|
|
||||||
quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
// change quant param min to 0 to fit ms-lite ops
|
|
||||||
if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) {
|
|
||||||
quant_param->zeroPoint = quant_param->zeroPoint - 128;
|
|
||||||
tensor->dataType = TypeId::kNumberTypeInt8;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!tflite_tensor->quantization->min.empty()) {
|
|
||||||
quant_param->min = tflite_tensor->quantization->min[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!tflite_tensor->quantization->max.empty()) {
|
|
||||||
quant_param->max = tflite_tensor->quantization->max[0];
|
|
||||||
}
|
|
||||||
quant_param->inited = true;
|
|
||||||
tensor->quantParams.clear();
|
tensor->quantParams.clear();
|
||||||
tensor->quantParams.emplace_back(std::move(quant_param));
|
for (size_t i = 0; i < tflite_tensor->quantization->scale.size(); i++) {
|
||||||
|
std::unique_ptr<schema::QuantParamT> quant_param = std::make_unique<QuantParamT>();
|
||||||
|
if (!tflite_tensor->quantization->scale.empty()) {
|
||||||
|
quant_param->scale = tflite_tensor->quantization->scale[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!tflite_tensor->quantization->zero_point.empty()) {
|
||||||
|
quant_param->zeroPoint = tflite_tensor->quantization->zero_point[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// change quant param min to 0 to fit ms-lite ops
|
||||||
|
if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) {
|
||||||
|
quant_param->zeroPoint = quant_param->zeroPoint - 128;
|
||||||
|
tensor->dataType = TypeId::kNumberTypeInt8;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!tflite_tensor->quantization->min.empty()) {
|
||||||
|
quant_param->min = tflite_tensor->quantization->min[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!tflite_tensor->quantization->max.empty()) {
|
||||||
|
quant_param->max = tflite_tensor->quantization->max[i];
|
||||||
|
}
|
||||||
|
quant_param->inited = true;
|
||||||
|
tensor->quantParams.emplace_back(std::move(quant_param));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
|
STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
|
||||||
|
|
Loading…
Reference in New Issue