diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc index 0b47a8e6362..786f825a651 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc +++ b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc @@ -89,7 +89,7 @@ int AnfImporterFromMetaGraphT::ConverterCNode() { } auto primTValue = std::make_shared(cNode->primitive.release()); // add quant parameter - if (cNode->quantType == schema::QuantType_AwareTrainning || cNode->quantType == schema::QuantType_PostTraining) { + if (cNode->quantType == schema::QuantType_AwareTrainning) { primTValue->SetQuantType(cNode->quantType); for (int index : cNode->inputIndex) { primTValue->AddInputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0])); diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 07dd48268ec..f7a3905c02d 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -141,11 +141,11 @@ void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags * // flags->bitNum)); // break; // } - // case mindspore::schema::QuantType_PostTraining: { - // MS_LOG(INFO) << "create PostTrainningQuantizer!"; - // mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8)); - // break; - // } + case mindspore::schema::QuantType_PostTraining: { + MS_LOG(INFO) << "create PostTrainningQuantizer!"; + mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8)); + break; + } case mindspore::schema::QuantType_QUANT_NONE: MS_LOG(INFO) << "Not do quantization for model!"; break; diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 26d5fce2294..e03cbe79e91 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -308,21 +308,24 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum, bool per_channel) { - if (per_channel) { - // per channel - auto dims = weightPtr->tensor_shape(); - if (dims.size() < 1) { - MS_LOG(ERROR) << "weight dims size error"; - return RET_ERROR; - } - // todo(x) + auto dims = weightPtr->tensor_shape(); + if (dims.size() != 4) { + MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer."; + per_channel = false; + } else { uint32_t channels = dims[3]; if (channels == 0) { - MS_LOG(ERROR) << "channels error 0"; + MS_LOG(ERROR) << "channels is 0"; return RET_ERROR; } + } + if (per_channel) { + // notice: + // at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D + // if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK size_t shapeSize = weightPtr->tensor_shape_size(); + auto channels = dims[3]; size_t oneFilterSize = shapeSize / channels; auto *rawDatas = reinterpret_cast(weightPtr->tensor_addr()); if (rawDatas == nullptr) { @@ -330,17 +333,17 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_ return RET_ERROR; } + float min = FLT_MAX; + float max = FLT_MIN; weightPtr->quant_param().clear(); vector qDatas(shapeSize); + for (uint32_t i = 0; i < channels; i++) { - float min = 0; - float max = 0; // find min and max for (uint32_t j = 0; j < oneFilterSize; j++) { - min = std::min(min, rawDatas[j + i * oneFilterSize]); - max = std::max(max, rawDatas[j + i * oneFilterSize]); + min = std::min(min, rawDatas[i + j * oneFilterSize]); + max = std::max(max, rawDatas[i + j * oneFilterSize]); } - std::unique_ptr quantParam = std::unique_ptr(new AnfQuantParam); STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); if (status != RET_OK) { @@ -349,11 +352,10 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_ } // update data and datatype for (uint32_t j = 0; j < oneFilterSize; j++) { - float rawData = rawDatas[j + i * oneFilterSize]; + float rawData = rawDatas[i + j * oneFilterSize]; auto qData = QuantizeData(rawData, quantParam.get(), quant_max, quant_min); - qDatas[j + i * oneFilterSize] = qData; + qDatas[i + j * oneFilterSize] = qData; } - weightPtr->set_quant_param(quantParam); } auto ret =