!4430 fix post training quant

Merge pull request !4430 from xutianchun/quant_0814
This commit is contained in:
mindspore-ci-bot 2020-08-14 10:17:42 +08:00 committed by Gitee
commit 89127ccf65
3 changed files with 25 additions and 23 deletions

View File

@ -89,7 +89,7 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
}
auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release());
// add quant parameter
if (cNode->quantType == schema::QuantType_AwareTrainning || cNode->quantType == schema::QuantType_PostTraining) {
if (cNode->quantType == schema::QuantType_AwareTrainning) {
primTValue->SetQuantType(cNode->quantType);
for (int index : cNode->inputIndex) {
primTValue->AddInputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0]));

View File

@ -141,11 +141,11 @@ void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *
// flags->bitNum));
// break;
// }
// case mindspore::schema::QuantType_PostTraining: {
// MS_LOG(INFO) << "create PostTrainningQuantizer!";
// mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8));
// break;
// }
case mindspore::schema::QuantType_PostTraining: {
MS_LOG(INFO) << "create PostTrainningQuantizer!";
mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8));
break;
}
case mindspore::schema::QuantType_QUANT_NONE:
MS_LOG(INFO) << "Not do quantization for model!";
break;

View File

@ -308,21 +308,24 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum,
bool per_channel) {
if (per_channel) {
// per channel
auto dims = weightPtr->tensor_shape();
if (dims.size() < 1) {
MS_LOG(ERROR) << "weight dims size error";
return RET_ERROR;
}
// todo(x)
auto dims = weightPtr->tensor_shape();
if (dims.size() != 4) {
MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer.";
per_channel = false;
} else {
uint32_t channels = dims[3];
if (channels == 0) {
MS_LOG(ERROR) << "channels error 0";
MS_LOG(ERROR) << "channels is 0";
return RET_ERROR;
}
}
if (per_channel) {
// notice:
// at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D
// if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK
size_t shapeSize = weightPtr->tensor_shape_size();
auto channels = dims[3];
size_t oneFilterSize = shapeSize / channels;
auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr());
if (rawDatas == nullptr) {
@ -330,17 +333,17 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
return RET_ERROR;
}
float min = FLT_MAX;
float max = FLT_MIN;
weightPtr->quant_param().clear();
vector<int8_t> qDatas(shapeSize);
for (uint32_t i = 0; i < channels; i++) {
float min = 0;
float max = 0;
// find min and max
for (uint32_t j = 0; j < oneFilterSize; j++) {
min = std::min(min, rawDatas[j + i * oneFilterSize]);
max = std::max(max, rawDatas[j + i * oneFilterSize]);
min = std::min(min, rawDatas[i + j * oneFilterSize]);
max = std::max(max, rawDatas[i + j * oneFilterSize]);
}
std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam);
STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) {
@ -349,11 +352,10 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
}
// update data and datatype
for (uint32_t j = 0; j < oneFilterSize; j++) {
float rawData = rawDatas[j + i * oneFilterSize];
float rawData = rawDatas[i + j * oneFilterSize];
auto qData = QuantizeData<int8_t>(rawData, quantParam.get(), quant_max, quant_min);
qDatas[j + i * oneFilterSize] = qData;
qDatas[i + j * oneFilterSize] = qData;
}
weightPtr->set_quant_param(quantParam);
}
auto ret =