forked from mindspore-Ecosystem/mindspore
PerChannel Quantization
This commit is contained in:
parent
ec1cf059a7
commit
45268c5289
|
@ -134,7 +134,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
|
||||||
for (auto input_quant_param : input_quant_params[i]) {
|
for (auto input_quant_param : input_quant_params[i]) {
|
||||||
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
|
std::unique_ptr<schema::QuantParamT> input_quant_param_ptr =
|
||||||
std::make_unique<schema::QuantParamT>(input_quant_param);
|
std::make_unique<schema::QuantParamT>(input_quant_param);
|
||||||
MS_LOG(DEBUG) << "[input]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
|
MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale
|
||||||
<< " zp: " << input_quant_param_ptr->zeroPoint;
|
<< " zp: " << input_quant_param_ptr->zeroPoint;
|
||||||
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
|
tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr));
|
||||||
}
|
}
|
||||||
|
@ -152,7 +152,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
|
||||||
if (tensor_output->quantParams.empty()) {
|
if (tensor_output->quantParams.empty()) {
|
||||||
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
|
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
|
||||||
std::make_unique<schema::QuantParamT>(output_quant_param);
|
std::make_unique<schema::QuantParamT>(output_quant_param);
|
||||||
MS_LOG(DEBUG) << "[input]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
|
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
|
||||||
<< " zp: " << output_quant_param_ptr->zeroPoint;
|
<< " zp: " << output_quant_param_ptr->zeroPoint;
|
||||||
tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr));
|
tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr));
|
||||||
}
|
}
|
||||||
|
|
|
@ -536,7 +536,7 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value,
|
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value,
|
||||||
bool depthwise) {
|
bool perchanel, bool depthwise) {
|
||||||
// const vector<int> dims = filter->dims;
|
// const vector<int> dims = filter->dims;
|
||||||
// perlayer
|
// perlayer
|
||||||
if (!weight->isa<Parameter>()) {
|
if (!weight->isa<Parameter>()) {
|
||||||
|
@ -544,9 +544,17 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P
|
||||||
return RET_PARAM_INVALID;
|
return RET_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
auto parameter = std::dynamic_pointer_cast<Parameter>(weight);
|
auto parameter = std::dynamic_pointer_cast<Parameter>(weight);
|
||||||
|
if (parameter == nullptr) {
|
||||||
|
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not cast to Parameter";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param());
|
ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param());
|
||||||
|
if (paramValue == nullptr) {
|
||||||
|
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
auto status = QuantFilter(paramValue, primitiveT_value, QuantType_PostTraining, quant_max, quant_min, bit_num,
|
auto status = QuantFilter(paramValue, primitiveT_value, QuantType_PostTraining, quant_max, quant_min, bit_num,
|
||||||
per_channel_, depthwise);
|
perchanel, depthwise);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||||
return status;
|
return status;
|
||||||
|
@ -690,11 +698,29 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
||||||
auto op_name = cnode->fullname_with_scope();
|
auto op_name = cnode->fullname_with_scope();
|
||||||
auto op_type = primitiveT_value->GetPrimitiveT()->value.type;
|
auto op_type = primitiveT_value->GetPrimitiveT()->value.type;
|
||||||
MS_LOG(INFO) << "OpName: " << op_name;
|
MS_LOG(INFO) << "OpName: " << op_name;
|
||||||
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D) {
|
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D &&
|
||||||
|
op_type != PrimitiveType_FullConnection) {
|
||||||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||||
auto input_node = cnode->input(i);
|
auto input_node = cnode->input(i);
|
||||||
if (!input_node->isa<mindspore::CNode>()) {
|
if (!input_node->isa<mindspore::CNode>()) {
|
||||||
MS_LOG(WARNING) << "node: " << cnode_name << " input " << i << " not a cnode";
|
MS_LOG(DEBUG) << "node: " << cnode_name << " input " << i << " not a cnode";
|
||||||
|
// get dtype
|
||||||
|
auto abstractBase = input_node->abstract();
|
||||||
|
if (abstractBase == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
|
||||||
|
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << input_node->fullname_with_scope();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
||||||
|
if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) {
|
||||||
|
MS_LOG(DEBUG) << "this parameter do quant";
|
||||||
|
DoWeightQuant(input_node, primitiveT_value, false, false);
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "this parameter no need to do quant";
|
||||||
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
|
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
|
||||||
|
@ -704,8 +730,15 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
||||||
<< " PrimitiveTValue is null";
|
<< " PrimitiveTValue is null";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) {
|
if (!input_cnode_primitiveT_value->GetOutputQuantParams().empty()) {
|
||||||
primitiveT_value->AddInputQuantParam(quant_param);
|
for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) {
|
||||||
|
primitiveT_value->AddInputQuantParam(quant_param);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// do input quant
|
||||||
|
double scale = input_scale[cnode];
|
||||||
|
int32_t zp = input_zero_point[cnode];
|
||||||
|
DoQuantInput(scale, zp, &input_min_max[cnode], primitiveT_value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -715,8 +748,12 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
||||||
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitiveT_value);
|
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitiveT_value);
|
||||||
// do weight quant
|
// do weight quant
|
||||||
auto weight = cnode->input(2);
|
auto weight = cnode->input(2);
|
||||||
bool depthwise = op_type == PrimitiveType_DeDepthwiseConv2D;
|
bool depthwise = op_type == PrimitiveType_DepthwiseConv2D;
|
||||||
DoWeightQuant(weight, primitiveT_value, depthwise);
|
bool perchannel = per_channel_;
|
||||||
|
if (op_type == PrimitiveType_FullConnection) {
|
||||||
|
perchannel = false;
|
||||||
|
}
|
||||||
|
DoWeightQuant(weight, primitiveT_value, perchannel, depthwise);
|
||||||
// do bias quant
|
// do bias quant
|
||||||
if (cnode->inputs().size() == 4) {
|
if (cnode->inputs().size() == 4) {
|
||||||
auto bias = cnode->input(3);
|
auto bias = cnode->input(3);
|
||||||
|
|
|
@ -60,7 +60,7 @@ struct ConfigParam {
|
||||||
class PostTrainingQuantizer : public Quantizer {
|
class PostTrainingQuantizer : public Quantizer {
|
||||||
public:
|
public:
|
||||||
PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8,
|
PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8,
|
||||||
bool per_channel = false);
|
bool per_channel = true);
|
||||||
|
|
||||||
STATUS DoQuantize(FuncGraphPtr funcGraph) override;
|
STATUS DoQuantize(FuncGraphPtr funcGraph) override;
|
||||||
|
|
||||||
|
@ -96,7 +96,8 @@ class PostTrainingQuantizer : public Quantizer {
|
||||||
STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveTValue>);
|
STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveTValue>);
|
||||||
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveTValue>);
|
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveTValue>);
|
||||||
|
|
||||||
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value, bool depthwise);
|
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value, bool perchannel,
|
||||||
|
bool depthwise);
|
||||||
|
|
||||||
STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveTValue> primitiveT_value);
|
STATUS DoBiasQuant(AnfNodePtr bias, std::shared_ptr<PrimitiveTValue> primitiveT_value);
|
||||||
};
|
};
|
||||||
|
|
|
@ -100,7 +100,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
|
||||||
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
|
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D,
|
||||||
schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
|
schema::PrimitiveType_Add, schema::PrimitiveType_Pooling,
|
||||||
schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/
|
schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/
|
||||||
schema::PrimitiveType_Reshape, /*schema::PrimitiveType_FullConnection,*/
|
schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection,
|
||||||
schema::PrimitiveType_MatMul,
|
schema::PrimitiveType_MatMul,
|
||||||
schema::PrimitiveType_Activation};
|
schema::PrimitiveType_Activation};
|
||||||
return IsContain(uint8OpList, type);
|
return IsContain(uint8OpList, type);
|
||||||
|
|
Loading…
Reference in New Issue