forked from mindspore-Ecosystem/mindspore
!4430 fix post training quant
Merge pull request !4430 from xutianchun/quant_0814
This commit is contained in:
commit
89127ccf65
|
@ -89,7 +89,7 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
|
|||
}
|
||||
auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release());
|
||||
// add quant parameter
|
||||
if (cNode->quantType == schema::QuantType_AwareTrainning || cNode->quantType == schema::QuantType_PostTraining) {
|
||||
if (cNode->quantType == schema::QuantType_AwareTrainning) {
|
||||
primTValue->SetQuantType(cNode->quantType);
|
||||
for (int index : cNode->inputIndex) {
|
||||
primTValue->AddInputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0]));
|
||||
|
|
|
@ -141,11 +141,11 @@ void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *
|
|||
// flags->bitNum));
|
||||
// break;
|
||||
// }
|
||||
// case mindspore::schema::QuantType_PostTraining: {
|
||||
// MS_LOG(INFO) << "create PostTrainningQuantizer!";
|
||||
// mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8));
|
||||
// break;
|
||||
// }
|
||||
case mindspore::schema::QuantType_PostTraining: {
|
||||
MS_LOG(INFO) << "create PostTrainningQuantizer!";
|
||||
mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8));
|
||||
break;
|
||||
}
|
||||
case mindspore::schema::QuantType_QUANT_NONE:
|
||||
MS_LOG(INFO) << "Not do quantization for model!";
|
||||
break;
|
||||
|
|
|
@ -308,21 +308,24 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
|
|||
|
||||
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum,
|
||||
bool per_channel) {
|
||||
if (per_channel) {
|
||||
// per channel
|
||||
auto dims = weightPtr->tensor_shape();
|
||||
if (dims.size() < 1) {
|
||||
MS_LOG(ERROR) << "weight dims size error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// todo(x)
|
||||
if (dims.size() != 4) {
|
||||
MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer.";
|
||||
per_channel = false;
|
||||
} else {
|
||||
uint32_t channels = dims[3];
|
||||
if (channels == 0) {
|
||||
MS_LOG(ERROR) << "channels error 0";
|
||||
MS_LOG(ERROR) << "channels is 0";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
if (per_channel) {
|
||||
// notice:
|
||||
// at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D
|
||||
// if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK
|
||||
size_t shapeSize = weightPtr->tensor_shape_size();
|
||||
auto channels = dims[3];
|
||||
size_t oneFilterSize = shapeSize / channels;
|
||||
auto *rawDatas = reinterpret_cast<const float *>(weightPtr->tensor_addr());
|
||||
if (rawDatas == nullptr) {
|
||||
|
@ -330,17 +333,17 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
float min = FLT_MAX;
|
||||
float max = FLT_MIN;
|
||||
weightPtr->quant_param().clear();
|
||||
vector<int8_t> qDatas(shapeSize);
|
||||
|
||||
for (uint32_t i = 0; i < channels; i++) {
|
||||
float min = 0;
|
||||
float max = 0;
|
||||
// find min and max
|
||||
for (uint32_t j = 0; j < oneFilterSize; j++) {
|
||||
min = std::min(min, rawDatas[j + i * oneFilterSize]);
|
||||
max = std::max(max, rawDatas[j + i * oneFilterSize]);
|
||||
min = std::min(min, rawDatas[i + j * oneFilterSize]);
|
||||
max = std::max(max, rawDatas[i + j * oneFilterSize]);
|
||||
}
|
||||
|
||||
std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam);
|
||||
STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum);
|
||||
if (status != RET_OK) {
|
||||
|
@ -349,11 +352,10 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
|
|||
}
|
||||
// update data and datatype
|
||||
for (uint32_t j = 0; j < oneFilterSize; j++) {
|
||||
float rawData = rawDatas[j + i * oneFilterSize];
|
||||
float rawData = rawDatas[i + j * oneFilterSize];
|
||||
auto qData = QuantizeData<int8_t>(rawData, quantParam.get(), quant_max, quant_min);
|
||||
qDatas[j + i * oneFilterSize] = qData;
|
||||
qDatas[i + j * oneFilterSize] = qData;
|
||||
}
|
||||
|
||||
weightPtr->set_quant_param(quantParam);
|
||||
}
|
||||
auto ret =
|
||||
|
|
Loading…
Reference in New Issue