Fix bug of the loss of quantization parameters in quantization models.
This commit is contained in:
parent
fbf8a3bbcc
commit
5170065722
|
@ -64,8 +64,6 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
|
|||
MS_ASSERT(dst_node != nullptr);
|
||||
// add quant param
|
||||
dst_node->quantType = primitive->GetQuantType();
|
||||
if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining
|
||||
|| dst_node->quantType == schema::QuantType_WeightQuant) {
|
||||
MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam";
|
||||
// activation
|
||||
auto input_quant_params = primitive->GetInputQuantParams();
|
||||
|
@ -124,7 +122,6 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -83,12 +83,15 @@ ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<s
|
|||
auto primitiveCValue = PrimitiveC::UnPackFromSchemaPrimitiveT(cNode->primitive.release());
|
||||
cNode->primitive = nullptr;
|
||||
// add quant parameter
|
||||
if (cNode->quantType == schema::QuantType_AwareTraining) {
|
||||
if (cNode->quantType != schema::QuantType_PostTraining) {
|
||||
primitiveCValue->SetQuantType(cNode->quantType);
|
||||
for (int index : cNode->inputIndex) {
|
||||
if (meta_graph_->allTensors[index]->quantParams.size() > 0) {
|
||||
std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])};
|
||||
primitiveCValue->AddInputQuantParam(quant_params);
|
||||
} else {
|
||||
std::vector<schema::QuantParamT> empty_quant_params;
|
||||
primitiveCValue->AddInputQuantParam(empty_quant_params);
|
||||
}
|
||||
}
|
||||
for (int index : cNode->outputIndex) {
|
||||
|
|
|
@ -38,27 +38,37 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
|
|||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
const auto &in_tensor = tflite_tensors[tflite_op->inputs[0]];
|
||||
if (in_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "input tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->srcT = GetTfliteDataType(in_tensor->type);
|
||||
const auto &out_tensor = tflite_tensors[tflite_op->outputs[0]];
|
||||
if (out_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "output tensor is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8) {
|
||||
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->srcT = GetTfliteDataType(in_tensor->type);
|
||||
attr->dstT = GetTfliteDataType(out_tensor->type);
|
||||
|
||||
op->primitive->value.type = schema::PrimitiveType_Cast;
|
||||
op->primitive->value.value = attr.release();
|
||||
op->primitive->value.type = schema::PrimitiveType_QuantDTypeCast;
|
||||
} else {
|
||||
std::unique_ptr<schema::CastT> attr = std::make_unique<schema::CastT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
attr->srcT = GetTfliteDataType(in_tensor->type);
|
||||
attr->dstT = GetTfliteDataType(out_tensor->type);
|
||||
op->primitive->value.value = attr.release();
|
||||
op->primitive->value.type = schema::PrimitiveType_Cast;
|
||||
}
|
||||
|
||||
AddOpInput(op, tensors_id, tensors_format, tensors_id_map,
|
||||
tflite_op->inputs[0], tensors_id->size(), tflite_tensors.size(), schema::Format_NHWC);
|
||||
|
|
|
@ -200,6 +200,24 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
|
|||
FreeTensors(&input_tensors, &output_tensors);
|
||||
return nullptr;
|
||||
}
|
||||
auto inputQuantParams = lite_primitive->GetInputQuantParams();
|
||||
for (size_t m = 0; m < inputQuantParams.size(); m++) {
|
||||
for (auto inputQuantParam : inputQuantParams[m]) {
|
||||
lite::tensor::QuantArg quant_arg{};
|
||||
quant_arg.scale = inputQuantParam.scale;
|
||||
quant_arg.zeroPoint = inputQuantParam.zeroPoint;
|
||||
input_tensors[m]->AddQuantParam(quant_arg);
|
||||
}
|
||||
}
|
||||
auto outputQuantParams = lite_primitive->GetOutputQuantParams();
|
||||
for (size_t m = 0; m < outputQuantParams.size(); m++) {
|
||||
for (auto outputQuantParam : outputQuantParams[m]) {
|
||||
lite::tensor::QuantArg quant_arg{};
|
||||
quant_arg.scale = outputQuantParam.scale;
|
||||
quant_arg.zeroPoint = outputQuantParam.zeroPoint;
|
||||
output_tensors[m]->AddQuantParam(quant_arg);
|
||||
}
|
||||
}
|
||||
// here, input_tensor's format need to be transposed nhwc according to fmkType,
|
||||
// but for the time being, we only transpose the tensor with 0/1/2/3D.
|
||||
// Others should be added in future.
|
||||
|
|
Loading…
Reference in New Issue