weight quant support ON_THE_FLY

This commit is contained in:
yeyunpeng2020 2022-11-02 10:15:04 +08:00
parent b4e6ec78be
commit 8534315dc6
9 changed files with 55 additions and 4 deletions

View File

@ -25,6 +25,7 @@ namespace lite {
namespace {
constexpr auto kCommonQuantParam = "common_quant_param";
constexpr auto kFullQuantParam = "full_quant_param";
constexpr auto kWeightQuantParam = "weight_quant_param";
constexpr auto kMixedBitWeightQuantParam = "mixed_bit_weight_quant_param";
constexpr auto kDataPreprocessParam = "data_preprocess_param";
constexpr auto kRegistry = "registry";
@ -93,7 +94,12 @@ int ConfigFileParser::ParseConfigParam(std::map<std::string, std::map<std::strin
MS_LOG(ERROR) << "ParseMicroParamString failed.";
return ret;
}
ret = ParseWeightQuantString(*maps);
(void)maps->erase(kWeightQuantParam);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ParseWeightQuantString failed.";
return ret;
}
for (const auto &config_info : *maps) {
ConverterInnerContext::GetInstance()->SetExternalUsedConfigInfos(config_info.first, config_info.second);
}
@ -226,5 +232,14 @@ int ConfigFileParser::ParseMicroParamString(const std::map<std::string, std::map
}
return RET_OK;
}
int ConfigFileParser::ParseWeightQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
if (maps.find(kWeightQuantParam) != maps.end()) {
const auto &map = maps.at(kWeightQuantParam);
std::map<std::string, std::string &> parse_map{{"dequant_strategy", weight_quant_string_.dequant_strategy}};
return SetMapData(map, parse_map, kWeightQuantParam);
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -53,6 +53,10 @@ struct MixedBitWeightQuantString {
std::string max_iterations;
};
struct WeightQuantString {
std::string dequant_strategy;
};
struct FullQuantString {
std::string activation_quant_method;
std::string bias_correction;
@ -98,6 +102,7 @@ class ConfigFileParser {
CommonQuantString GetCommonQuantString() const { return this->common_quant_string_; }
MixedBitWeightQuantString GetMixedBitWeightQuantString() const { return this->mixed_bit_quant_string_; }
FullQuantString GetFullQuantString() const { return this->full_quant_string_; }
WeightQuantString GetWeightQuantString() const { return this->weight_quant_string_; }
RegistryInfoString GetRegistryInfoString() const { return this->registry_info_string_; }
AclOptionCfgString GetAclOptionCfgString() { return this->acl_option_cfg_string_; }
MicroParamString GetMicroParamString() { return this->micro_param_string_; }
@ -107,6 +112,7 @@ class ConfigFileParser {
int ParseCommonQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps);
int ParseMixedBitQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps);
int ParseFullQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps);
int ParseWeightQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps);
int ParseRegistryInfoString(const std::map<std::string, std::map<std::string, std::string>> &maps);
int ParseAclOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &maps);
int SetMapData(const std::map<std::string, std::string> &input_map,
@ -118,6 +124,7 @@ class ConfigFileParser {
CommonQuantString common_quant_string_;
MixedBitWeightQuantString mixed_bit_quant_string_;
FullQuantString full_quant_string_;
WeightQuantString weight_quant_string_;
RegistryInfoString registry_info_string_;
AclOptionCfgString acl_option_cfg_string_;
MicroParamString micro_param_string_;

View File

@ -250,5 +250,18 @@ int QuantParamParser::ParseActivationQuantizedMethod(const std::string &activati
return RET_INPUT_PARAM_INVALID;
}
}
int QuantParamParser::ParseWeightQuant(const WeightQuantString &weight_quant_string,
quant::WeightQuantParam *weight_quant) {
if (!weight_quant_string.dequant_strategy.empty()) {
if (weight_quant_string.dequant_strategy == "ON_THE_FLY") {
weight_quant->dequant_strategy = quant::ON_THE_FLY;
} else {
MS_LOG(ERROR) << "INPUT ILLEGAL: dequant_strategy must be ON_THE_FLY.";
return RET_INPUT_PARAM_INVALID;
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -27,6 +27,7 @@ class QuantParamParser {
static int ParseMixedBitWeightQuant(const MixedBitWeightQuantString &mixed_bit_weight_quant_string,
quant::MixedBitWeightQuantParam *mixed_bit_weight_quant);
static int ParseFullQuant(const FullQuantString &full_quant_string, quant::FullQuantParam *full_quant);
static int ParseWeightQuant(const WeightQuantString &weight_quant_string, quant::WeightQuantParam *weight_quant);
private:
static int ParseQuantType(const std::string &quant_type_str, schema::QuantType *quant_type);

View File

@ -384,6 +384,11 @@ int ConverterImpl::InitConfigParam(const std::shared_ptr<ConverterPara> &param)
MS_LOG(ERROR) << "Parse full quant param failed.";
return ret;
}
ret = lite::QuantParamParser::ParseWeightQuant(config_parser.GetWeightQuantString(), &param->weightQuantParam);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse full quant param failed.";
return ret;
}
ret = lite::QuantParamParser::ParseMixedBitWeightQuant(config_parser.GetMixedBitWeightQuantString(),
&param->mixedBitWeightQuantParam);
if (ret != RET_OK) {

View File

@ -66,6 +66,7 @@ struct ConverterPara {
lite::quant::CommonQuantParam commonQuantParam;
lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam;
lite::quant::FullQuantParam fullQuantParam;
lite::quant::WeightQuantParam weightQuantParam;
lite::preprocess::DataPreProcessParam dataPreProcessParam;
lite::acl::AclModelOptionCfg aclModelOptionCfgParam;
lite::micro::MicroParam microParam;

View File

@ -83,6 +83,11 @@ enum InsertDirection {
BACKWARD,
};
enum DequantStrategy {
DEFAULT, // initial phase to dequant
ON_THE_FLY,
};
struct CommonQuantParam {
schema::QuantType quant_type = schema::QuantType_QUANT_NONE;
int bit_num = 8;
@ -96,6 +101,10 @@ struct CommonQuantParam {
bool enable_encode = true;
};
struct WeightQuantParam {
DequantStrategy dequant_strategy = DEFAULT;
};
struct MixedBitWeightQuantParam {
double init_scale = 0.02;
bool auto_tune = false;

View File

@ -238,7 +238,8 @@ int WeightQuantizer::DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr &pa
auto tensor_quant_params = quant_param_holder->get_input_quant_params();
MS_CHECK_GT(static_cast<int>(tensor_quant_params.size()), idx - 1, RET_ERROR);
auto quant_params = tensor_quant_params.at(idx - 1);
mindspore::TensorCompressionType compress_type = inference_dequant_ ? mindspore::kFSEInfer : mindspore::kFSE;
mindspore::TensorCompressionType compress_type =
param_->weightQuantParam.dequant_strategy == ON_THE_FLY ? mindspore::kFSEInfer : mindspore::kFSE;
status = fse_encoder.Compress(parameter, quant_params, compress_type);
if (status == RET_OK) {
quant_param_holder->ClearQuantParams();
@ -385,7 +386,7 @@ int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_LOG(ERROR) << "Weight Quant failed.";
return ret;
}
if (inference_dequant_) {
if (param_->weightQuantParam.dequant_strategy == ON_THE_FLY) {
return InsertQuantDtypeNode(func_graph);
}
return MarkGraphWeightQuantType(func_graph);

View File

@ -122,7 +122,6 @@ class WeightQuantizer : public Quantizer {
private:
bool is_auto_tune_{false};
bool is_mixed_bit_{false};
bool inference_dequant_{false};
bool linear_quant_{true};
size_t bit_num_{8};
double mixed_bit_init_scale_ = 0.02;