diff --git a/.jenkins/check/config/whitelizard.txt b/.jenkins/check/config/whitelizard.txt index ffa4ea83b76..51795ba665c 100644 --- a/.jenkins/check/config/whitelizard.txt +++ b/.jenkins/check/config/whitelizard.txt @@ -206,4 +206,5 @@ mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/experimental/conv_fp32_ mindspore/mindspore/lite/src/litert/kernel/cpu/control/tensorlist_setitem.cc:mindspore::kernel::TensorListSetItemCPUKernel::Run mindspore/mindspore/python/mindspore/ops/_utils/utils.py:get_broadcast_shape mindspore/mindspore/ccsrc/pybind_api/ir/dtype_py.cc:mindspore::RegTyping -mindspore/mindspore/ccsrc/pybind_api/ir/tensor_py.cc:mindspore::tensor::RegMetaTensor \ No newline at end of file +mindspore/mindspore/ccsrc/pybind_api/ir/tensor_py.cc:mindspore::tensor::RegMetaTensor +mindspore/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc:mindspore::lite::quant::WeightQuantizer::LinearQuant diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 94906ee139a..63cee4aaeb3 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -563,7 +563,7 @@ STATUS AnfTransform::DoSingleGraphQATTransform(const FuncGraphPtr &func_graph, } auto transform_uint8_pass = quant::TransformUint8Pass(func_graph); ret = transform_uint8_pass.Transform(); - if (ret != RET_OK) { + if (ret != RET_OK && ret != RET_NO_CHANGE) { MS_LOG(ERROR) << "Run dtype transform pass failed."; return ret; } diff --git a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc index 59a7d3e50a2..c6a222890c8 100644 --- a/mindspore/lite/tools/converter/config_parser/config_file_parser.cc +++ b/mindspore/lite/tools/converter/config_parser/config_file_parser.cc @@ -145,6 +145,7 @@ int ConfigFileParser::ParseCommonQuantString(const std::mapenable_encode)) { + MS_LOG(ERROR) << "INPUT ILLEGAL: enable_encode should be true or false."; + return RET_INPUT_PARAM_INVALID; + } + if (common_quant->quant_type == schema::QuantType_QUANT_WEIGHT && + (common_quant->bit_num != kQuantBitNumInt8 && common_quant->bit_num != kQuantBitNumInt16)) { + if (!common_quant->enable_encode) { + MS_LOG(ERROR) << "INPUT ILLEGAL: enable_encode should be true when parameter bit_num belongs to [0,7] or [9,15]."; + return RET_INPUT_PARAM_INVALID; + } + } + return RET_OK; +} + int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant) { MS_ASSERT(common_quant != nullptr); @@ -101,6 +118,12 @@ int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_str return ret; } + ret = ParseEnableEncode(common_quant_string, common_quant); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse enable encode failed."; + return ret; + } + ret = ParseFilter(common_quant_string, common_quant); if (ret != RET_OK) { MS_LOG(ERROR) << "Parse filter failed."; diff --git a/mindspore/lite/tools/converter/config_parser/quant_param_parser.h b/mindspore/lite/tools/converter/config_parser/quant_param_parser.h index 4f9e38164dd..b2af98ca663 100644 --- a/mindspore/lite/tools/converter/config_parser/quant_param_parser.h +++ b/mindspore/lite/tools/converter/config_parser/quant_param_parser.h @@ -36,6 +36,7 @@ class QuantParamParser { quant::ActivationQuantizedMethod *activation_quant_method); static int ParseFilter(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant); static int ParseBitNum(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant); + static int ParseEnableEncode(const CommonQuantString &common_quant_string, quant::CommonQuantParam *common_quant); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.cc index 8085a886150..16dd5bbaeaf 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/transform_uint8_pass.cc @@ -35,7 +35,9 @@ int TransformUint8Pass::Transform() { continue; } auto status = DoNodeDTypeTrans(cnode); - if (status != RET_OK) { + if (status == RET_NO_CHANGE) { + return status; + } else if (status != RET_OK) { MS_LOG(ERROR) << "DoNodeDTypeTrans failed, cnode name: " << cnode->fullname_with_scope(); return status; } @@ -153,8 +155,8 @@ int TransformUint8Pass::DoNodeDTypeTrans(const CNodePtr &cnode) { CHECK_NULL_RETURN(curr_quant_param_holder); TypeId cnode_dtype = kTypeUnknown; if (opt::GetDataTypeFromAnfNode(cnode, &cnode_dtype) != RET_OK) { - MS_LOG(ERROR) << "Get data type failed, cnode name: " << cnode->fullname_with_scope(); - return RET_ERROR; + MS_LOG(INFO) << "Get data type failed, cnode name: " << cnode->fullname_with_scope(); + return RET_NO_CHANGE; } if (cnode_dtype == kNumberTypeUInt8) { MS_LOG(INFO) << "cnode dtype kNumberTypeUInt8, cnode name: " << cnode->fullname_with_scope(); @@ -169,8 +171,8 @@ int TransformUint8Pass::DoNodeDTypeTrans(const CNodePtr &cnode) { } // update output quant param zp if (curr_quant_param_holder->get_output_quant_params().empty()) { - MS_LOG(ERROR) << "output quant params empty."; - return RET_ERROR; + MS_LOG(INFO) << "output quant params empty."; + return RET_NO_CHANGE; } auto out_quant_params = curr_quant_param_holder->get_output_quant_params()[0]; for (auto &quant_param : out_quant_params) { @@ -242,6 +244,10 @@ bool TransformUint8Pass::CheckNeedDTypeTrans(const CNodePtr &cnode) { MS_LOG(DEBUG) << "dtype not kNumberTypeUInt8, cnode name: " << cnode->fullname_with_scope(); return false; } + auto curr_quant_param_holder = GetCNodeQuantHolder(cnode); + if (curr_quant_param_holder->get_output_quant_params().empty()) { + return false; + } return true; } diff --git a/mindspore/lite/tools/converter/quantizer/quant_params.h b/mindspore/lite/tools/converter/quantizer/quant_params.h index 870b8bd8526..6d3a1b2e64e 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_params.h +++ b/mindspore/lite/tools/converter/quantizer/quant_params.h @@ -93,6 +93,7 @@ struct CommonQuantParam { DebugMode debug_mode = DETAIL; std::set skip_quant_node; int thread_num = 4; + bool enable_encode = true; }; struct MixedBitWeightQuantParam { diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index 5a1aca6c62d..88120c9e408 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -180,7 +180,8 @@ int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr MS_LOG(ERROR) << "QuantFilter failed : " << status; return status; } - if (compression && !is_mixed_bit_) { + bool is_compression = (compression && !is_mixed_bit_ && enable_encode_); + if (is_compression) { status = DoCompression(cnode, parameter, idx); if (status != RET_OK) { MS_LOG(ERROR) << cnode->fullname_with_scope() << " compression failed."; diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h index a01e5df4560..2ae6f561013 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h @@ -58,6 +58,7 @@ class WeightQuantizer : public Quantizer { explicit WeightQuantizer(const std::shared_ptr ¶m, double init_scale = 0) : Quantizer(param) { bit_num_ = param_->commonQuantParam.bit_num; + enable_encode_ = param_->commonQuantParam.enable_encode; if (bit_num_ == 0) { type_id_ = kNumberTypeInt16; is_mixed_bit_ = true; @@ -133,6 +134,7 @@ class WeightQuantizer : public Quantizer { std::set skip_quant_node_; std::unique_ptr quant_strategy_; schema::QuantType quant_type_{schema::QuantType_WeightQuant}; + bool enable_encode_{true}; }; } // namespace mindspore::lite::quant #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_