PerChannel Quantization

This commit is contained in:
xutianchun 2020-08-18 21:24:16 +08:00
parent ec1cf059a7
commit 45268c5289
4 changed files with 51 additions and 13 deletions

View File

@ -134,7 +134,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
for (auto input_quant_param : input_quant_params[i]) {
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
std::make_unique<schema::QuantParamT>(input_quant_param);
MS_LOG(DEBUG) << "[input]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
<< " zp: " << input_quant_param_ptr->zeroPoint;
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
}
@ -152,7 +152,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
if (tensor_output->quantParams.empty()) {
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
std::make_unique<schema::QuantParamT>(output_quant_param);
MS_LOG(DEBUG) << "[input]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
<< " zp: " << output_quant_param_ptr->zeroPoint;
tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr));
}

View File

@ -536,7 +536,7 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
}
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value,
bool depthwise) {
bool perchanel, bool depthwise) {
// const vector<int> dims = filter->dims;
// perlayer
if (!weight->isa<Parameter>()) {
@ -544,9 +544,17 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P
return RET_PARAM_INVALID;
}
auto parameter = std::dynamic_pointer_cast<Parameter>(weight);
if (parameter == nullptr) {
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not cast to Parameter";
return RET_ERROR;
}
ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param());
if (paramValue == nullptr) {
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
return RET_ERROR;
}
auto status = QuantFilter(paramValue, primitiveT_value, QuantType_PostTraining, quant_max, quant_min, bit_num,
per_channel_, depthwise);
perchanel, depthwise);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status;
@ -690,11 +698,29 @@ STATUS PostTrainingQuantizer::QuantNode() {
auto op_name = cnode->fullname_with_scope();
auto op_type = primitiveT_value->GetPrimitiveT()->value.type;
MS_LOG(INFO) << "OpName: " << op_name;
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D) {
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D &&
op_type != PrimitiveType_FullConnection) {
for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i);
if (!input_node->isa<mindspore::CNode>()) {
MS_LOG(WARNING) << "node: " << cnode_name << " input " << i << " not a cnode";
MS_LOG(DEBUG) << "node: " << cnode_name << " input " << i << " not a cnode";
// get dtype
auto abstractBase = input_node->abstract();
if (abstractBase == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope();
return RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << input_node->fullname_with_scope();
return RET_ERROR;
}
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) {
MS_LOG(DEBUG) << "this parameter do quant";
DoWeightQuant(input_node, primitiveT_value, false, false);
} else {
MS_LOG(DEBUG) << "this parameter no need to do quant";
}
continue;
}
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
@ -704,9 +730,16 @@ STATUS PostTrainingQuantizer::QuantNode() {
<< " PrimitiveTValue is null";
continue;
}
if (!input_cnode_primitiveT_value->GetOutputQuantParams().empty()) {
for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) {
primitiveT_value->AddInputQuantParam(quant_param);
}
} else {
// do input quant
double scale = input_scale[cnode];
int32_t zp = input_zero_point[cnode];
DoQuantInput(scale, zp, &input_min_max[cnode], primitiveT_value);
}
}
} else {
// do input quant
@ -715,8 +748,12 @@ STATUS PostTrainingQuantizer::QuantNode() {
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitiveT_value);
// do weight quant
auto weight = cnode->input(2);
bool depthwise = op_type == PrimitiveType_DeDepthwiseConv2D;
DoWeightQuant(weight, primitiveT_value, depthwise);
bool depthwise = op_type == PrimitiveType_DepthwiseConv2D;
bool perchannel = per_channel_;
if (op_type == PrimitiveType_FullConnection) {
perchannel = false;
}
DoWeightQuant(weight, primitiveT_value, perchannel, depthwise);
// do bias quant
if (cnode->inputs().size() == 4) {
auto bias = cnode->input(3);

View File

@ -60,7 +60,7 @@ struct ConfigParam {
class PostTrainingQuantizer : public Quantizer {
public:
PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8,
bool per_channel = false);
bool per_channel = true);
STATUS DoQuantize(FuncGraphPtr funcGraph) override;
@ -96,7 +96,8 @@ class PostTrainingQuantizer : public Quantizer {
STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveTValue>);
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveTValue>);
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value, bool depthwise);
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value, bool perchannel,
bool depthwise);
STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveTValue> primitiveT_value);
};

View File

@ -100,7 +100,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/
schema::PrimitiveType_Reshape, /*schema::PrimitiveType_FullConnection,*/
schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection,
schema::PrimitiveType_MatMul,
schema::PrimitiveType_Activation};
return IsContain(uint8OpList, type);