From 8534315dc6a2a1c5ed1c1993afb6a25ce41513a5 Mon Sep 17 00:00:00 2001 From: yeyunpeng2020 Date: Wed, 2 Nov 2022 10:15:04 +0800 Subject: [PATCH] weight quant support ON_THE_FLY --- .../config_parser/config_file_parser.cc | 17 ++++++++++++++++- .../config_parser/config_file_parser.h | 7 +++++++ .../config_parser/quant_param_parser.cc | 13 +++++++++++++ .../config_parser/quant_param_parser.h | 1 + mindspore/lite/tools/converter/converter.cc | 5 +++++ .../tools/converter/cxx_api/converter_para.h | 1 + .../tools/converter/quantizer/quant_params.h | 9 +++++++++ .../converter/quantizer/weight_quantizer.cc | 5 +++-- .../converter/quantizer/weight_quantizer.h | 1 - 9 files changed, 55 insertions(+), 4 deletions(-) diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc index c6a222890c8..f35db68c8fa 100644 --- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc +++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc @@ -25,6 +25,7 @@ namespace lite { namespace { constexpr auto kCommonQuantParam = "common_quant_param"; constexpr auto kFullQuantParam = "full_quant_param"; +constexpr auto kWeightQuantParam = "weight_quant_param"; constexpr auto kMixedBitWeightQuantParam = "mixed_bit_weight_quant_param"; constexpr auto kDataPreprocessParam = "data_preprocess_param"; constexpr auto kRegistry = "registry"; @@ -93,7 +94,12 @@ int ConfigFileParser::ParseConfigParam(std::maperase(kWeightQuantParam); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ParseWeightQuantString failed."; + return ret; + } for (const auto &config_info : *maps) { ConverterInnerContext::GetInstance()->SetExternalUsedConfigInfos(config_info.first, config_info.second); } @@ -226,5 +232,14 @@ int ConfigFileParser::ParseMicroParamString(const std::map> &maps) { + if (maps.find(kWeightQuantParam) != maps.end()) { + const auto &map = maps.at(kWeightQuantParam); + std::map parse_map{{"dequant_strategy", weight_quant_string_.dequant_strategy}}; + return SetMapData(map, parse_map, kWeightQuantParam); + } + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.h b/mindspore/lite/tools/converter/config_parser/config_file_parser.h index 11eb6c71b8a..146d2530e1b 100644 --- a/mindspore/lite/tools/converter/config_parser/config_file_parser.h +++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.h @@ -53,6 +53,10 @@ struct MixedBitWeightQuantString { std::string max_iterations; }; +struct WeightQuantString { + std::string dequant_strategy; +}; + struct FullQuantString { std::string activation_quant_method; std::string bias_correction; @@ -98,6 +102,7 @@ class ConfigFileParser { CommonQuantString GetCommonQuantString() const { return this->common_quant_string_; } MixedBitWeightQuantString GetMixedBitWeightQuantString() const { return this->mixed_bit_quant_string_; } FullQuantString GetFullQuantString() const { return this->full_quant_string_; } + WeightQuantString GetWeightQuantString() const { return this->weight_quant_string_; } RegistryInfoString GetRegistryInfoString() const { return this->registry_info_string_; } AclOptionCfgString GetAclOptionCfgString() { return this->acl_option_cfg_string_; } MicroParamString GetMicroParamString() { return this->micro_param_string_; } @@ -107,6 +112,7 @@ class ConfigFileParser { int ParseCommonQuantString(const std::map> &maps); int ParseMixedBitQuantString(const std::map> &maps); int ParseFullQuantString(const std::map> &maps); + int ParseWeightQuantString(const std::map> &maps); int ParseRegistryInfoString(const std::map> &maps); int ParseAclOptionCfgString(const std::map> &maps); int SetMapData(const std::map &input_map, @@ -118,6 +124,7 @@ class ConfigFileParser { CommonQuantString common_quant_string_; MixedBitWeightQuantString mixed_bit_quant_string_; FullQuantString full_quant_string_; + WeightQuantString weight_quant_string_; RegistryInfoString registry_info_string_; AclOptionCfgString acl_option_cfg_string_; MicroParamString micro_param_string_; diff --git a/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc b/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc index 2cd5e4a7473..3853f9dda4f 100644 --- a/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc +++ b/mindspore/lite/tools/converter/config_parser/quant_param_parser.cc @@ -250,5 +250,18 @@ int QuantParamParser::ParseActivationQuantizedMethod(const std::string &activati return RET_INPUT_PARAM_INVALID; } } + +int QuantParamParser::ParseWeightQuant(const WeightQuantString &weight_quant_string, + quant::WeightQuantParam *weight_quant) { + if (!weight_quant_string.dequant_strategy.empty()) { + if (weight_quant_string.dequant_strategy == "ON_THE_FLY") { + weight_quant->dequant_strategy = quant::ON_THE_FLY; + } else { + MS_LOG(ERROR) << "INPUT ILLEGAL: dequant_strategy must be ON_THE_FLY."; + return RET_INPUT_PARAM_INVALID; + } + } + return RET_OK; +} } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/config_parser/quant_param_parser.h b/mindspore/lite/tools/converter/config_parser/quant_param_parser.h index b2af98ca663..42e9bbc0607 100644 --- a/mindspore/lite/tools/converter/config_parser/quant_param_parser.h +++ b/mindspore/lite/tools/converter/config_parser/quant_param_parser.h @@ -27,6 +27,7 @@ class QuantParamParser { static int ParseMixedBitWeightQuant(const MixedBitWeightQuantString &mixed_bit_weight_quant_string, quant::MixedBitWeightQuantParam *mixed_bit_weight_quant); static int ParseFullQuant(const FullQuantString &full_quant_string, quant::FullQuantParam *full_quant); + static int ParseWeightQuant(const WeightQuantString &weight_quant_string, quant::WeightQuantParam *weight_quant); private: static int ParseQuantType(const std::string &quant_type_str, schema::QuantType *quant_type); diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 723859bc16f..627b5c71c48 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -384,6 +384,11 @@ int ConverterImpl::InitConfigParam(const std::shared_ptr ¶m) MS_LOG(ERROR) << "Parse full quant param failed."; return ret; } + ret = lite::QuantParamParser::ParseWeightQuant(config_parser.GetWeightQuantString(), ¶m->weightQuantParam); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse full quant param failed."; + return ret; + } ret = lite::QuantParamParser::ParseMixedBitWeightQuant(config_parser.GetMixedBitWeightQuantString(), ¶m->mixedBitWeightQuantParam); if (ret != RET_OK) { diff --git a/mindspore/lite/tools/converter/cxx_api/converter_para.h b/mindspore/lite/tools/converter/cxx_api/converter_para.h index ade1e2d63af..451560d10b3 100644 --- a/mindspore/lite/tools/converter/cxx_api/converter_para.h +++ b/mindspore/lite/tools/converter/cxx_api/converter_para.h @@ -66,6 +66,7 @@ struct ConverterPara { lite::quant::CommonQuantParam commonQuantParam; lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam; lite::quant::FullQuantParam fullQuantParam; + lite::quant::WeightQuantParam weightQuantParam; lite::preprocess::DataPreProcessParam dataPreProcessParam; lite::acl::AclModelOptionCfg aclModelOptionCfgParam; lite::micro::MicroParam microParam; diff --git a/mindspore/lite/tools/converter/quantizer/quant_params.h b/mindspore/lite/tools/converter/quantizer/quant_params.h index 6d3a1b2e64e..21fc5181e6c 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_params.h +++ b/mindspore/lite/tools/converter/quantizer/quant_params.h @@ -83,6 +83,11 @@ enum InsertDirection { BACKWARD, }; +enum DequantStrategy { + DEFAULT, // initial phase to dequant + ON_THE_FLY, +}; + struct CommonQuantParam { schema::QuantType quant_type = schema::QuantType_QUANT_NONE; int bit_num = 8; @@ -96,6 +101,10 @@ struct CommonQuantParam { bool enable_encode = true; }; +struct WeightQuantParam { + DequantStrategy dequant_strategy = DEFAULT; +}; + struct MixedBitWeightQuantParam { double init_scale = 0.02; bool auto_tune = false; diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index 0c4df77a6b9..0dcbeb67a81 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -238,7 +238,8 @@ int WeightQuantizer::DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr &pa auto tensor_quant_params = quant_param_holder->get_input_quant_params(); MS_CHECK_GT(static_cast(tensor_quant_params.size()), idx - 1, RET_ERROR); auto quant_params = tensor_quant_params.at(idx - 1); - mindspore::TensorCompressionType compress_type = inference_dequant_ ? mindspore::kFSEInfer : mindspore::kFSE; + mindspore::TensorCompressionType compress_type = + param_->weightQuantParam.dequant_strategy == ON_THE_FLY ? mindspore::kFSEInfer : mindspore::kFSE; status = fse_encoder.Compress(parameter, quant_params, compress_type); if (status == RET_OK) { quant_param_holder->ClearQuantParams(); @@ -385,7 +386,7 @@ int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { MS_LOG(ERROR) << "Weight Quant failed."; return ret; } - if (inference_dequant_) { + if (param_->weightQuantParam.dequant_strategy == ON_THE_FLY) { return InsertQuantDtypeNode(func_graph); } return MarkGraphWeightQuantType(func_graph); diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h index cc4ee50a157..edb4d202901 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h @@ -122,7 +122,6 @@ class WeightQuantizer : public Quantizer { private: bool is_auto_tune_{false}; bool is_mixed_bit_{false}; - bool inference_dequant_{false}; bool linear_quant_{true}; size_t bit_num_{8}; double mixed_bit_init_scale_ = 0.02;