forked from mindspore-Ecosystem/mindspore
PerChannel Quantization
This commit is contained in:
parent
ec1cf059a7
commit
45268c5289
|
@ -134,7 +134,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
|
|||
for (auto input_quant_param : input_quant_params[i]) {
|
||||
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
|
||||
std::make_unique<schema::QuantParamT>(input_quant_param);
|
||||
MS_LOG(DEBUG) << "[input]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
|
||||
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
|
||||
<< " zp: " << input_quant_param_ptr->zeroPoint;
|
||||
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
|
||||
}
|
||||
|
@ -152,7 +152,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
|
|||
if (tensor_output->quantParams.empty()) {
|
||||
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
|
||||
std::make_unique<schema::QuantParamT>(output_quant_param);
|
||||
MS_LOG(DEBUG) << "[input]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
|
||||
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
|
||||
<< " zp: " << output_quant_param_ptr->zeroPoint;
|
||||
tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr));
|
||||
}
|
||||
|
|
|
@ -536,7 +536,7 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
|
|||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value,
|
||||
bool depthwise) {
|
||||
bool perchanel, bool depthwise) {
|
||||
// const vector<int> dims = filter->dims;
|
||||
// perlayer
|
||||
if (!weight->isa<Parameter>()) {
|
||||
|
@ -544,9 +544,17 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P
|
|||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto parameter = std::dynamic_pointer_cast<Parameter>(weight);
|
||||
if (parameter == nullptr) {
|
||||
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not cast to Parameter";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param());
|
||||
if (paramValue == nullptr) {
|
||||
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = QuantFilter(paramValue, primitiveT_value, QuantType_PostTraining, quant_max, quant_min, bit_num,
|
||||
per_channel_, depthwise);
|
||||
perchanel, depthwise);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||
return status;
|
||||
|
@ -690,11 +698,29 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
auto op_name = cnode->fullname_with_scope();
|
||||
auto op_type = primitiveT_value->GetPrimitiveT()->value.type;
|
||||
MS_LOG(INFO) << "OpName: " << op_name;
|
||||
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D) {
|
||||
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D &&
|
||||
op_type != PrimitiveType_FullConnection) {
|
||||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||
auto input_node = cnode->input(i);
|
||||
if (!input_node->isa<mindspore::CNode>()) {
|
||||
MS_LOG(WARNING) << "node: " << cnode_name << " input " << i << " not a cnode";
|
||||
MS_LOG(DEBUG) << "node: " << cnode_name << " input " << i << " not a cnode";
|
||||
// get dtype
|
||||
auto abstractBase = input_node->abstract();
|
||||
if (abstractBase == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << input_node->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
||||
if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) {
|
||||
MS_LOG(DEBUG) << "this parameter do quant";
|
||||
DoWeightQuant(input_node, primitiveT_value, false, false);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "this parameter no need to do quant";
|
||||
}
|
||||
continue;
|
||||
}
|
||||
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
|
||||
|
@ -704,9 +730,16 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
<< " PrimitiveTValue is null";
|
||||
continue;
|
||||
}
|
||||
if (!input_cnode_primitiveT_value->GetOutputQuantParams().empty()) {
|
||||
for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) {
|
||||
primitiveT_value->AddInputQuantParam(quant_param);
|
||||
}
|
||||
} else {
|
||||
// do input quant
|
||||
double scale = input_scale[cnode];
|
||||
int32_t zp = input_zero_point[cnode];
|
||||
DoQuantInput(scale, zp, &input_min_max[cnode], primitiveT_value);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// do input quant
|
||||
|
@ -715,8 +748,12 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitiveT_value);
|
||||
// do weight quant
|
||||
auto weight = cnode->input(2);
|
||||
bool depthwise = op_type == PrimitiveType_DeDepthwiseConv2D;
|
||||
DoWeightQuant(weight, primitiveT_value, depthwise);
|
||||
bool depthwise = op_type == PrimitiveType_DepthwiseConv2D;
|
||||
bool perchannel = per_channel_;
|
||||
if (op_type == PrimitiveType_FullConnection) {
|
||||
perchannel = false;
|
||||
}
|
||||
DoWeightQuant(weight, primitiveT_value, perchannel, depthwise);
|
||||
// do bias quant
|
||||
if (cnode->inputs().size() == 4) {
|
||||
auto bias = cnode->input(3);
|
||||
|
|
|
@ -60,7 +60,7 @@ struct ConfigParam {
|
|||
class PostTrainingQuantizer : public Quantizer {
|
||||
public:
|
||||
PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8,
|
||||
bool per_channel = false);
|
||||
bool per_channel = true);
|
||||
|
||||
STATUS DoQuantize(FuncGraphPtr funcGraph) override;
|
||||
|
||||
|
@ -96,7 +96,8 @@ class PostTrainingQuantizer : public Quantizer {
|
|||
STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveTValue>);
|
||||
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveTValue>);
|
||||
|
||||
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value, bool depthwise);
|
||||
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value, bool perchannel,
|
||||
bool depthwise);
|
||||
|
||||
STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveTValue> primitiveT_value);
|
||||
};
|
||||
|
|
|
@ -100,7 +100,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
|
|||
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
|
||||
schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
|
||||
schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/
|
||||
schema::PrimitiveType_Reshape, /*schema::PrimitiveType_FullConnection,*/
|
||||
schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection,
|
||||
schema::PrimitiveType_MatMul,
|
||||
schema::PrimitiveType_Activation};
|
||||
return IsContain(uint8OpList, type);
|
||||
|
|
Loading…
Reference in New Issue