!6390 post training quantization: scale restrict and uniform weight shape

Merge pull request !6390 from xutianchun/quant_09
This commit is contained in:
mindspore-ci-bot 2020-09-22 21:47:57 +08:00 committed by Gitee
commit 88ded11f59
5 changed files with 85 additions and 117 deletions

View File

@ -516,8 +516,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
return RET_OK; return RET_OK;
} }
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchanel, STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c,
bool depthwise) { bool perchanel) {
// perlayer // perlayer
if (!weight->isa<Parameter>()) { if (!weight->isa<Parameter>()) {
MS_LOG(ERROR) << "not a parameter"; MS_LOG(ERROR) << "not a parameter";
@ -534,7 +534,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P
return RET_ERROR; return RET_ERROR;
} }
auto status = QuantFilter<int8_t>(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num, auto status = QuantFilter<int8_t>(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num,
perchanel, depthwise); perchanel);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status; MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status; return status;
@ -608,7 +608,6 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
quant_param.inited = true; quant_param.inited = true;
quant_params.emplace_back(quant_param); quant_params.emplace_back(quant_param);
} }
primitive_c->AddInputQuantParam(quant_params);
// quant bias data // quant bias data
int32_t *quant_datas = new (std::nothrow) int32_t[shape_size]; int32_t *quant_datas = new (std::nothrow) int32_t[shape_size];
if (quant_datas == nullptr) { if (quant_datas == nullptr) {
@ -617,15 +616,35 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
} }
float *raw_datas = static_cast<float *>(bias_param->tensor_addr()); float *raw_datas = static_cast<float *>(bias_param->tensor_addr());
double bias_scale_tmp; double bias_scale_tmp;
constexpr int32_t quanted_bias_abs_limit = 0.5 * INT32_MAX;
for (size_t i = 0; i < shape_size; i++) { for (size_t i = 0; i < shape_size; i++) {
if (bias_scales.size() == 1) { if (bias_scales.size() == 1) {
bias_scale_tmp = bias_scales[0]; bias_scale_tmp = bias_scales[0];
} else { } else {
bias_scale_tmp = bias_scales[i]; bias_scale_tmp = bias_scales[i];
} }
if (std::abs(raw_datas[i] / bias_scale_tmp) >= quanted_bias_abs_limit) {
MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << active_weight_quant_params[1][i].scale
<< " is too small, need to update";
// update filter scale and zp
if (input_scales.size() == 1 && active_weight_quant_params[1].size() == shape_size) {
double activate_scale = input_scales[0];
double filter_scale = std::abs(raw_datas[i]) / (activate_scale * quanted_bias_abs_limit);
active_weight_quant_params[1][i].scale = filter_scale;
active_weight_quant_params[1][i].zeroPoint = 0;
primitive_c->SetInputQuantParam(active_weight_quant_params);
bias_scale_tmp = std::abs(raw_datas[i]) / quanted_bias_abs_limit;
quant_params[i].scale = bias_scale_tmp;
MS_LOG(DEBUG) << "new filter scale: " << filter_scale;
} else {
MS_LOG(WARNING) << "unexpected input_scales size: " << input_scales.size() << " weight_scales size: "
<< active_weight_quant_params[1].size();
}
}
auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp); auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp);
quant_datas[i] = quant_data; quant_datas[i] = quant_data;
} }
primitive_c->AddInputQuantParam(quant_params);
auto ret = memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t)); auto ret = memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t));
if (ret != EOK) { if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed."; MS_LOG(ERROR) << "memcpy_s failed.";
@ -659,9 +678,9 @@ STATUS PostTrainingQuantizer::QuantNode() {
auto cnodes = funcGraph->GetOrderedCnodes(); auto cnodes = funcGraph->GetOrderedCnodes();
for (auto &cnode : cnodes) { for (auto &cnode : cnodes) {
auto cnode_name = cnode->fullname_with_scope(); auto op_name = cnode->fullname_with_scope();
if (this->calibrator_->GetInputDivergInfo()->find(cnode_name) == this->calibrator_->GetInputDivergInfo()->end()) { if (this->calibrator_->GetInputDivergInfo()->find(op_name) == this->calibrator_->GetInputDivergInfo()->end()) {
MS_LOG(INFO) << cnode_name << " can not do quant"; MS_LOG(INFO) << op_name << " can not do quant";
continue; continue;
} }
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0)); auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
@ -673,8 +692,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
primitive_c->SetQuantType(schema::QuantType_QUANT_NONE); primitive_c->SetQuantType(schema::QuantType_QUANT_NONE);
continue; continue;
} }
primitive_c->ClearInputOutputQuantParam();
auto op_name = cnode->fullname_with_scope();
auto op_type = (schema::PrimitiveType)primitive_c->Type(); auto op_type = (schema::PrimitiveType)primitive_c->Type();
MS_LOG(INFO) << "OpName: " << op_name; MS_LOG(INFO) << "OpName: " << op_name;
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D && if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D &&
@ -682,7 +700,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
for (size_t i = 1; i < cnode->inputs().size(); i++) { for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i); auto input_node = cnode->input(i);
if (!input_node->isa<mindspore::CNode>()) { if (!input_node->isa<mindspore::CNode>()) {
MS_LOG(DEBUG) << "node: " << cnode_name << " input " << i << " not a cnode"; MS_LOG(DEBUG) << "node: " << op_name << " input " << i << " not a cnode";
// get dtype // get dtype
auto abstractBase = input_node->abstract(); auto abstractBase = input_node->abstract();
if (abstractBase == nullptr) { if (abstractBase == nullptr) {
@ -696,7 +714,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) { if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) {
MS_LOG(DEBUG) << "this parameter do quant"; MS_LOG(DEBUG) << "this parameter do quant";
DoWeightQuant(input_node, primitive_c, false, false); DoWeightQuant(input_node, primitive_c, false);
} else { } else {
MS_LOG(DEBUG) << "this parameter no need to do quant"; MS_LOG(DEBUG) << "this parameter no need to do quant";
} }
@ -727,12 +745,11 @@ STATUS PostTrainingQuantizer::QuantNode() {
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c); DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c);
// do weight quant // do weight quant
auto weight = cnode->input(2); auto weight = cnode->input(2);
bool depthwise = op_type == PrimitiveType_DepthwiseConv2D;
bool perchannel = per_channel_; bool perchannel = per_channel_;
if (op_type == PrimitiveType_FullConnection) { if (op_type == PrimitiveType_FullConnection) {
perchannel = false; perchannel = false;
} }
DoWeightQuant(weight, primitive_c, perchannel, depthwise); DoWeightQuant(weight, primitive_c, perchannel);
// do bias quant // do bias quant
if (cnode->inputs().size() == 4) { if (cnode->inputs().size() == 4) {
auto bias = cnode->input(3); auto bias = cnode->input(3);

View File

@ -89,7 +89,7 @@ class PostTrainingQuantizer : public Quantizer {
STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>); STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>); STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel, bool depthwise); STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel);
STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitive_c); STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitive_c);
}; };

