add parameter enable_encode & modify qat
This commit is contained in:
parent
255023c539
commit
a8df67c458
|
@ -206,4 +206,5 @@ mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/conv_fp32_
|
|||
mindspore/mindspore/lite/src/litert/kernel/cpu/control/tensorlist_setitem.cc:mindspore::kernel::TensorListSetItemCPUKernel::Run
|
||||
mindspore/mindspore/python/mindspore/ops/_utils/utils.py:get_broadcast_shape
|
||||
mindspore/mindspore/ccsrc/pybind_api/ir/dtype_py.cc:mindspore::RegTyping
|
||||
mindspore/mindspore/ccsrc/pybind_api/ir/tensor_py.cc:mindspore::tensor::RegMetaTensor
|
||||
mindspore/mindspore/ccsrc/pybind_api/ir/tensor_py.cc:mindspore::tensor::RegMetaTensor
|
||||
mindspore/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc:mindspore::lite::quant::WeightQuantizer::LinearQuant
|
||||
|
|
|
@ -563,7 +563,7 @@ STATUS AnfTransform::DoSingleGraphQATTransform(const FuncGraphPtr &func_graph,
|
|||
}
|
||||
auto transform_uint8_pass = quant::TransformUint8Pass(func_graph);
|
||||
ret = transform_uint8_pass.Transform();
|
||||
if (ret != RET_OK) {
|
||||
if (ret != RET_OK && ret != RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "Run dtype transform pass failed.";
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -145,6 +145,7 @@ int ConfigFileParser::ParseCommonQuantString(const std::map<std::string, std::ma
|
|||
{"min_quant_weight_channel", common_quant_string_.min_quant_weight_channel},
|
||||
{"skip_quant_node", common_quant_string_.skip_quant_node},
|
||||
{"debug_info_save_path", common_quant_string_.debug_info_save_path},
|
||||
{"enable_encode", common_quant_string_.enable_encode},
|
||||
};
|
||||
return SetMapData(map, parse_map, kCommonQuantParam);
|
||||
}
|
||||
|
|
|
@ -43,6 +43,7 @@ struct CommonQuantString {
|
|||
std::string min_quant_weight_channel;
|
||||
std::string skip_quant_node;
|
||||
std::string debug_info_save_path;
|
||||
std::string enable_encode;
|
||||
};
|
||||
|
||||
struct MixedBitWeightQuantString {
|
||||
|
|
|
@ -84,6 +84,23 @@ int QuantParamParser::ParseBitNum(const CommonQuantString &common_quant_string,
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int QuantParamParser::ParseEnableEncode(const CommonQuantString &common_quant_string,
|
||||
quant::CommonQuantParam *common_quant) {
|
||||
if (!common_quant_string.enable_encode.empty() &&
|
||||
!ConvertBool(common_quant_string.enable_encode, &common_quant->enable_encode)) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: enable_encode should be true or false.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (common_quant->quant_type == schema::QuantType_QUANT_WEIGHT &&
|
||||
(common_quant->bit_num != kQuantBitNumInt8 && common_quant->bit_num != kQuantBitNumInt16)) {
|
||||
if (!common_quant->enable_encode) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: enable_encode should be true when parameter bit_num belongs to [0,7] or [9,15].";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_string,
|
||||
quant::CommonQuantParam *common_quant) {
|
||||
MS_ASSERT(common_quant != nullptr);
|
||||
|
@ -101,6 +118,12 @@ int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_str
|
|||
return ret;
|
||||
}
|
||||
|
||||
ret = ParseEnableEncode(common_quant_string, common_quant);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Parse enable encode failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = ParseFilter(common_quant_string, common_quant);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Parse filter failed.";
|
||||
|
|
|
@ -36,6 +36,7 @@ class QuantParamParser {
|
|||
quant::ActivationQuantizedMethod *activation_quant_method);
|
||||
static int ParseFilter(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant);
|
||||
static int ParseBitNum(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant);
|
||||
static int ParseEnableEncode(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant);
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,7 +35,9 @@ int TransformUint8Pass::Transform() {
|
|||
continue;
|
||||
}
|
||||
auto status = DoNodeDTypeTrans(cnode);
|
||||
if (status != RET_OK) {
|
||||
if (status == RET_NO_CHANGE) {
|
||||
return status;
|
||||
} else if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoNodeDTypeTrans failed, cnode name: " << cnode->fullname_with_scope();
|
||||
return status;
|
||||
}
|
||||
|
@ -153,8 +155,8 @@ int TransformUint8Pass::DoNodeDTypeTrans(const CNodePtr &cnode) {
|
|||
CHECK_NULL_RETURN(curr_quant_param_holder);
|
||||
TypeId cnode_dtype = kTypeUnknown;
|
||||
if (opt::GetDataTypeFromAnfNode(cnode, &cnode_dtype) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get data type failed, cnode name: " << cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
MS_LOG(INFO) << "Get data type failed, cnode name: " << cnode->fullname_with_scope();
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
if (cnode_dtype == kNumberTypeUInt8) {
|
||||
MS_LOG(INFO) << "cnode dtype kNumberTypeUInt8, cnode name: " << cnode->fullname_with_scope();
|
||||
|
@ -169,8 +171,8 @@ int TransformUint8Pass::DoNodeDTypeTrans(const CNodePtr &cnode) {
|
|||
}
|
||||
// update output quant param zp
|
||||
if (curr_quant_param_holder->get_output_quant_params().empty()) {
|
||||
MS_LOG(ERROR) << "output quant params empty.";
|
||||
return RET_ERROR;
|
||||
MS_LOG(INFO) << "output quant params empty.";
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
auto out_quant_params = curr_quant_param_holder->get_output_quant_params()[0];
|
||||
for (auto &quant_param : out_quant_params) {
|
||||
|
@ -242,6 +244,10 @@ bool TransformUint8Pass::CheckNeedDTypeTrans(const CNodePtr &cnode) {
|
|||
MS_LOG(DEBUG) << "dtype not kNumberTypeUInt8, cnode name: " << cnode->fullname_with_scope();
|
||||
return false;
|
||||
}
|
||||
auto curr_quant_param_holder = GetCNodeQuantHolder(cnode);
|
||||
if (curr_quant_param_holder->get_output_quant_params().empty()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -93,6 +93,7 @@ struct CommonQuantParam {
|
|||
DebugMode debug_mode = DETAIL;
|
||||
std::set<std::string> skip_quant_node;
|
||||
int thread_num = 4;
|
||||
bool enable_encode = true;
|
||||
};
|
||||
|
||||
struct MixedBitWeightQuantParam {
|
||||
|
|
|
@ -180,7 +180,8 @@ int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr
|
|||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
if (compression && !is_mixed_bit_) {
|
||||
bool is_compression = (compression && !is_mixed_bit_ && enable_encode_);
|
||||
if (is_compression) {
|
||||
status = DoCompression(cnode, parameter, idx);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << cnode->fullname_with_scope() << " compression failed.";
|
||||
|
|
|
@ -58,6 +58,7 @@ class WeightQuantizer : public Quantizer {
|
|||
|
||||
explicit WeightQuantizer(const std::shared_ptr<ConverterPara> ¶m, double init_scale = 0) : Quantizer(param) {
|
||||
bit_num_ = param_->commonQuantParam.bit_num;
|
||||
enable_encode_ = param_->commonQuantParam.enable_encode;
|
||||
if (bit_num_ == 0) {
|
||||
type_id_ = kNumberTypeInt16;
|
||||
is_mixed_bit_ = true;
|
||||
|
@ -133,6 +134,7 @@ class WeightQuantizer : public Quantizer {
|
|||
std::set<std::string> skip_quant_node_;
|
||||
std::unique_ptr<QuantStrategy> quant_strategy_;
|
||||
schema::QuantType quant_type_{schema::QuantType_WeightQuant};
|
||||
bool enable_encode_{true};
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_
|
||||
|
|
Loading…
Reference in New Issue