forked from mindspore-Ecosystem/mindspore
!6390 post training quantization: scale restrict and uniform weight shape
Merge pull request !6390 from xutianchun/quant_09
This commit is contained in:
commit
88ded11f59
|
@ -516,8 +516,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchanel,
|
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c,
|
||||||
bool depthwise) {
|
bool perchanel) {
|
||||||
// perlayer
|
// perlayer
|
||||||
if (!weight->isa<Parameter>()) {
|
if (!weight->isa<Parameter>()) {
|
||||||
MS_LOG(ERROR) << "not a parameter";
|
MS_LOG(ERROR) << "not a parameter";
|
||||||
|
@ -534,7 +534,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto status = QuantFilter<int8_t>(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num,
|
auto status = QuantFilter<int8_t>(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num,
|
||||||
perchanel, depthwise);
|
perchanel);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||||
return status;
|
return status;
|
||||||
|
@ -608,7 +608,6 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
|
||||||
quant_param.inited = true;
|
quant_param.inited = true;
|
||||||
quant_params.emplace_back(quant_param);
|
quant_params.emplace_back(quant_param);
|
||||||
}
|
}
|
||||||
primitive_c->AddInputQuantParam(quant_params);
|
|
||||||
// quant bias data
|
// quant bias data
|
||||||
int32_t *quant_datas = new (std::nothrow) int32_t[shape_size];
|
int32_t *quant_datas = new (std::nothrow) int32_t[shape_size];
|
||||||
if (quant_datas == nullptr) {
|
if (quant_datas == nullptr) {
|
||||||
|
@ -617,15 +616,35 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
|
||||||
}
|
}
|
||||||
float *raw_datas = static_cast<float *>(bias_param->tensor_addr());
|
float *raw_datas = static_cast<float *>(bias_param->tensor_addr());
|
||||||
double bias_scale_tmp;
|
double bias_scale_tmp;
|
||||||
|
constexpr int32_t quanted_bias_abs_limit = 0.5 * INT32_MAX;
|
||||||
for (size_t i = 0; i < shape_size; i++) {
|
for (size_t i = 0; i < shape_size; i++) {
|
||||||
if (bias_scales.size() == 1) {
|
if (bias_scales.size() == 1) {
|
||||||
bias_scale_tmp = bias_scales[0];
|
bias_scale_tmp = bias_scales[0];
|
||||||
} else {
|
} else {
|
||||||
bias_scale_tmp = bias_scales[i];
|
bias_scale_tmp = bias_scales[i];
|
||||||
}
|
}
|
||||||
|
if (std::abs(raw_datas[i] / bias_scale_tmp) >= quanted_bias_abs_limit) {
|
||||||
|
MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << active_weight_quant_params[1][i].scale
|
||||||
|
<< " is too small, need to update";
|
||||||
|
// update filter scale and zp
|
||||||
|
if (input_scales.size() == 1 && active_weight_quant_params[1].size() == shape_size) {
|
||||||
|
double activate_scale = input_scales[0];
|
||||||
|
double filter_scale = std::abs(raw_datas[i]) / (activate_scale * quanted_bias_abs_limit);
|
||||||
|
active_weight_quant_params[1][i].scale = filter_scale;
|
||||||
|
active_weight_quant_params[1][i].zeroPoint = 0;
|
||||||
|
primitive_c->SetInputQuantParam(active_weight_quant_params);
|
||||||
|
bias_scale_tmp = std::abs(raw_datas[i]) / quanted_bias_abs_limit;
|
||||||
|
quant_params[i].scale = bias_scale_tmp;
|
||||||
|
MS_LOG(DEBUG) << "new filter scale: " << filter_scale;
|
||||||
|
} else {
|
||||||
|
MS_LOG(WARNING) << "unexpected input_scales size: " << input_scales.size() << " weight_scales size: "
|
||||||
|
<< active_weight_quant_params[1].size();
|
||||||
|
}
|
||||||
|
}
|
||||||
auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp);
|
auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp);
|
||||||
quant_datas[i] = quant_data;
|
quant_datas[i] = quant_data;
|
||||||
}
|
}
|
||||||
|
primitive_c->AddInputQuantParam(quant_params);
|
||||||
auto ret = memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t));
|
auto ret = memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t));
|
||||||
if (ret != EOK) {
|
if (ret != EOK) {
|
||||||
MS_LOG(ERROR) << "memcpy_s failed.";
|
MS_LOG(ERROR) << "memcpy_s failed.";
|
||||||
|
@ -659,9 +678,9 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
||||||
|
|
||||||
auto cnodes = funcGraph->GetOrderedCnodes();
|
auto cnodes = funcGraph->GetOrderedCnodes();
|
||||||
for (auto &cnode : cnodes) {
|
for (auto &cnode : cnodes) {
|
||||||
auto cnode_name = cnode->fullname_with_scope();
|
auto op_name = cnode->fullname_with_scope();
|
||||||
if (this->calibrator_->GetInputDivergInfo()->find(cnode_name) == this->calibrator_->GetInputDivergInfo()->end()) {
|
if (this->calibrator_->GetInputDivergInfo()->find(op_name) == this->calibrator_->GetInputDivergInfo()->end()) {
|
||||||
MS_LOG(INFO) << cnode_name << " can not do quant";
|
MS_LOG(INFO) << op_name << " can not do quant";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||||
|
@ -673,8 +692,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
||||||
primitive_c->SetQuantType(schema::QuantType_QUANT_NONE);
|
primitive_c->SetQuantType(schema::QuantType_QUANT_NONE);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
primitive_c->ClearInputOutputQuantParam();
|
|
||||||
auto op_name = cnode->fullname_with_scope();
|
|
||||||
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
||||||
MS_LOG(INFO) << "OpName: " << op_name;
|
MS_LOG(INFO) << "OpName: " << op_name;
|
||||||
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D &&
|
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D &&
|
||||||
|
@ -682,7 +700,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
||||||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||||
auto input_node = cnode->input(i);
|
auto input_node = cnode->input(i);
|
||||||
if (!input_node->isa<mindspore::CNode>()) {
|
if (!input_node->isa<mindspore::CNode>()) {
|
||||||
MS_LOG(DEBUG) << "node: " << cnode_name << " input " << i << " not a cnode";
|
MS_LOG(DEBUG) << "node: " << op_name << " input " << i << " not a cnode";
|
||||||
// get dtype
|
// get dtype
|
||||||
auto abstractBase = input_node->abstract();
|
auto abstractBase = input_node->abstract();
|
||||||
if (abstractBase == nullptr) {
|
if (abstractBase == nullptr) {
|
||||||
|
@ -696,7 +714,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
||||||
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
||||||
if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) {
|
if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) {
|
||||||
MS_LOG(DEBUG) << "this parameter do quant";
|
MS_LOG(DEBUG) << "this parameter do quant";
|
||||||
DoWeightQuant(input_node, primitive_c, false, false);
|
DoWeightQuant(input_node, primitive_c, false);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "this parameter no need to do quant";
|
MS_LOG(DEBUG) << "this parameter no need to do quant";
|
||||||
}
|
}
|
||||||
|
@ -727,12 +745,11 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
||||||
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c);
|
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c);
|
||||||
// do weight quant
|
// do weight quant
|
||||||
auto weight = cnode->input(2);
|
auto weight = cnode->input(2);
|
||||||
bool depthwise = op_type == PrimitiveType_DepthwiseConv2D;
|
|
||||||
bool perchannel = per_channel_;
|
bool perchannel = per_channel_;
|
||||||
if (op_type == PrimitiveType_FullConnection) {
|
if (op_type == PrimitiveType_FullConnection) {
|
||||||
perchannel = false;
|
perchannel = false;
|
||||||
}
|
}
|
||||||
DoWeightQuant(weight, primitive_c, perchannel, depthwise);
|
DoWeightQuant(weight, primitive_c, perchannel);
|
||||||
// do bias quant
|
// do bias quant
|
||||||
if (cnode->inputs().size() == 4) {
|
if (cnode->inputs().size() == 4) {
|
||||||
auto bias = cnode->input(3);
|
auto bias = cnode->input(3);
|
||||||
|
|
|
@ -89,7 +89,7 @@ class PostTrainingQuantizer : public Quantizer {
|
||||||
STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
|
STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
|
||||||
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
|
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
|
||||||
|
|
||||||
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel, bool depthwise);
|
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel);
|
||||||
|
|
||||||
STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitive_c);
|
STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveC> primitive_c);
|
||||||
};
|
};
|
||||||
|
|
|
@ -109,7 +109,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan
|
||||||
const int minLimit = quant_min;
|
const int minLimit = quant_min;
|
||||||
|
|
||||||
return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
|
return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
|
||||||
int quant_data = std::round(originData / scale + zeroPoint);
|
auto quant_data = std::round(originData / scale + zeroPoint);
|
||||||
if (quant_data > maxLimit) {
|
if (quant_data > maxLimit) {
|
||||||
quant_data = maxLimit;
|
quant_data = maxLimit;
|
||||||
} else if (quant_data < minLimit) {
|
} else if (quant_data < minLimit) {
|
||||||
|
@ -120,7 +120,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
|
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
|
||||||
int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) {
|
int quant_max, int quant_min, size_t bitNum, bool per_channel) {
|
||||||
auto dims = weight->tensor_shape();
|
auto dims = weight->tensor_shape();
|
||||||
if (per_channel) {
|
if (per_channel) {
|
||||||
if (dims.size() != 4) {
|
if (dims.size() != 4) {
|
||||||
|
@ -145,57 +145,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
|
||||||
std::vector<T> quant_datas(elem_count);
|
std::vector<T> quant_datas(elem_count);
|
||||||
|
|
||||||
if (per_channel) {
|
if (per_channel) {
|
||||||
// notice:
|
// notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC
|
||||||
// at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D
|
|
||||||
// if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK
|
|
||||||
if (depth_wise) {
|
|
||||||
// channel at last
|
|
||||||
auto channels = dims[3];
|
|
||||||
if (channels == 0) {
|
|
||||||
MS_LOG(ERROR) << "channels is zero";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
size_t one_filter_size = elem_count / channels;
|
|
||||||
|
|
||||||
for (int i = 0; i < channels; i++) {
|
|
||||||
float min = FLT_MAX;
|
|
||||||
float max = -FLT_MAX;
|
|
||||||
// find min and max
|
|
||||||
for (size_t j = 0; j < one_filter_size; j++) {
|
|
||||||
auto index = i + j * channels;
|
|
||||||
if (index >= elem_count) {
|
|
||||||
MS_LOG(ERROR) << "over flow!";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
min = std::min(min, raw_datas[index]);
|
|
||||||
max = std::max(max, raw_datas[index]);
|
|
||||||
}
|
|
||||||
schema::QuantParamT quant_param;
|
|
||||||
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
|
|
||||||
if (status != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
quant_params.emplace_back(quant_param);
|
|
||||||
// do quantization
|
|
||||||
for (uint32_t j = 0; j < one_filter_size; j++) {
|
|
||||||
auto index = i + j * channels;
|
|
||||||
if (index >= elem_count) {
|
|
||||||
MS_LOG(ERROR) << "over flow!";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
float raw_data = raw_datas[index];
|
|
||||||
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
|
|
||||||
quant_datas[index] = quant_data;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(T));
|
|
||||||
if (ret != EOK) {
|
|
||||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
weight->set_tensor_size(elem_count * sizeof(T));
|
|
||||||
} else {
|
|
||||||
// channel at first
|
// channel at first
|
||||||
auto channels = dims[0];
|
auto channels = dims[0];
|
||||||
if (channels == 0) {
|
if (channels == 0) {
|
||||||
|
@ -242,7 +192,6 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
weight->set_tensor_size(elem_count * sizeof(T));
|
weight->set_tensor_size(elem_count * sizeof(T));
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// per layer
|
// per layer
|
||||||
float min = FLT_MAX;
|
float min = FLT_MAX;
|
||||||
|
|
|
@ -97,8 +97,7 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
|
||||||
std::vector<schema::QuantParamT> quant_params;
|
std::vector<schema::QuantParamT> quant_params;
|
||||||
primitive_c->AddInputQuantParam(quant_params);
|
primitive_c->AddInputQuantParam(quant_params);
|
||||||
auto status =
|
auto status =
|
||||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant,
|
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
|
||||||
quant_max, quant_min, bitNum, true, false);
|
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||||
return status;
|
return status;
|
||||||
|
@ -160,8 +159,8 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
|
||||||
|
|
||||||
std::vector<schema::QuantParamT> quant_params;
|
std::vector<schema::QuantParamT> quant_params;
|
||||||
primitive_c->AddInputQuantParam(quant_params);
|
primitive_c->AddInputQuantParam(quant_params);
|
||||||
auto status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant,
|
auto status =
|
||||||
quant_max, quant_min, bitNum, true, false);
|
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||||
return status;
|
return status;
|
||||||
|
|
|
@ -40,6 +40,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node
|
||||||
MS_ASSERT(conv_cnode != nullptr);
|
MS_ASSERT(conv_cnode != nullptr);
|
||||||
MS_ASSERT(param_value != nullptr);
|
MS_ASSERT(param_value != nullptr);
|
||||||
switch (quant_type) {
|
switch (quant_type) {
|
||||||
|
case schema::QuantType_PostTraining:
|
||||||
case QuantType_WeightQuant:
|
case QuantType_WeightQuant:
|
||||||
case QuantType_QUANT_NONE:param_value->set_format(schema::Format::Format_KCHW);
|
case QuantType_QUANT_NONE:param_value->set_format(schema::Format::Format_KCHW);
|
||||||
break;
|
break;
|
||||||
|
@ -73,6 +74,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case QuantType_PostTraining:
|
||||||
case QuantType_WeightQuant:
|
case QuantType_WeightQuant:
|
||||||
case QuantType_QUANT_NONE: {
|
case QuantType_QUANT_NONE: {
|
||||||
// conv (K x C/group x kH x kW) group = 1
|
// conv (K x C/group x kH x kW) group = 1
|
||||||
|
@ -114,6 +116,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case QuantType_PostTraining:
|
||||||
case QuantType_WeightQuant:
|
case QuantType_WeightQuant:
|
||||||
case QuantType_QUANT_NONE: {
|
case QuantType_QUANT_NONE: {
|
||||||
// sum up from current ms quant models
|
// sum up from current ms quant models
|
||||||
|
|
Loading…
Reference in New Issue