View File

@ -109,7 +109,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan
const int minLimit = quant_min; const int minLimit = quant_min;
return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] { return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
int quant_data = std::round(originData / scale + zeroPoint); auto quant_data = std::round(originData / scale + zeroPoint);
if (quant_data > maxLimit) { if (quant_data > maxLimit) {
quant_data = maxLimit; quant_data = maxLimit;
} else if (quant_data < minLimit) { } else if (quant_data < minLimit) {
@ -120,7 +120,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan
} }
template <typename T> template <typename T>
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType, STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) { int quant_max, int quant_min, size_t bitNum, bool per_channel) {
auto dims = weight->tensor_shape(); auto dims = weight->tensor_shape();
if (per_channel) { if (per_channel) {
if (dims.size() != 4) { if (dims.size() != 4) {
@ -145,57 +145,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
std::vector<T> quant_datas(elem_count); std::vector<T> quant_datas(elem_count);
if (per_channel) { if (per_channel) {
// notice: // notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC
// at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D
// if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK
if (depth_wise) {
// channel at last
auto channels = dims[3];
if (channels == 0) {
MS_LOG(ERROR) << "channels is zero";
return RET_ERROR;
}
size_t one_filter_size = elem_count / channels;
for (int i = 0; i < channels; i++) {
float min = FLT_MAX;
float max = -FLT_MAX;
// find min and max
for (size_t j = 0; j < one_filter_size; j++) {
auto index = i + j * channels;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
min = std::min(min, raw_datas[index]);
max = std::max(max, raw_datas[index]);
}
schema::QuantParamT quant_param;
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
}
quant_params.emplace_back(quant_param);
// do quantization
for (uint32_t j = 0; j < one_filter_size; j++) {
auto index = i + j * channels;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
float raw_data = raw_datas[index];
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
quant_datas[index] = quant_data;
}
}
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(T));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
}
weight->set_tensor_size(elem_count * sizeof(T));
} else {
// channel at first // channel at first
auto channels = dims[0]; auto channels = dims[0];
if (channels == 0) { if (channels == 0) {
@ -242,7 +192,6 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
return RET_ERROR; return RET_ERROR;
} }
weight->set_tensor_size(elem_count * sizeof(T)); weight->set_tensor_size(elem_count * sizeof(T));
}
} else { } else {
// per layer // per layer
float min = FLT_MAX; float min = FLT_MAX;

View File

@ -97,8 +97,7 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
std::vector<schema::QuantParamT> quant_params; std::vector<schema::QuantParamT> quant_params;
primitive_c->AddInputQuantParam(quant_params); primitive_c->AddInputQuantParam(quant_params);
auto status = auto status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
quant_max, quant_min, bitNum, true, false);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status; MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status; return status;
@ -160,8 +159,8 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
std::vector<schema::QuantParamT> quant_params; std::vector<schema::QuantParamT> quant_params;
primitive_c->AddInputQuantParam(quant_params); primitive_c->AddInputQuantParam(quant_params);
auto status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, auto status =
quant_max, quant_min, bitNum, true, false); QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status; MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status; return status;

View File

@ -40,6 +40,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node
MS_ASSERT(conv_cnode != nullptr); MS_ASSERT(conv_cnode != nullptr);
MS_ASSERT(param_value != nullptr); MS_ASSERT(param_value != nullptr);
switch (quant_type) { switch (quant_type) {
case schema::QuantType_PostTraining:
case QuantType_WeightQuant: case QuantType_WeightQuant:
case QuantType_QUANT_NONE:param_value->set_format(schema::Format::Format_KCHW); case QuantType_QUANT_NONE:param_value->set_format(schema::Format::Format_KCHW);
break; break;
@ -73,6 +74,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node,
} }
} }
break; break;
case QuantType_PostTraining:
case QuantType_WeightQuant: case QuantType_WeightQuant:
case QuantType_QUANT_NONE: { case QuantType_QUANT_NONE: {
// conv (K x C/group x kH x kW) group = 1 // conv (K x C/group x kH x kW) group = 1
@ -114,6 +116,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
} }
} }
break; break;
case QuantType_PostTraining:
case QuantType_WeightQuant: case QuantType_WeightQuant:
case QuantType_QUANT_NONE: { case QuantType_QUANT_NONE: {
// sum up from current ms quant models // sum up from current ms quant models