From 389380e25f034e2a59697e72f53ad607d10b6fb8 Mon Sep 17 00:00:00 2001 From: jianghui58 Date: Wed, 27 Jan 2021 10:43:11 +0800 Subject: [PATCH] huffman code support 1~8 bit && change it to internal interface --- mindspore/lite/src/dequant.cc | 19 +- mindspore/lite/src/dequant.h | 17 +- mindspore/lite/src/huffman_decode.cc | 8 +- mindspore/lite/src/huffman_decode.h | 6 +- mindspore/lite/src/lite_session.cc | 37 +- mindspore/lite/test/run_benchmark_nets.sh | 4 +- .../lite/tools/converter/anf_transform.cc | 31 +- .../lite/tools/converter/anf_transform.h | 2 +- .../lite/tools/converter/converter_flags.cc | 72 ++-- .../lite/tools/converter/converter_flags.h | 17 +- .../converter/quantizer/huffman_encode.cc | 107 +++-- .../converter/quantizer/huffman_encode.h | 13 +- .../converter/quantizer/weight_quantizer.cc | 370 ++++++++---------- .../converter/quantizer/weight_quantizer.h | 20 +- 14 files changed, 379 insertions(+), 344 deletions(-) diff --git a/mindspore/lite/src/dequant.cc b/mindspore/lite/src/dequant.cc index 22064b37e15..1c37aa340db 100644 --- a/mindspore/lite/src/dequant.cc +++ b/mindspore/lite/src/dequant.cc @@ -14,7 +14,10 @@ * limitations under the License. */ #include +#include +#include #include "src/dequant.h" +#include "src/huffman_decode.h" namespace mindspore::lite { float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { @@ -34,13 +37,24 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { } } -void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) { +int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) { MS_ASSERT(input_tensor != nullptr); MS_ASSERT(unpack_int_data != nullptr); auto quant_params = input_tensor->quantParams(); if (quant_params == nullptr) { MS_LOG(ERROR) << "low bits quantparams is empty."; - return; + return RET_ERROR; + } + auto enable_huffman_code = input_tensor->enableHuffmanCode(); + if (enable_huffman_code) { + std::string encode_str(input_tensor->data()->begin(), input_tensor->data()->end()); + auto huffman_decode = std::make_unique(); + auto ret = huffman_decode->DoHuffmanDecode(encode_str, unpack_int_data); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoHuffmanDecode failed."; + return ret; + } + return RET_OK; } int origin_bit = quant_params->Get(0)->numBits(); if (origin_bit < 8 && origin_bit > 0) { @@ -48,6 +62,7 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i } else if (origin_bit < 16 && origin_bit > 8) { UnPackUtil(input_tensor, origin_bit, unpack_int_data); } + return RET_OK; } std::map> DequantUtil::DequantTensor(const std::vector &in_tensors, diff --git a/mindspore/lite/src/dequant.h b/mindspore/lite/src/dequant.h index 094b8468ef6..5191a769bd1 100644 --- a/mindspore/lite/src/dequant.h +++ b/mindspore/lite/src/dequant.h @@ -31,7 +31,7 @@ class DequantUtil { public: static float *DequantWeight(lite::Tensor *input_tensor); - static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); + static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); static std::map> DequantTensor(const std::vector &in_tensors, TypeId data_type, bool need_restore = true); @@ -110,6 +110,21 @@ class DequantUtil { return dequant_datas; } + template + static void UnpackUtil(const T1 *weight_data, int pack_size, int origin_bit, void *unpack_int_data) { + if (weight_data == nullptr || unpack_int_data == nullptr) { + MS_LOG(ERROR) << "data is nullptr"; + return; + } + std::queue unpack_bit_data; + size_t count = 0; + for (int i = 0; i < pack_size; ++i) { + T2 pack_data = (static_cast(static_cast(weight_data)))[i]; + bool is_last = i == pack_size - 1; + UnPackData(origin_bit, pack_data, &unpack_bit_data, unpack_int_data, &count, is_last); + } + } + private: template static void UnPackData(int origin_bit, const T2 &packed_data, std::queue *unpack_bit_data, void *unpack_int, diff --git a/mindspore/lite/src/huffman_decode.cc b/mindspore/lite/src/huffman_decode.cc index d4176770dfc..8432571f5fd 100644 --- a/mindspore/lite/src/huffman_decode.cc +++ b/mindspore/lite/src/huffman_decode.cc @@ -19,7 +19,7 @@ namespace mindspore { namespace lite { -STATUS huffman_decode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) { +STATUS HuffmanDecode::DoHuffmanDecode(const std::string &input_str, void *decoded_data) { if (decoded_data == nullptr) { MS_LOG(ERROR) << "decoded_data is nullptr."; return RET_ERROR; @@ -64,7 +64,7 @@ STATUS huffman_decode::DoHuffmanDecode(const std::string &input_str, void *decod return RET_OK; } -STATUS huffman_decode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) { +STATUS HuffmanDecode::RebuildHuffmanTree(std::string keys, std::string codes, const HuffmanNodePtr &root) { HuffmanNodePtr cur_node, tmp_node, new_node; auto huffman_keys = Str2Vec(std::move(keys)); @@ -121,7 +121,7 @@ STATUS huffman_decode::RebuildHuffmanTree(std::string keys, std::string codes, c return RET_OK; } -STATUS huffman_decode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) { +STATUS HuffmanDecode::DoHuffmanDecompress(HuffmanNodePtr root, std::string encoded_data, std::string *decoded_str) { HuffmanNodePtr cur_node = root; bool pseudo_eof = false; size_t pos = 0; @@ -157,7 +157,7 @@ STATUS huffman_decode::DoHuffmanDecompress(HuffmanNodePtr root, std::string enco return RET_OK; } -huffman_decode::~huffman_decode() { +HuffmanDecode::~HuffmanDecode() { for (auto &node : this->huffman_nodes_) { delete node; } diff --git a/mindspore/lite/src/huffman_decode.h b/mindspore/lite/src/huffman_decode.h index dec0182d7e9..9f155370828 100644 --- a/mindspore/lite/src/huffman_decode.h +++ b/mindspore/lite/src/huffman_decode.h @@ -38,11 +38,11 @@ struct HuffmanNode { }; using HuffmanNodePtr = HuffmanNode *; -class huffman_decode { +class HuffmanDecode { public: - huffman_decode() = default; + HuffmanDecode() = default; - ~huffman_decode(); + ~HuffmanDecode(); STATUS DoHuffmanDecode(const std::string &input_str, void *decoded_data); diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index dad924459a5..1e0f72d6e58 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -28,7 +28,6 @@ #include "src/kernel_registry.h" #include "src/lite_model.h" #include "src/dequant.h" -#include "src/huffman_decode.h" #if SUPPORT_NPU #include "src/runtime/agent/npu/npu_manager.h" #include "src/runtime/agent/npu/optimizer/npu_pass_manager.h" @@ -96,13 +95,6 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde int org_size = dst_tensor->Size(); return (pack_size != org_size) && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16); }; - auto NeedHuffmanDecode = [&src_tensor, &dst_tensor]() -> bool { - auto data_type = src_tensor->dataType(); - auto enable_huffman_code = src_tensor->enableHuffmanCode(); - int pack_size = src_tensor->data()->size(); - int org_size = dst_tensor->Size(); - return (pack_size != org_size) && (data_type == kNumberTypeInt8) && enable_huffman_code; - }; auto src_category = TensorCategory(src_tensor); if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) && src_tensor->data() != nullptr && src_tensor->data()->size() > 0) { @@ -116,21 +108,6 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde return RET_ERROR; } } else { - if (NeedHuffmanDecode()) { - auto dst_data = dst_tensor->MutableData(); - if (dst_data == nullptr) { - MS_LOG(ERROR) << "Data from tensor is nullptr"; - return RET_NULL_PTR; - } - std::string encode_str(src_tensor->data()->begin(), src_tensor->data()->end()); - auto huffman_decode = std::make_unique(); - auto ret = huffman_decode->DoHuffmanDecode(encode_str, dst_data); - if (ret != RET_OK) { - MS_LOG(ERROR) << "DoHuffmanDecode failed."; - return ret; - } - copyed_tensor_idxes_.emplace_back(tensor_index); - } if (WeightTensorNeedCopy(model, tensor_index)) { auto dst_data = dst_tensor->MutableData(); if (dst_data == nullptr) { @@ -138,7 +115,11 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde return RET_NULL_PTR; } if (NeedUnPack()) { - DequantUtil::UnPackToInt(src_tensor, dst_data); + auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data); + if (ret != RET_OK) { + MS_LOG(ERROR) << "unpack to int failed."; + return RET_NULL_PTR; + } } else { memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size()); } @@ -148,9 +129,13 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde auto dst_data = dst_tensor->MutableData(); if (dst_data == nullptr) { MS_LOG(ERROR) << "Data from tensor is nullptr"; - return RET_NULL_PTR; + return RET_ERROR; + } + auto ret = DequantUtil::UnPackToInt(src_tensor, dst_data); + if (ret != RET_OK) { + MS_LOG(ERROR) << "unpack to int failed."; + return RET_ERROR; } - DequantUtil::UnPackToInt(src_tensor, dst_data); copyed_tensor_idxes_.emplace_back(tensor_index); } else { dst_tensor->set_data(const_cast(src_tensor->data()->data())); diff --git a/mindspore/lite/test/run_benchmark_nets.sh b/mindspore/lite/test/run_benchmark_nets.sh index 6ca5b82b5c6..fe8c5ad0848 100755 --- a/mindspore/lite/test/run_benchmark_nets.sh +++ b/mindspore/lite/test/run_benchmark_nets.sh @@ -227,8 +227,8 @@ function Run_Converter() { fi model_name=`echo ${weight_quant_line_info}|awk -F ' ' '{print $1}'` echo ${model_name} >> "${run_converter_log_file}" - echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true' >> "${run_converter_log_file}" - ./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 --enableHuffmanCode=true + echo './converter_lite --fmk=TFLITE --modelFile='${models_path}'/'${model_name}' --outputFile='${ms_models_path}'/'${model_name}'--quantType=WeightQuant --bitNum=8 --quantWeightChannel=0' >> "${run_converter_log_file}" + ./converter_lite --fmk=TFLITE --modelFile=$models_path/${model_name} --outputFile=${ms_models_path}/${model_name}_weightquant --quantType=WeightQuant --bitNum=8 --quantWeightChannel=0 if [ $? = 0 ]; then converter_result='converter weight_quant '${model_name}' pass';echo ${converter_result} >> ${run_converter_result_file} else diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 62547ce35e1..786b2ef5fb1 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -215,26 +215,14 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla const FuncGraphPtr &new_graph) { // quant if (config->quantType == schema::QuantType_PostTraining) { - if (!quant::WeightQuantizer::IsPosNum(config->bitNum)) { - MS_LOG(ERROR) << "bitNum must be valid pos num."; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); - return RET_ERROR; - } - this->mQuantizer = - std::make_unique(new_graph, config->configFile, std::stoi(config->bitNum)); + this->mQuantizer = std::make_unique(new_graph, config->configFile, config->bitNum); if (mQuantizer == nullptr) { MS_LOG(ERROR) << "New PostTrainingQuantizer failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); return RET_ERROR; } } else if (config->quantType == schema::QuantType_WeightQuant) { - if (quant::WeightQuantizer::WeightQuantInputCheck(config) != RET_OK) { - MS_LOG(ERROR) << "weight quant input param error"; - ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); - return RET_ERROR; - } - this->mQuantizer = std::make_unique(new_graph, config->configFile, config->quantWeightSize, - config->quantWeightChannel, config->bitNum); + this->mQuantizer = std::make_unique(new_graph, *config); if (mQuantizer == nullptr) { MS_LOG(ERROR) << "New WeightQuantizer failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED); @@ -253,10 +241,15 @@ int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Fla return RET_OK; } -int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph) { - if (config->quantType == schema::QuantType_WeightQuant && config->bitNum == "8" && config->enableHuffmanCode) { - auto huffman_encode = std::make_unique(); - auto status = huffman_encode->DoHuffmanEncode(new_graph); +int AnfTransform::DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph, + bool enableHuffmanCode) { + if (config->quantType == schema::QuantType_WeightQuant && enableHuffmanCode) { + if (config->bitNum < 16 && config->bitNum > 8) { + MS_LOG(WARNING) << "don't support huffman encode when 8 < bitNum < 16 currently."; + return RET_OK; + } + auto huffman_encode = std::make_unique(); + auto status = huffman_encode->DoHuffmanEncode(new_graph, config->bitNum); if (status != RET_OK) { MS_LOG(ERROR) << "Huffman encode failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); @@ -320,7 +313,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap return nullptr; } - status = DoHuffmanEncode(config, new_graph); + status = DoHuffmanEncode(config, new_graph, false); if (status != RET_OK) { MS_LOG(ERROR) << "Do HuffmanCode failed."; return nullptr; diff --git a/mindspore/lite/tools/converter/anf_transform.h b/mindspore/lite/tools/converter/anf_transform.h index 2491cf32b75..4ed86d16dff 100644 --- a/mindspore/lite/tools/converter/anf_transform.h +++ b/mindspore/lite/tools/converter/anf_transform.h @@ -59,7 +59,7 @@ class AnfTransform { int DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config, const FuncGraphPtr &new_graph); - int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph); + int DoHuffmanEncode(const converter::Flags *config, const FuncGraphPtr &new_graph, bool enableHuffmanCode); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/converter_flags.cc b/mindspore/lite/tools/converter/converter_flags.cc index 940b03968dd..1705729b2f3 100644 --- a/mindspore/lite/tools/converter/converter_flags.cc +++ b/mindspore/lite/tools/converter/converter_flags.cc @@ -38,16 +38,12 @@ Flags::Flags() { "UINT8 | DEFAULT", "DEFAULT"); AddFlag(&Flags::quantTypeIn, "quantType", "Quantization Type. PostTraining | WeightQuant", ""); - AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8"); - AddFlag(&Flags::quantWeightSize, "quantWeightSize", "Weight quantization size threshold", "0"); - AddFlag(&Flags::quantWeightChannel, "quantWeightChannel", "Channel threshold for weight quantization", "16"); + AddFlag(&Flags::bitNumIn, "bitNum", "Weight quantization bitNum", "8"); + AddFlag(&Flags::quantWeightSizeIn, "quantWeightSize", "Weight quantization size threshold", "0"); + AddFlag(&Flags::quantWeightChannelIn, "quantWeightChannel", "Channel threshold for weight quantization", "16"); AddFlag(&Flags::configFile, "configFile", "Configuration for post-training.", ""); - AddFlag(&Flags::enableHuffmanCodeIn, "enableHuffmanCode", - "whether the weight quant model is going to use huffman code." - "true | false", - "false"); AddFlag(&Flags::trainModelIn, "trainModel", - "whether the model is going to be trained on device." + "whether the model is going to be trained on device. " "true | false", "false"); } @@ -107,7 +103,41 @@ int Flags::InitFmk() { return RET_OK; } -int Flags::InitQuantType() { +bool Flags::IsValidNum(const std::string &str, int *num) { + char *ptr; + *num = strtol(str.c_str(), &ptr, 10); + return ptr == (str.c_str() + str.size()); +} + +int Flags::QuantParamInputCheck() { + if (!Flags::IsValidNum(this->quantWeightChannelIn, &this->quantWeightChannel)) { + std::cerr << "quantWeightChannel should be a valid number."; + return RET_INPUT_PARAM_INVALID; + } + if (this->quantWeightChannel < 0) { + std::cerr << "quantWeightChannel should be greater than or equal to zero."; + return RET_INPUT_PARAM_INVALID; + } + if (!Flags::IsValidNum(this->quantWeightSizeIn, &this->quantWeightSize)) { + std::cerr << "quantWeightSize should be a valid number."; + return RET_INPUT_PARAM_INVALID; + } + if (this->quantWeightSize < 0) { + std::cerr << "quantWeightSize should be greater than or equal to zero."; + return RET_INPUT_PARAM_INVALID; + } + if (!Flags::IsValidNum(this->bitNumIn, &this->bitNum)) { + std::cerr << "bitNum should be a valid number."; + return RET_INPUT_PARAM_INVALID; + } + if (this->bitNum <= 0 || this->bitNum > 16) { + std::cerr << "bitNum should be greater than zero and lesser than 16 currently."; + return RET_INPUT_PARAM_INVALID; + } + return RET_OK; +} + +int Flags::InitQuantParam() { if (this->quantTypeIn == "WeightQuant") { this->quantType = QuantType_WeightQuant; } else if (this->quantTypeIn == "PostTraining") { @@ -118,19 +148,9 @@ int Flags::InitQuantType() { std::cerr << "INPUT ILLEGAL: quantType must be WeightQuant|PostTraining"; return RET_INPUT_PARAM_INVALID; } - return RET_OK; -} -int Flags::InitHuffmanCode() { - if (this->enableHuffmanCodeIn == "true") { - this->enableHuffmanCode = true; - } else if (this->enableHuffmanCodeIn == "false") { - this->enableHuffmanCode = false; - } else { - std::cerr << "INPUT ILLEGAL: trainModel must be true|false "; - return RET_INPUT_PARAM_INVALID; - } - return RET_OK; + auto ret = QuantParamInputCheck(); + return ret; } int Flags::InitTrainModel() { @@ -218,15 +238,9 @@ int Flags::Init(int argc, const char **argv) { return RET_INPUT_PARAM_INVALID; } - ret = InitQuantType(); + ret = InitQuantParam(); if (ret != RET_OK) { - std::cerr << "Init quant type failed."; - return RET_INPUT_PARAM_INVALID; - } - - ret = InitHuffmanCode(); - if (ret != RET_OK) { - std::cerr << "Init huffman code failed."; + std::cerr << "Init quant param failed."; return RET_INPUT_PARAM_INVALID; } diff --git a/mindspore/lite/tools/converter/converter_flags.h b/mindspore/lite/tools/converter/converter_flags.h index 214b98b514f..92c0cc2376b 100644 --- a/mindspore/lite/tools/converter/converter_flags.h +++ b/mindspore/lite/tools/converter/converter_flags.h @@ -49,9 +49,11 @@ class Flags : public virtual mindspore::lite::FlagParser { int InitFmk(); - int InitQuantType(); + bool IsValidNum(const std::string &str, int *num); - int InitHuffmanCode(); + int QuantParamInputCheck(); + + int InitQuantParam(); int InitTrainModel(); @@ -76,12 +78,13 @@ class Flags : public virtual mindspore::lite::FlagParser { TypeId inputDataType; TypeId outputDataType; // used for post-trainning-weight - std::string quantWeightSize; - std::string bitNum; + std::string quantWeightSizeIn; + int quantWeightSize; + std::string bitNumIn; + int bitNum; std::string configFile; - std::string quantWeightChannel; - std::string enableHuffmanCodeIn; - bool enableHuffmanCode = false; + std::string quantWeightChannelIn; + int quantWeightChannel; std::string trainModelIn; bool trainModel = false; }; diff --git a/mindspore/lite/tools/converter/quantizer/huffman_encode.cc b/mindspore/lite/tools/converter/quantizer/huffman_encode.cc index 4dad3fe68e7..b3b50c8a055 100644 --- a/mindspore/lite/tools/converter/quantizer/huffman_encode.cc +++ b/mindspore/lite/tools/converter/quantizer/huffman_encode.cc @@ -18,18 +18,51 @@ #include #include -#include -#include -#include "securec/include/securec.h" -#include "src/param_value_lite.h" +#include "src/dequant.h" namespace mindspore { namespace lite { -STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) { +STATUS HuffmanEncode::GetParamValueLitePtr(const std::shared_ptr &input_node, ParamValueLitePtr *param_value) { + if (!input_node->isa()) { + return RET_CONTINUE; + } + auto abstract_base = input_node->abstract(); + if (abstract_base == nullptr) { + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope(); + return RET_ERROR; + } + if (!utils::isa(abstract_base)) { + MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope(); + return RET_ERROR; + } + auto abstract_tensor = utils::cast(abstract_base); + if (abstract_tensor->element() == nullptr) { + MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope(); + return RET_ERROR; + } + auto tensor_type = abstract_tensor->element()->GetTypeTrack(); + MS_ASSERT(tensor_type != nullptr); + auto tensor_type_id = tensor_type->type_id(); + if (tensor_type_id != kNumberTypeInt8) { + return RET_CONTINUE; + } + auto param_node = input_node->cast(); + if (param_node == nullptr) { + MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope(); + return RET_ERROR; + } + if (!param_node->has_default()) { + MS_LOG(WARNING) << "param_node don't have default: " << input_node->fullname_with_scope(); + return RET_CONTINUE; + } + *param_value = std::static_pointer_cast(param_node->default_param()); + return RET_OK; +} + +STATUS HuffmanEncode::DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num) { auto cnodes = func_graph->GetOrderedCnodes(); - STATUS status; for (auto &cnode : cnodes) { auto primitive_c = GetValueNode>(cnode->input(0)); if (primitive_c == nullptr) { @@ -41,45 +74,33 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) { } for (size_t i = 1; i < cnode->inputs().size(); i++) { auto input_node = cnode->input(i); - if (!input_node->isa()) { + ParamValueLitePtr param_value; + auto status = GetParamValueLitePtr(input_node, ¶m_value); + if (status == RET_CONTINUE) { continue; - } - auto abstract_base = input_node->abstract(); - if (abstract_base == nullptr) { - MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope(); + } else if (status == RET_ERROR) { + MS_LOG(ERROR) << "Get param value lite ptr failed. " << cnode->fullname_with_scope(); return RET_ERROR; } - if (!utils::isa(abstract_base)) { - MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << input_node->fullname_with_scope(); - return RET_ERROR; - } - auto abstract_tensor = utils::cast(abstract_base); - if (abstract_tensor->element() == nullptr) { - MS_LOG(ERROR) << "abstract tensor element is nullptr, " << input_node->fullname_with_scope(); - return RET_ERROR; - } - auto tensor_type = abstract_tensor->element()->GetTypeTrack(); - MS_ASSERT(tensor_type != nullptr); - auto tensor_type_id = tensor_type->type_id(); - if (tensor_type_id != kNumberTypeInt8) { - continue; - } - auto param_node = input_node->cast(); - if (param_node == nullptr) { - MS_LOG(ERROR) << "parameter node is nullptr, " << input_node->fullname_with_scope(); - return RET_ERROR; - } - if (!param_node->has_default()) { - MS_LOG(WARNING) << "param_node don't have default: " << cnode->fullname_with_scope(); - continue; - } - ParamValueLitePtr param_value = std::static_pointer_cast(param_node->default_param()); size_t elem_count = param_value->tensor_shape_size(); + size_t packed_size = param_value->tensor_size(); auto *raw_datas = static_cast(param_value->tensor_addr()); if (raw_datas == nullptr) { MS_LOG(ERROR) << "rawDatas is nullptr"; return RET_ERROR; } + if (bit_num < 8 && bit_num > 0) { + auto dst_data = new (std::nothrow) int8_t[elem_count]; + if (dst_data == nullptr) { + MS_LOG(ERROR) << "new int8_t[] failed"; + return RET_ERROR; + } + DequantUtil::UnpackUtil(raw_datas, packed_size, bit_num, dst_data); + if (memcpy_s(raw_datas, elem_count, dst_data, elem_count) != EOK) { + MS_LOG(ERROR) << "memcpy_s failed."; + return RET_MEMORY_FAILED; + } + } HuffmanPriorityQueue pq; status = GetHuffmanPriorityQueue(raw_datas, elem_count, &pq); if (status != RET_OK) { @@ -97,12 +118,14 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) { return status; } size_t ch_size = huffman_encoded_str_.length(); - if (ch_size < elem_count) { + if (ch_size < packed_size) { auto encode_data = new (std::nothrow) char[ch_size]; if (encode_data == nullptr) { MS_LOG(ERROR) << "new char[] failed."; + delete[] raw_datas; return RET_MEMORY_FAILED; } + delete[] raw_datas; if (memcpy_s(encode_data, ch_size, huffman_encoded_str_.c_str(), ch_size) != EOK) { MS_LOG(ERROR) << "memcpy_s failed."; delete[] encode_data; @@ -118,7 +141,7 @@ STATUS huffman_encode::DoHuffmanEncode(const FuncGraphPtr &func_graph) { return RET_SUCCESS; } -STATUS huffman_encode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) { +STATUS HuffmanEncode::GetHuffmanPriorityQueue(const int8_t *data, const size_t data_size, HuffmanPriorityQueue *pq) { MS_ASSERT(data != nullptr); std::map freq_map; @@ -166,7 +189,7 @@ STATUS huffman_encode::GetHuffmanPriorityQueue(const int8_t *data, const size_t return RET_OK; } -void huffman_encode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) { +void HuffmanEncode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_left_node) { if (is_left_node) { node->code = node->parent->code + "0"; } else { @@ -185,7 +208,7 @@ void huffman_encode::GenerateHuffmanTable(const HuffmanNodePtr node, bool is_lef } } -STATUS huffman_encode::BuildHuffmanTree(HuffmanPriorityQueue *pq) { +STATUS HuffmanEncode::BuildHuffmanTree(HuffmanPriorityQueue *pq) { HuffmanNodePtr root = nullptr; while (!pq->empty()) { @@ -228,7 +251,7 @@ STATUS huffman_encode::BuildHuffmanTree(HuffmanPriorityQueue *pq) { return RET_OK; } -STATUS huffman_encode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) { +STATUS HuffmanEncode::DoHuffmanCompress(const int8_t *input_datas, const size_t data_size) { unsigned char out_c; string code_str; std::map::iterator iter; @@ -270,7 +293,7 @@ STATUS huffman_encode::DoHuffmanCompress(const int8_t *input_datas, const size_t return RET_OK; } -huffman_encode::~huffman_encode() { +HuffmanEncode::~HuffmanEncode() { for (auto &node : this->huffman_nodes_) { delete node; } diff --git a/mindspore/lite/tools/converter/quantizer/huffman_encode.h b/mindspore/lite/tools/converter/quantizer/huffman_encode.h index f7418d9ba42..ff02e452638 100644 --- a/mindspore/lite/tools/converter/quantizer/huffman_encode.h +++ b/mindspore/lite/tools/converter/quantizer/huffman_encode.h @@ -23,9 +23,12 @@ #include #include #include +#include #include #include "src/common/log_adapter.h" #include "src/ops/primitive_c.h" +#include "securec/include/securec.h" +#include "src/param_value_lite.h" #include "ir/func_graph.h" namespace mindspore { @@ -49,13 +52,15 @@ struct cmp { }; using HuffmanPriorityQueue = std::priority_queue, cmp>; -class huffman_encode { +class HuffmanEncode { public: - huffman_encode() = default; + HuffmanEncode() = default; - ~huffman_encode(); + ~HuffmanEncode(); - STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph); + STATUS GetParamValueLitePtr(const std::shared_ptr &input_node, ParamValueLitePtr *param_value); + + STATUS DoHuffmanEncode(const FuncGraphPtr &func_graph, const int &bit_num); private: std::map huffman_table_; diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index 09a037986b3..993250587c9 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -25,52 +25,16 @@ using std::string; using std::vector; namespace mindspore::lite::quant { -bool WeightQuantizer::IsPosNum(const std::string &str) { - for (size_t i = 0; i < str.size(); i++) { - if (str.at(i) < '0' || str.at(i) > '9') { - return false; - } - if (str.at(i) == '0' && i == 0 && str.size() != 1) { - return false; - } - } - return true; -} - -STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { - MS_ASSERT(config != nullptr); - if (!WeightQuantizer::IsPosNum(config->quantWeightChannel)) { - MS_LOG(ERROR) << "quantWeightChannel must be valid pos num."; - return RET_ERROR; - } - if (!WeightQuantizer::IsPosNum(config->quantWeightSize)) { - MS_LOG(ERROR) << "quantWeightSize must be valid pos num."; - return RET_ERROR; - } - if (!WeightQuantizer::IsPosNum(config->bitNum)) { - MS_LOG(ERROR) << "bitNum must be valid pos num."; - return RET_ERROR; - } - int bitNum = std::stoi(config->bitNum); - if (bitNum <= 0 || bitNum > 16) { - MS_LOG(ERROR) << "bitNum should be more than 0 and less than 16 currently."; - return RET_ERROR; - } - return RET_OK; -} - WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) { quant_strategy_ = std::make_unique(0, 0); config_param_ = config; } -WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const string &weightSize, - const std::string &convWeightChannelThreshold, const std::string &bitNum) - : Quantizer(graph) { - this->config_file_ = config_file; - auto quantSize = static_cast(std::stoull(weightSize)); - this->bit_num_ = static_cast(std::stoull(bitNum)); - auto convQuantWeightChannelThreshold = static_cast(std::stoull(convWeightChannelThreshold)); +WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config) : Quantizer(graph) { + this->config_file_ = config.configFile; + auto quantSize = config.quantWeightSize; + this->bit_num_ = config.bitNum; + auto convQuantWeightChannelThreshold = config.quantWeightChannel; quant_strategy_ = std::make_unique(quantSize, convQuantWeightChannelThreshold); quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1)); @@ -222,7 +186,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) { return RET_OK; } -STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) { +STATUS WeightQuantizer::DoLstmQuantize(CNodePtr cnode) { MS_ASSERT(cnode != nullptr); auto op_name = cnode->fullname_with_scope(); @@ -233,110 +197,29 @@ STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) { MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size(); return RET_ERROR; } - { - auto weight_i = cnode->input(2); - ParameterPtr param_node; - ParamValueLitePtr param_value; - GetLiteParameter(weight_i, ¶m_node, ¶m_value); - if (param_node == nullptr || param_value == nullptr) { - MS_LOG(ERROR) << "GetLiteParameter error"; - return RET_ERROR; - } - if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { - MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; - return RET_OK; - } - if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) { - MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < " - << quant_strategy_->mWeightSize; - return RET_OK; - } - auto status = RET_ERROR; - if (type_id_ == kNumberTypeInt8) { - status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, - false, 1); - } else if (type_id_ == kNumberTypeInt16) { - status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, - false, 1); - } + + auto status = ProcessLstmWeightByIndex(cnode, primitive_c, 2); + if (status != RET_OK) { + MS_LOG(ERROR) << "Process lstm weight i failed."; + return RET_ERROR; + } + status = ProcessLstmWeightByIndex(cnode, primitive_c, 3); + if (status != RET_OK) { + MS_LOG(ERROR) << "Process lstm weight h failed."; + return RET_ERROR; + } + if (cnode->inputs().size() > 4) { + status = ProcessLstmWeightByIndex(cnode, primitive_c, 4); if (status != RET_OK) { - MS_LOG(ERROR) << "QuantFilter failed : " << status; - return status; - } - status = SetAbstract(param_value, param_node, primitive_c); - if (status != RET_OK) { - MS_LOG(ERROR) << "SetAbstract failed : " << status; + MS_LOG(ERROR) << "Process lstm bias failed."; return RET_ERROR; } } - { - auto weight_h = cnode->input(3); - ParameterPtr param_node; - ParamValueLitePtr param_value; - GetLiteParameter(weight_h, ¶m_node, ¶m_value); - if (param_node == nullptr || param_value == nullptr) { - MS_LOG(ERROR) << "GetLiteParameter error"; - return RET_ERROR; - } - if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { - MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; - return RET_ERROR; - } - auto status = RET_ERROR; - if (type_id_ == kNumberTypeInt8) { - status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, - false, 2); - } else if (type_id_ == kNumberTypeInt16) { - status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, - false, 2); - } - if (status != RET_OK) { - MS_LOG(ERROR) << "QuantFilter failed : " << status; - return status; - } - status = SetAbstract(param_value, param_node, primitive_c); - if (status != RET_OK) { - MS_LOG(ERROR) << "SetAbstract failed : " << status; - return RET_ERROR; - } - } - { - if (cnode->inputs().size() > 4) { - auto bias = cnode->input(4); - ParameterPtr param_node; - ParamValueLitePtr param_value; - GetLiteParameter(bias, ¶m_node, ¶m_value); - if (param_node == nullptr || param_value == nullptr) { - MS_LOG(ERROR) << "GetLiteParameter error"; - return RET_ERROR; - } - if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { - MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; - return RET_ERROR; - } - auto status = RET_ERROR; - if (type_id_ == kNumberTypeInt8) { - status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, - false, 3); - } else if (type_id_ == kNumberTypeInt16) { - status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, - false, 3); - } - if (status != RET_OK) { - MS_LOG(ERROR) << "QuantFilter failed : " << status; - return status; - } - status = SetAbstract(param_value, param_node, primitive_c); - if (status != RET_OK) { - MS_LOG(ERROR) << "SetAbstract failed : " << status; - return RET_ERROR; - } - } - } - return RET_OK; + + return status; } -STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) { +STATUS WeightQuantizer::DoGatherQuantize(CNodePtr cnode) { auto primitive_c = GetValueNode>(cnode->input(0)); MS_ASSERT(primitive_c != nullptr); @@ -375,6 +258,46 @@ STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) { return RET_OK; } +STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const std::shared_ptr &primitive_c, + const int &index) { + auto op_name = cnode->fullname_with_scope(); + auto weight_i = cnode->input(index); + ParameterPtr param_node; + ParamValueLitePtr param_value; + GetLiteParameter(weight_i, ¶m_node, ¶m_value); + if (param_node == nullptr || param_value == nullptr) { + MS_LOG(ERROR) << "GetLiteParameter error"; + return RET_ERROR; + } + if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { + MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant"; + return RET_OK; + } + if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) { + MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < " + << quant_strategy_->mWeightSize; + return RET_OK; + } + auto status = RET_ERROR; + if (type_id_ == kNumberTypeInt8) { + status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, + false, index - 1); + } else if (type_id_ == kNumberTypeInt16) { + status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, + false, index - 1); + } + if (status != RET_OK) { + MS_LOG(ERROR) << "QuantFilter failed : " << status; + return status; + } + status = SetAbstract(param_value, param_node, primitive_c); + if (status != RET_OK) { + MS_LOG(ERROR) << "SetAbstract failed : " << status; + return RET_ERROR; + } + return RET_OK; +} + constexpr float relative_tolerance = 1e-5; constexpr float abs_tolerance = 1e-4; @@ -510,37 +433,28 @@ STATUS WeightQuantizer::RunFp32Graph(FuncGraphPtr func_graph) { return RET_OK; } -STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { - // 0.2 Parse input calib files - auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); - if (status != RET_OK) { - MS_LOG(ERROR) << "CollectCalibInputs fail"; - return RET_ERROR; - } - - MS_LOG(DEBUG) << "run fp32 model"; - status = RunFp32Graph(func_graph); - if (status != RET_OK) { - return RET_ERROR; - } - +STATUS WeightQuantizer::DoMixedQuantize(const FuncGraphPtr &func_graph) { auto cnodes = func_graph->GetOrderedCnodes(); + int status = RET_OK; for (auto &cnode : cnodes) { auto op_type = NodePrimitiveType(cnode); if (op_type == schema::PrimitiveType_Lstm) { - status = DoLstmQuntize(cnode); + status = DoLstmQuantize(cnode); if (status != RET_OK) { - MS_LOG(ERROR) << "DoLstmQuntize error"; + MS_LOG(ERROR) << "DoLstmQuantize error"; return RET_ERROR; } } else if (op_type == schema::PrimitiveType_Gather) { - status = DoGatherQuntize(cnode); + status = DoGatherQuantize(cnode); if (status != RET_OK) { - MS_LOG(ERROR) << "DoGatherQuntize error"; + MS_LOG(ERROR) << "DoGatherQuantize error"; return RET_ERROR; } } } + return status; +} +STATUS WeightQuantizer::CheckImageCnt() { auto image_cnt = images_.at(0).size(); if (!config_param_.input_shapes.empty()) { if (config_param_.input_shapes.size() != image_cnt) { @@ -548,7 +462,62 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { return RET_ERROR; } } + return RET_OK; +} +STATUS WeightQuantizer::GetParamNodeAndValue(const std::shared_ptr &input_node, const std::string &op_name, + ParameterPtr *param_node, ParamValueLitePtr *param_value) { + if (!input_node->isa()) { + MS_LOG(WARNING) << op_name << " the second input is not parameter"; + return RET_CONTINUE; + } + *param_node = input_node->cast(); + if (!(*param_node)->has_default()) { + MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; + return RET_CONTINUE; + } + *param_value = std::static_pointer_cast((*param_node)->default_param()); + if (*param_value == nullptr) { + MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; + return RET_CONTINUE; + } + if ((*param_value)->tensor_type() != TypeId::kNumberTypeFloat32) { + MS_LOG(WARNING) << op_name << " the second input type is not float"; + return RET_CONTINUE; + } + return RET_OK; +} +STATUS WeightQuantizer::TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, + const ParamValueLitePtr ¶m_value, const std::shared_ptr &primitive_c) { + int status; + type_id_ = TypeId::kNumberTypeInt8; + int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; + int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); + if (type_id_ == TypeId::kNumberTypeInt8) { + status = QuantFilter(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t, + bit_num_t, true); + } else if (type_id_ == TypeId::kNumberTypeInt16) { + status = QuantFilter(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t, + bit_num_t, true); + } else { + MS_LOG(ERROR) << "unexpected type_id_: " << type_id_; + return RET_ERROR; + } + if (status != RET_OK) { + MS_LOG(ERROR) << "quant filter failed."; + return RET_ERROR; + } + status = SetAbstract(param_value, param_node, primitive_c); + if (status != RET_OK) { + MS_LOG(ERROR) << "SetAbstract failed : " << status; + return RET_ERROR; + } + return status; +} +STATUS WeightQuantizer::DoQuantSearch(const FuncGraphPtr &func_graph) { + auto cnodes = func_graph->GetOrderedCnodes(); + auto image_cnt = images_.at(0).size(); + int status = RET_OK; for (auto iter = cnodes.end(); iter != cnodes.begin();) { auto cnode = *(--iter); auto primitive_c = GetValueNode>(cnode->input(0)); @@ -561,22 +530,10 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { << " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type()); if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) { auto input_node = cnode->input(2); - if (!input_node->isa()) { - MS_LOG(WARNING) << op_name << " the second input is not parameter"; - continue; - } - auto param_node = input_node->cast(); - if (!param_node->has_default()) { - MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; - continue; - } - auto param_value = std::static_pointer_cast(param_node->default_param()); - if (param_value == nullptr) { - MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; - continue; - } - if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { - MS_LOG(WARNING) << op_name << " the second input type is not float"; + ParameterPtr param_node; + ParamValueLitePtr param_value; + status = GetParamNodeAndValue(input_node, op_name, ¶m_node, ¶m_value); + if (status == RET_CONTINUE) { continue; } // copy origin data in case to recover @@ -591,27 +548,9 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { } // 1. try quant for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) { - type_id_ = TypeId::kNumberTypeInt8; - int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; - int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); - - if (type_id_ == TypeId::kNumberTypeInt8) { - status = QuantFilter(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, - quant_min_t, bit_num_t, true); - } else if (type_id_ == TypeId::kNumberTypeInt16) { - status = QuantFilter(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, - quant_min_t, bit_num_t, true); - } else { - MS_LOG(ERROR) << "unexpected type_id_: " << type_id_; - return RET_ERROR; - } + status = TryQuant(bit_num_t, param_node, param_value, primitive_c); if (status != RET_OK) { - MS_LOG(ERROR) << "quant filter fail."; - return RET_ERROR; - } - status = SetAbstract(param_value, param_node, primitive_c); - if (status != RET_OK) { - MS_LOG(ERROR) << "SetAbstract failed : " << status; + MS_LOG(ERROR) << "TryQuant failed."; return RET_ERROR; } // 2. evaluate the quant @@ -679,6 +618,41 @@ STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { free(origin_data); } // if: conv and matmul } // end loop: all cnode + return status; +} + +STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) { + // 0.2 Parse input calib files + auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); + if (status != RET_OK) { + MS_LOG(ERROR) << "CollectCalibInputs failed."; + return RET_ERROR; + } + + status = RunFp32Graph(func_graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "RunFp32Graph failed."; + return RET_ERROR; + } + + status = DoMixedQuantize(func_graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoMixedQuantize failed."; + return RET_ERROR; + } + + status = CheckImageCnt(); + if (status != RET_OK) { + MS_LOG(ERROR) << "CheckImageCnt failed."; + return RET_ERROR; + } + + status = DoQuantSearch(func_graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoQuantSearch failed."; + return RET_ERROR; + } + for (const auto &kv : opname_bit_) { MS_LOG(INFO) << "op: " << kv.first << " bit:" << kv.second; } @@ -709,15 +683,15 @@ STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) { return RET_ERROR; } } else if (op_type == schema::PrimitiveType_Lstm) { - auto status = DoLstmQuntize(cnode); + auto status = DoLstmQuantize(cnode); if (status != RET_OK) { - MS_LOG(ERROR) << "DoLstmQuntize error"; + MS_LOG(ERROR) << "DoLstmQuantize error"; return RET_ERROR; } } else if (op_type == schema::PrimitiveType_Gather) { - auto status = DoGatherQuntize(cnode); + auto status = DoGatherQuantize(cnode); if (status != RET_OK) { - MS_LOG(ERROR) << "DoGatherQuntize error"; + MS_LOG(ERROR) << "DoGatherQuantize error"; return RET_ERROR; } } else { diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h index 0d749b3aaf3..791c3c9bc13 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h @@ -36,18 +36,18 @@ namespace mindspore::lite::quant { class WeightQuantizer : public Quantizer { public: - WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const std::string &weightSize, - const std::string &covWeightChannelThreshold, const std::string &bitNum); + WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config); WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config); ~WeightQuantizer(); STATUS DoQuantize(FuncGraphPtr func_graph) override; STATUS DoConvQuantize(CNodePtr); STATUS DoMulQuantize(CNodePtr); - STATUS DoLstmQuntize(CNodePtr cnode); - STATUS DoGatherQuntize(CNodePtr cnode); - static STATUS WeightQuantInputCheck(const converter::Flags *config); - static bool IsPosNum(const std::string &str); + STATUS DoLstmQuantize(CNodePtr cnode); + STATUS DoGatherQuantize(CNodePtr cnode); + + STATUS ProcessLstmWeightByIndex(const CNodePtr &cnode, const std::shared_ptr &primitive_c, + const int &index); int quant_max_{127}; int quant_min_{-128}; @@ -66,6 +66,14 @@ class WeightQuantizer : public Quantizer { STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr primitive_c); STATUS DoFixedQuant(FuncGraphPtr); STATUS RunFp32Graph(FuncGraphPtr); + + STATUS DoMixedQuantize(const FuncGraphPtr &func_graph); + STATUS CheckImageCnt(); + STATUS GetParamNodeAndValue(const std::shared_ptr &input_node, const std::string &op_name, + ParameterPtr *param_node, ParamValueLitePtr *param_value); + STATUS TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, const ParamValueLitePtr ¶m_value, + const std::shared_ptr &primitive_c); + STATUS DoQuantSearch(const FuncGraphPtr &func_graph); }; } // namespace mindspore::lite::quant #endif