add parameter enable_encode & modify qat

This commit is contained in:
albert-yan 2022-10-19 11:11:16 +08:00
parent 255023c539
commit a8df67c458
10 changed files with 45 additions and 8 deletions

View File

@ -206,4 +206,5 @@ mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/conv_fp32_
mindspore/mindspore/lite/src/litert/kernel/cpu/control/tensorlist_setitem.cc:mindspore::kernel::TensorListSetItemCPUKernel::Run
mindspore/mindspore/python/mindspore/ops/_utils/utils.py:get_broadcast_shape
mindspore/mindspore/ccsrc/pybind_api/ir/dtype_py.cc:mindspore::RegTyping
mindspore/mindspore/ccsrc/pybind_api/ir/tensor_py.cc:mindspore::tensor::RegMetaTensor
mindspore/mindspore/ccsrc/pybind_api/ir/tensor_py.cc:mindspore::tensor::RegMetaTensor
mindspore/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc:mindspore::lite::quant::WeightQuantizer::LinearQuant

View File

@ -563,7 +563,7 @@ STATUS AnfTransform::DoSingleGraphQATTransform(const FuncGraphPtr &func_graph,
}
auto transform_uint8_pass = quant::TransformUint8Pass(func_graph);
ret = transform_uint8_pass.Transform();
if (ret != RET_OK) {
if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run dtype transform pass failed.";
return ret;
}

View File

@ -145,6 +145,7 @@ int ConfigFileParser::ParseCommonQuantString(const std::map<std::string, std::ma
{"min_quant_weight_channel", common_quant_string_.min_quant_weight_channel},
{"skip_quant_node", common_quant_string_.skip_quant_node},
{"debug_info_save_path", common_quant_string_.debug_info_save_path},
{"enable_encode", common_quant_string_.enable_encode},
};
return SetMapData(map, parse_map, kCommonQuantParam);
}

View File

@ -43,6 +43,7 @@ struct CommonQuantString {
std::string min_quant_weight_channel;
std::string skip_quant_node;
std::string debug_info_save_path;
std::string enable_encode;
};
struct MixedBitWeightQuantString {

View File

@ -84,6 +84,23 @@ int QuantParamParser::ParseBitNum(const CommonQuantString &common_quant_string,
return RET_OK;
}
int QuantParamParser::ParseEnableEncode(const CommonQuantString &common_quant_string,
quant::CommonQuantParam *common_quant) {
if (!common_quant_string.enable_encode.empty() &&
!ConvertBool(common_quant_string.enable_encode, &common_quant->enable_encode)) {
MS_LOG(ERROR) << "INPUT ILLEGAL: enable_encode should be true or false.";
return RET_INPUT_PARAM_INVALID;
}
if (common_quant->quant_type == schema::QuantType_QUANT_WEIGHT &&
(common_quant->bit_num != kQuantBitNumInt8 && common_quant->bit_num != kQuantBitNumInt16)) {
if (!common_quant->enable_encode) {
MS_LOG(ERROR) << "INPUT ILLEGAL: enable_encode should be true when parameter bit_num belongs to [0,7] or [9,15].";
return RET_INPUT_PARAM_INVALID;
}
}
return RET_OK;
}
int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_string,
quant::CommonQuantParam *common_quant) {
MS_ASSERT(common_quant != nullptr);
@ -101,6 +118,12 @@ int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_str
return ret;
}
ret = ParseEnableEncode(common_quant_string, common_quant);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse enable encode failed.";
return ret;
}
ret = ParseFilter(common_quant_string, common_quant);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse filter failed.";

View File

@ -36,6 +36,7 @@ class QuantParamParser {
quant::ActivationQuantizedMethod *activation_quant_method);
static int ParseFilter(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant);
static int ParseBitNum(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant);
static int ParseEnableEncode(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant);
};
} // namespace lite
} // namespace mindspore

View File

@ -35,7 +35,9 @@ int TransformUint8Pass::Transform() {
continue;
}
auto status = DoNodeDTypeTrans(cnode);
if (status != RET_OK) {
if (status == RET_NO_CHANGE) {
return status;
} else if (status != RET_OK) {
MS_LOG(ERROR) << "DoNodeDTypeTrans failed, cnode name: " << cnode->fullname_with_scope();
return status;
}
@ -153,8 +155,8 @@ int TransformUint8Pass::DoNodeDTypeTrans(const CNodePtr &cnode) {
CHECK_NULL_RETURN(curr_quant_param_holder);
TypeId cnode_dtype = kTypeUnknown;
if (opt::GetDataTypeFromAnfNode(cnode, &cnode_dtype) != RET_OK) {
MS_LOG(ERROR) << "Get data type failed, cnode name: " << cnode->fullname_with_scope();
return RET_ERROR;
MS_LOG(INFO) << "Get data type failed, cnode name: " << cnode->fullname_with_scope();
return RET_NO_CHANGE;
}
if (cnode_dtype == kNumberTypeUInt8) {
MS_LOG(INFO) << "cnode dtype kNumberTypeUInt8, cnode name: " << cnode->fullname_with_scope();
@ -169,8 +171,8 @@ int TransformUint8Pass::DoNodeDTypeTrans(const CNodePtr &cnode) {
}
// update output quant param zp
if (curr_quant_param_holder->get_output_quant_params().empty()) {
MS_LOG(ERROR) << "output quant params empty.";
return RET_ERROR;
MS_LOG(INFO) << "output quant params empty.";
return RET_NO_CHANGE;
}
auto out_quant_params = curr_quant_param_holder->get_output_quant_params()[0];
for (auto &quant_param : out_quant_params) {
@ -242,6 +244,10 @@ bool TransformUint8Pass::CheckNeedDTypeTrans(const CNodePtr &cnode) {
MS_LOG(DEBUG) << "dtype not kNumberTypeUInt8, cnode name: " << cnode->fullname_with_scope();
return false;
}
auto curr_quant_param_holder = GetCNodeQuantHolder(cnode);
if (curr_quant_param_holder->get_output_quant_params().empty()) {
return false;
}
return true;
}

View File

@ -93,6 +93,7 @@ struct CommonQuantParam {
DebugMode debug_mode = DETAIL;
std::set<std::string> skip_quant_node;
int thread_num = 4;
bool enable_encode = true;
};
struct MixedBitWeightQuantParam {

View File

@ -180,7 +180,8 @@ int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
if (compression && !is_mixed_bit_) {
bool is_compression = (compression && !is_mixed_bit_ && enable_encode_);
if (is_compression) {
status = DoCompression(cnode, parameter, idx);
if (status != RET_OK) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << " compression failed.";

View File

@ -58,6 +58,7 @@ class WeightQuantizer : public Quantizer {
explicit WeightQuantizer(const std::shared_ptr<ConverterPara> &param, double init_scale = 0) : Quantizer(param) {
bit_num_ = param_->commonQuantParam.bit_num;
enable_encode_ = param_->commonQuantParam.enable_encode;
if (bit_num_ == 0) {
type_id_ = kNumberTypeInt16;
is_mixed_bit_ = true;
@ -133,6 +134,7 @@ class WeightQuantizer : public Quantizer {
std::set<std::string> skip_quant_node_;
std::unique_ptr<QuantStrategy> quant_strategy_;
schema::QuantType quant_type_{schema::QuantType_WeightQuant};
bool enable_encode_{true};
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_