weight quant support ON_THE_FLY
This commit is contained in:
parent
b4e6ec78be
commit
8534315dc6
|
@ -25,6 +25,7 @@ namespace lite {
|
|||
namespace {
|
||||
constexpr auto kCommonQuantParam = "common_quant_param";
|
||||
constexpr auto kFullQuantParam = "full_quant_param";
|
||||
constexpr auto kWeightQuantParam = "weight_quant_param";
|
||||
constexpr auto kMixedBitWeightQuantParam = "mixed_bit_weight_quant_param";
|
||||
constexpr auto kDataPreprocessParam = "data_preprocess_param";
|
||||
constexpr auto kRegistry = "registry";
|
||||
|
@ -93,7 +94,12 @@ int ConfigFileParser::ParseConfigParam(std::map<std::string, std::map<std::strin
|
|||
MS_LOG(ERROR) << "ParseMicroParamString failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = ParseWeightQuantString(*maps);
|
||||
(void)maps->erase(kWeightQuantParam);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ParseWeightQuantString failed.";
|
||||
return ret;
|
||||
}
|
||||
for (const auto &config_info : *maps) {
|
||||
ConverterInnerContext::GetInstance()->SetExternalUsedConfigInfos(config_info.first, config_info.second);
|
||||
}
|
||||
|
@ -226,5 +232,14 @@ int ConfigFileParser::ParseMicroParamString(const std::map<std::string, std::map
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ConfigFileParser::ParseWeightQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps) {
|
||||
if (maps.find(kWeightQuantParam) != maps.end()) {
|
||||
const auto &map = maps.at(kWeightQuantParam);
|
||||
std::map<std::string, std::string &> parse_map{{"dequant_strategy", weight_quant_string_.dequant_strategy}};
|
||||
return SetMapData(map, parse_map, kWeightQuantParam);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -53,6 +53,10 @@ struct MixedBitWeightQuantString {
|
|||
std::string max_iterations;
|
||||
};
|
||||
|
||||
struct WeightQuantString {
|
||||
std::string dequant_strategy;
|
||||
};
|
||||
|
||||
struct FullQuantString {
|
||||
std::string activation_quant_method;
|
||||
std::string bias_correction;
|
||||
|
@ -98,6 +102,7 @@ class ConfigFileParser {
|
|||
CommonQuantString GetCommonQuantString() const { return this->common_quant_string_; }
|
||||
MixedBitWeightQuantString GetMixedBitWeightQuantString() const { return this->mixed_bit_quant_string_; }
|
||||
FullQuantString GetFullQuantString() const { return this->full_quant_string_; }
|
||||
WeightQuantString GetWeightQuantString() const { return this->weight_quant_string_; }
|
||||
RegistryInfoString GetRegistryInfoString() const { return this->registry_info_string_; }
|
||||
AclOptionCfgString GetAclOptionCfgString() { return this->acl_option_cfg_string_; }
|
||||
MicroParamString GetMicroParamString() { return this->micro_param_string_; }
|
||||
|
@ -107,6 +112,7 @@ class ConfigFileParser {
|
|||
int ParseCommonQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps);
|
||||
int ParseMixedBitQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps);
|
||||
int ParseFullQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps);
|
||||
int ParseWeightQuantString(const std::map<std::string, std::map<std::string, std::string>> &maps);
|
||||
int ParseRegistryInfoString(const std::map<std::string, std::map<std::string, std::string>> &maps);
|
||||
int ParseAclOptionCfgString(const std::map<std::string, std::map<std::string, std::string>> &maps);
|
||||
int SetMapData(const std::map<std::string, std::string> &input_map,
|
||||
|
@ -118,6 +124,7 @@ class ConfigFileParser {
|
|||
CommonQuantString common_quant_string_;
|
||||
MixedBitWeightQuantString mixed_bit_quant_string_;
|
||||
FullQuantString full_quant_string_;
|
||||
WeightQuantString weight_quant_string_;
|
||||
RegistryInfoString registry_info_string_;
|
||||
AclOptionCfgString acl_option_cfg_string_;
|
||||
MicroParamString micro_param_string_;
|
||||
|
|
|
@ -250,5 +250,18 @@ int QuantParamParser::ParseActivationQuantizedMethod(const std::string &activati
|
|||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
int QuantParamParser::ParseWeightQuant(const WeightQuantString &weight_quant_string,
|
||||
quant::WeightQuantParam *weight_quant) {
|
||||
if (!weight_quant_string.dequant_strategy.empty()) {
|
||||
if (weight_quant_string.dequant_strategy == "ON_THE_FLY") {
|
||||
weight_quant->dequant_strategy = quant::ON_THE_FLY;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: dequant_strategy must be ON_THE_FLY.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,6 +27,7 @@ class QuantParamParser {
|
|||
static int ParseMixedBitWeightQuant(const MixedBitWeightQuantString &mixed_bit_weight_quant_string,
|
||||
quant::MixedBitWeightQuantParam *mixed_bit_weight_quant);
|
||||
static int ParseFullQuant(const FullQuantString &full_quant_string, quant::FullQuantParam *full_quant);
|
||||
static int ParseWeightQuant(const WeightQuantString &weight_quant_string, quant::WeightQuantParam *weight_quant);
|
||||
|
||||
private:
|
||||
static int ParseQuantType(const std::string &quant_type_str, schema::QuantType *quant_type);
|
||||
|
|
|
@ -384,6 +384,11 @@ int ConverterImpl::InitConfigParam(const std::shared_ptr<ConverterPara> ¶m)
|
|||
MS_LOG(ERROR) << "Parse full quant param failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = lite::QuantParamParser::ParseWeightQuant(config_parser.GetWeightQuantString(), ¶m->weightQuantParam);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Parse full quant param failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = lite::QuantParamParser::ParseMixedBitWeightQuant(config_parser.GetMixedBitWeightQuantString(),
|
||||
¶m->mixedBitWeightQuantParam);
|
||||
if (ret != RET_OK) {
|
||||
|
|
|
@ -66,6 +66,7 @@ struct ConverterPara {
|
|||
lite::quant::CommonQuantParam commonQuantParam;
|
||||
lite::quant::MixedBitWeightQuantParam mixedBitWeightQuantParam;
|
||||
lite::quant::FullQuantParam fullQuantParam;
|
||||
lite::quant::WeightQuantParam weightQuantParam;
|
||||
lite::preprocess::DataPreProcessParam dataPreProcessParam;
|
||||
lite::acl::AclModelOptionCfg aclModelOptionCfgParam;
|
||||
lite::micro::MicroParam microParam;
|
||||
|
|
|
@ -83,6 +83,11 @@ enum InsertDirection {
|
|||
BACKWARD,
|
||||
};
|
||||
|
||||
enum DequantStrategy {
|
||||
DEFAULT, // initial phase to dequant
|
||||
ON_THE_FLY,
|
||||
};
|
||||
|
||||
struct CommonQuantParam {
|
||||
schema::QuantType quant_type = schema::QuantType_QUANT_NONE;
|
||||
int bit_num = 8;
|
||||
|
@ -96,6 +101,10 @@ struct CommonQuantParam {
|
|||
bool enable_encode = true;
|
||||
};
|
||||
|
||||
struct WeightQuantParam {
|
||||
DequantStrategy dequant_strategy = DEFAULT;
|
||||
};
|
||||
|
||||
struct MixedBitWeightQuantParam {
|
||||
double init_scale = 0.02;
|
||||
bool auto_tune = false;
|
||||
|
|
|
@ -238,7 +238,8 @@ int WeightQuantizer::DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr &pa
|
|||
auto tensor_quant_params = quant_param_holder->get_input_quant_params();
|
||||
MS_CHECK_GT(static_cast<int>(tensor_quant_params.size()), idx - 1, RET_ERROR);
|
||||
auto quant_params = tensor_quant_params.at(idx - 1);
|
||||
mindspore::TensorCompressionType compress_type = inference_dequant_ ? mindspore::kFSEInfer : mindspore::kFSE;
|
||||
mindspore::TensorCompressionType compress_type =
|
||||
param_->weightQuantParam.dequant_strategy == ON_THE_FLY ? mindspore::kFSEInfer : mindspore::kFSE;
|
||||
status = fse_encoder.Compress(parameter, quant_params, compress_type);
|
||||
if (status == RET_OK) {
|
||||
quant_param_holder->ClearQuantParams();
|
||||
|
@ -385,7 +386,7 @@ int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
MS_LOG(ERROR) << "Weight Quant failed.";
|
||||
return ret;
|
||||
}
|
||||
if (inference_dequant_) {
|
||||
if (param_->weightQuantParam.dequant_strategy == ON_THE_FLY) {
|
||||
return InsertQuantDtypeNode(func_graph);
|
||||
}
|
||||
return MarkGraphWeightQuantType(func_graph);
|
||||
|
|
|
@ -122,7 +122,6 @@ class WeightQuantizer : public Quantizer {
|
|||
private:
|
||||
bool is_auto_tune_{false};
|
||||
bool is_mixed_bit_{false};
|
||||
bool inference_dequant_{false};
|
||||
bool linear_quant_{true};
|
||||
size_t bit_num_{8};
|
||||
double mixed_bit_init_scale_ = 0.02;
|
||||
|
|
Loading…
Reference in New Issue