forked from mindspore-Ecosystem/mindspore
!6390 post training quantization: scale restrict and uniform weight shape
Merge pull request !6390 from xutianchun/quant_09
This commit is contained in:
commit
88ded11f59
|
@ -516,8 +516,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchanel,
|
||||
bool depthwise) {
|
||||
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c,
|
||||
bool perchanel) {
|
||||
// perlayer
|
||||
if (!weight->isa<Parameter>()) {
|
||||
MS_LOG(ERROR) << "not a parameter";
|
||||
|
@ -534,7 +534,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto status = QuantFilter<int8_t>(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num,
|
||||
perchanel, depthwise);
|
||||
perchanel);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||
return status;
|
||||
|
@ -608,7 +608,6 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
|
|||
quant_param.inited = true;
|
||||
quant_params.emplace_back(quant_param);
|
||||
}
|
||||
primitive_c->AddInputQuantParam(quant_params);
|
||||
// quant bias data
|
||||
int32_t *quant_datas = new (std::nothrow) int32_t[shape_size];
|
||||
if (quant_datas == nullptr) {
|
||||
|
@ -617,15 +616,35 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
|
|||
}
|
||||
float *raw_datas = static_cast<float *>(bias_param->tensor_addr());
|
||||
double bias_scale_tmp;
|
||||
constexpr int32_t quanted_bias_abs_limit = 0.5 * INT32_MAX;
|
||||
for (size_t i = 0; i < shape_size; i++) {
|
||||
if (bias_scales.size() == 1) {
|
||||
bias_scale_tmp = bias_scales[0];
|
||||
} else {
|
||||
bias_scale_tmp = bias_scales[i];
|
||||
}
|
||||
if (std::abs(raw_datas[i] / bias_scale_tmp) >= quanted_bias_abs_limit) {
|
||||
MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << active_weight_quant_params[1][i].scale
|
||||
<< " is too small, need to update";
|
||||
// update filter scale and zp
|
||||
if (input_scales.size() == 1 && active_weight_quant_params[1].size() == shape_size) {
|
||||
double activate_scale = input_scales[0];
|
||||
double filter_scale = std::abs(raw_datas[i]) / (activate_scale * quanted_bias_abs_limit);
|
||||
active_weight_quant_params[1][i].scale = filter_scale;
|
||||
active_weight_quant_params[1][i].zeroPoint = 0;
|
||||
primitive_c->SetInputQuantParam(active_weight_quant_params);
|
||||
bias_scale_tmp = std::abs(raw_datas[i]) / quanted_bias_abs_limit;
|
||||
quant_params[i].scale = bias_scale_tmp;
|
||||
MS_LOG(DEBUG) << "new filter scale: " << filter_scale;
|
||||
} else {
|
||||
MS_LOG(WARNING) << "unexpected input_scales size: " << input_scales.size() << " weight_scales size: "
|
||||
<< active_weight_quant_params[1].size();
|
||||
}
|
||||
}
|
||||
auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp);
|
||||
quant_datas[i] = quant_data;
|
||||
}
|
||||
primitive_c->AddInputQuantParam(quant_params);
|
||||
auto ret = memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s failed.";
|
||||
|
@ -659,9 +678,9 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
|
||||
auto cnodes = funcGraph->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
auto cnode_name = cnode->fullname_with_scope();
|
||||
if (this->calibrator_->GetInputDivergInfo()->find(cnode_name) == this->calibrator_->GetInputDivergInfo()->end()) {
|
||||
MS_LOG(INFO) << cnode_name << " can not do quant";
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
if (this->calibrator_->GetInputDivergInfo()->find(op_name) == this->calibrator_->GetInputDivergInfo()->end()) {
|
||||
MS_LOG(INFO) << op_name << " can not do quant";
|
||||
continue;
|
||||
}
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
|
@ -673,8 +692,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
primitive_c->SetQuantType(schema::QuantType_QUANT_NONE);
|
||||
continue;
|
||||
}
|
||||
primitive_c->ClearInputOutputQuantParam();
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
|
||||
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
||||
MS_LOG(INFO) << "OpName: " << op_name;
|
||||
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D &&
|
||||
|
@ -682,7 +700,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||
auto input_node = cnode->input(i);
|
||||
if (!input_node->isa<mindspore::CNode>()) {
|
||||
MS_LOG(DEBUG) << "node: " << cnode_name << " input " << i << " not a cnode";
|
||||
MS_LOG(DEBUG) << "node: " << op_name << " input " << i << " not a cnode";
|
||||
// get dtype
|
||||
auto abstractBase = input_node->abstract();
|
||||
if (abstractBase == nullptr) {
|
||||
|
@ -696,7 +714,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
||||
if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) {
|
||||
MS_LOG(DEBUG) << "this parameter do quant";
|
||||
DoWeightQuant(input_node, primitive_c, false, false);
|
||||
DoWeightQuant(input_node, primitive_c, false);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "this parameter no need to do quant";
|
||||
}
|
||||
|
@ -727,12 +745,11 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c);
|
||||
// do weight quant
|
||||
auto weight = cnode->input(2);
|
||||
bool depthwise = op_type == PrimitiveType_DepthwiseConv2D;
|
||||
bool perchannel = per_channel_;
|
||||
if (op_type == PrimitiveType_FullConnection) {
|
||||
perchannel = false;
|
||||
}
|
||||
DoWeightQuant(weight, primitive_c, perchannel, depthwise);
|
||||
DoWeightQuant(weight, primitive_c, perchannel);
|
||||
// do bias quant
|
||||
if (cnode->inputs().size() == 4) {
|
||||
auto bias = cnode->input(3);
|
||||
|
|
|
@ -89,7 +89,7 @@ class PostTrainingQuantizer : public Quantizer {
|
|||
STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
|
||||
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
|
||||
|
||||
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel, bool depthwise);
|
||||
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel);
|
||||
|
||||
STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitive_c);
|
||||
};
|
||||
|
|
|
@ -109,7 +109,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan
|
|||
const int minLimit = quant_min;
|
||||
|
||||
return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
|
||||
int quant_data = std::round(originData / scale + zeroPoint);
|
||||
auto quant_data = std::round(originData / scale + zeroPoint);
|
||||
if (quant_data > maxLimit) {
|
||||
quant_data = maxLimit;
|
||||
} else if (quant_data < minLimit) {
|
||||
|
@ -120,7 +120,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan
|
|||
}
|
||||
template <typename T>
|
||||
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
|
||||
int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) {
|
||||
int quant_max, int quant_min, size_t bitNum, bool per_channel) {
|
||||
auto dims = weight->tensor_shape();
|
||||
if (per_channel) {
|
||||
if (dims.size() != 4) {
|
||||
|
@ -145,57 +145,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
|
|||
std::vector<T> quant_datas(elem_count);
|
||||
|
||||
if (per_channel) {
|
||||
// notice:
|
||||
// at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D
|
||||
// if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK
|
||||
if (depth_wise) {
|
||||
// channel at last
|
||||
auto channels = dims[3];
|
||||
if (channels == 0) {
|
||||
MS_LOG(ERROR) << "channels is zero";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t one_filter_size = elem_count / channels;
|
||||
|
||||
for (int i = 0; i < channels; i++) {
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MAX;
|
||||
// find min and max
|
||||
for (size_t j = 0; j < one_filter_size; j++) {
|
||||
auto index = i + j * channels;
|
||||
if (index >= elem_count) {
|
||||
MS_LOG(ERROR) << "over flow!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
min = std::min(min, raw_datas[index]);
|
||||
max = std::max(max, raw_datas[index]);
|
||||
}
|
||||
schema::QuantParamT quant_param;
|
||||
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
|
||||
return status;
|
||||
}
|
||||
quant_params.emplace_back(quant_param);
|
||||
// do quantization
|
||||
for (uint32_t j = 0; j < one_filter_size; j++) {
|
||||
auto index = i + j * channels;
|
||||
if (index >= elem_count) {
|
||||
MS_LOG(ERROR) << "over flow!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
float raw_data = raw_datas[index];
|
||||
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
|
||||
quant_datas[index] = quant_data;
|
||||
}
|
||||
}
|
||||
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(T));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight->set_tensor_size(elem_count * sizeof(T));
|
||||
} else {
|
||||
// notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC
|
||||
// channel at first
|
||||
auto channels = dims[0];
|
||||
if (channels == 0) {
|
||||
|
@ -242,7 +192,6 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
|
|||
return RET_ERROR;
|
||||
}
|
||||
weight->set_tensor_size(elem_count * sizeof(T));
|
||||
}
|
||||
} else {
|
||||
// per layer
|
||||
float min = FLT_MAX;
|
||||
|
|
|
@ -97,8 +97,7 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
|
|||
std::vector<schema::QuantParamT> quant_params;
|
||||
primitive_c->AddInputQuantParam(quant_params);
|
||||
auto status =
|
||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant,
|
||||
quant_max, quant_min, bitNum, true, false);
|
||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
|
@ -160,8 +159,8 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
|
|||
|
||||
std::vector<schema::QuantParamT> quant_params;
|
||||
primitive_c->AddInputQuantParam(quant_params);
|
||||
auto status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant,
|
||||
quant_max, quant_min, bitNum, true, false);
|
||||
auto status =
|
||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
|
|
|
@ -40,6 +40,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node
|
|||
MS_ASSERT(conv_cnode != nullptr);
|
||||
MS_ASSERT(param_value != nullptr);
|
||||
switch (quant_type) {
|
||||
case schema::QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_QUANT_NONE:param_value->set_format(schema::Format::Format_KCHW);
|
||||
break;
|
||||
|
@ -73,6 +74,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node,
|
|||
}
|
||||
}
|
||||
break;
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_QUANT_NONE: {
|
||||
// conv (K x C/group x kH x kW) group = 1
|
||||
|
@ -114,6 +116,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
|
|||
}
|
||||
}
|
||||
break;
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_QUANT_NONE: {
|
||||
// sum up from current ms quant models
|
||||
|
|
Loading…
Reference in New Issue