From b5f8556d63630f3f9e05e445cb9b2187fe51f770 Mon Sep 17 00:00:00 2001 From: xutianchun Date: Tue, 25 Aug 2020 18:12:27 +0800 Subject: [PATCH] post quantization code review --- mindspore/lite/tools/converter/converter.cc | 30 ++----- mindspore/lite/tools/converter/converter.h | 2 +- .../quantizer/post_training_quantizer.cc | 83 ++++++++++--------- .../quantizer/post_training_quantizer.h | 18 ++-- .../converter/quantizer/quantize_util.cc | 12 +-- .../tools/converter/quantizer/quantizer.h | 2 +- 6 files changed, 58 insertions(+), 89 deletions(-) diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 75c24ca7fcf..da6ba860cd1 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -15,9 +15,9 @@ */ #include "tools/converter/converter.h" +#include #include #include -#include #include "tools/converter/converter_flags.h" #include "src/common/common.h" #include "src/common/file_utils.h" @@ -141,31 +141,11 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { return meta_graph; } -void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) { +void Converter::CreateQuantizer(FuncGraphPtr func_graph, const converter::Flags *flags) { auto type = flags->quantType; - switch (type) { - case mindspore::schema::QuantType_AwareTraining: { - // mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean)); - break; - } - // case mindspore::schema::QuantType_WeightQuant: { - // MS_LOG(INFO) << "create WeightQuantizer!"; - // mQuantizer.reset( - // new quant::WeightQuantizer(funcGraph, flags->quantSize, flags->convWeightQuantChannelThreshold, - // flags->bitNum)); - // break; - // } - case mindspore::schema::QuantType_PostTraining: { - MS_LOG(INFO) << "create PostTrainningQuantizer!"; - mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8)); - break; - } - case mindspore::schema::QuantType_QUANT_NONE: - MS_LOG(INFO) << "Not do quantization for model!"; - break; - default: - MS_LOG(INFO) << "will support quntizer type " << flags->quantTypeIn.c_str() << " in the future!"; - break; + if (type == mindspore::schema::QuantType_PostTraining) { + MS_LOG(INFO) << "create post training quantizer."; + mQuantizer.reset(new quant::PostTrainingQuantizer(func_graph, flags->configFile, 8)); } } int RunConverter(int argc, const char **argv) { diff --git a/mindspore/lite/tools/converter/converter.h b/mindspore/lite/tools/converter/converter.h index 1002946f70d..e71b34ce970 100644 --- a/mindspore/lite/tools/converter/converter.h +++ b/mindspore/lite/tools/converter/converter.h @@ -34,7 +34,7 @@ class Converter { Converter(); virtual ~Converter(); virtual schema::MetaGraphT *Convert(const lite::converter::Flags *flags); - void CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags); + void CreateQuantizer(FuncGraphPtr func_graph, const converter::Flags *flags); void FreeFuncGraph(const FuncGraphPtr &func_graph); protected: diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index 7a24b7728e8..3aa6d3d2f04 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -80,7 +80,7 @@ struct DivergInfo { this->interval = max_value / static_cast(bin_num); } - STATUS UpdateHistogram(const std::vector &data, const std::vector &shape) { + STATUS UpdateHistogram(const std::vector &data) { for (auto value : data) { if (value == 0) { continue; @@ -235,7 +235,7 @@ struct DivergInfo { return std::make_pair(this->cnode, zero_point); } }; -std::unordered_map Calibrator::GetResult( +std::unordered_map Calibrator::GetScale( std::unordered_map> *diverg_info) { std::unordered_map result; for (auto iter = diverg_info->begin(); iter != diverg_info->end(); iter++) { @@ -246,9 +246,9 @@ std::unordered_map Calibrator::GetResult( return result; } std::unordered_map Calibrator::GetZeropoint( - std::unordered_map> *mDivergInfo) { + std::unordered_map> *diverg_info) { std::unordered_map result; - for (auto iter = mDivergInfo->begin(); iter != mDivergInfo->end(); iter++) { + for (auto iter = diverg_info->begin(); iter != diverg_info->end(); iter++) { DivergInfo *info = iter->second.get(); auto zeropoint = info->GetZeropoint(); result.insert(zeropoint); @@ -257,9 +257,9 @@ std::unordered_map Calibrator::GetZeropoint( } std::map Calibrator::GetMinMax( - std::unordered_map> *mDivergInfo) { + std::unordered_map> *diverg_info) { std::map result; - for (auto iter = mDivergInfo->begin(); iter != mDivergInfo->end(); iter++) { + for (auto iter = diverg_info->begin(); iter != diverg_info->end(); iter++) { DivergInfo *info = iter->second.get(); mindspore::lite::quant::MaxMin input_maxmin{}; input_maxmin.min = info->min; @@ -284,10 +284,10 @@ std::unordered_map> *Calibrator::GetOut return &this->output_diverg_info_; } -STATUS Calibrator::RecordMaxValue(std::string opName, vector data, - std::unordered_map> *mDivergInfo) { - auto got = (*mDivergInfo).find(opName); - if (got != (*mDivergInfo).end()) { +STATUS Calibrator::RecordMaxValue(const std::string &op_name, const vector &data, + std::unordered_map> *diverg_info) { + auto got = (*diverg_info).find(op_name); + if (got != (*diverg_info).end()) { ((*got).second)->RecordMaxValue(data); } return RET_OK; @@ -332,11 +332,11 @@ STATUS Calibrator::UpdateDivergInverval(std::unordered_map data, vector shape, +STATUS Calibrator::UpdateDataFrequency(const std::string &op_name, const vector &data, std::unordered_map> *diverg_info) { auto got = (*diverg_info).find(op_name); if (got != (*diverg_info).end()) { - ((*got).second)->UpdateHistogram(data, shape); + ((*got).second)->UpdateHistogram(data); } return RET_OK; } @@ -347,10 +347,10 @@ STATUS Calibrator::AddQuantizedOp(CNodePtr node) { return RET_ERROR; } string node_name = node->fullname_with_scope(); - std::unique_ptr input_diverg = - std::unique_ptr(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_, config_param_.method_x)); - std::unique_ptr output_diverg = - std::unique_ptr(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_, config_param_.method_x)); + std::unique_ptr input_diverg = std::unique_ptr( + new DivergInfo(node, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, config_param_.method_x)); + std::unique_ptr output_diverg = std::unique_ptr( + new DivergInfo(node, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, config_param_.method_x)); input_diverg_info_.insert(std::make_pair(string(node_name), std::move(input_diverg))); output_diverg_info_.insert(std::make_pair(string(node_name), std::move(output_diverg))); @@ -359,29 +359,33 @@ STATUS Calibrator::AddQuantizedOp(CNodePtr node) { void Calibrator::AddImage(const string file) { auto exist = [](const string file) { - struct stat buf; + struct stat buf{}; return stat(file.c_str(), &buf) == 0; }; if (exist(file)) { MS_LOG(INFO) << "load image: " << file; this->images_.push_back(file); } else { - MS_LOG(WARNING) << "Invaild image file path: " << file; + MS_LOG(WARNING) << "invalid image file path: " << file; } } -STATUS Calibrator::GenerateInputData(const int index, mindspore::tensor::MSTensor *tensor) const { +STATUS Calibrator::GenerateInputData(int index, mindspore::tensor::MSTensor *tensor) const { string path = images_[index]; MS_LOG(INFO) << "read image: " << path; size_t size; - char *binBuf = ReadFile(path.c_str(), &size); + char *bin_buf = ReadFile(path.c_str(), &size); auto data = tensor->MutableData(); if (size != tensor->Size()) { MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size << " input tensor size: " << tensor->Size(); return RET_ERROR; } - memcpy(data, binBuf, size); + auto ret = memcpy_s(data, tensor->Size(), bin_buf, size); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s error: " << ret; + return RET_ERROR; + } return RET_OK; } @@ -467,7 +471,7 @@ STATUS Calibrator::ReadConfig() { } MS_LOG(DEBUG) << "image_path: " << config_param_.image_path << " " << "batch_count: " << config_param_.batch_count << " " - << "mothod_x: " << config_param_.method_x << " " + << "method_x: " << config_param_.method_x << " " << "thread_num: " << config_param_.thread_num; delete[] resolved_path; @@ -475,8 +479,8 @@ STATUS Calibrator::ReadConfig() { return RET_OK; } -Calibrator::Calibrator(string path, size_t bitNum, int quantMax, int quantMin) - : config_path_(path), bit_num_(bitNum), quant_max_(quantMax), quant_min_(quantMin) {} +Calibrator::Calibrator(string path, size_t bit_num, int quant_max, int quant_min) + : config_path_(path), bit_num_(bit_num), quant_max_(quant_max), quant_min_(quant_min) {} PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type, bool per_channel) @@ -669,11 +673,11 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptrcalibrator_->GetMinMax(this->calibrator_->GetInputDivergInfo()); - auto input_scale = this->calibrator_->GetResult(this->calibrator_->GetInputDivergInfo()); + auto input_scale = this->calibrator_->GetScale(this->calibrator_->GetInputDivergInfo()); auto input_zero_point = this->calibrator_->GetZeropoint(this->calibrator_->GetInputDivergInfo()); auto output_min_max = this->calibrator_->GetMinMax(this->calibrator_->GetOutputDivergInfo()); - auto output_scale = this->calibrator_->GetResult(this->calibrator_->GetOutputDivergInfo()); + auto output_scale = this->calibrator_->GetScale(this->calibrator_->GetOutputDivergInfo()); auto output_zeropoint = this->calibrator_->GetZeropoint(this->calibrator_->GetOutputDivergInfo()); auto cnodes = funcGraph->GetOrderedCnodes(); @@ -803,7 +807,7 @@ STATUS PostTrainingQuantizer::PreProcess() { // from user input QuantStrategy strategy(10); auto cnodes = funcGraph->GetOrderedCnodes(); - for (auto cnode : cnodes) { + for (auto &cnode : cnodes) { AnfNodePtr anf = std::dynamic_pointer_cast(cnode); if (strategy.CanOpPostQuantized(anf)) { MS_LOG(INFO) << "node: " << cnode->fullname_with_scope() << " will be quantized"; @@ -813,16 +817,15 @@ STATUS PostTrainingQuantizer::PreProcess() { return RET_OK; } -STATUS PostTrainingQuantizer::CheckTensorVec(const std::string &nodeName, - const std::vector &tensorVec) const { - if (tensorVec.size() < 1) { - MS_LOG(ERROR) << "node: " << nodeName << " input tensors is 0"; +STATUS PostTrainingQuantizer::CheckTensorVec(const std::string &node_name, + const std::vector &tensor_vec) const { + if (tensor_vec.size() < 1) { + MS_LOG(ERROR) << "node: " << node_name << " input tensors is 0"; return RET_ERROR; } - auto *tensor = tensorVec[0]; + auto *tensor = tensor_vec[0]; if (tensor->data_type() != kNumberTypeFloat32) { - //&& tensor->RefCount() != MSCONST_WEIGHT_REFCOUNT - MS_LOG(DEBUG) << "node: " << nodeName << " will not quantize" + MS_LOG(DEBUG) << "node: " << node_name << " will not quantize" << " tensor data_type: " << tensor->data_type(); return RET_ERROR; } @@ -856,8 +859,8 @@ STATUS PostTrainingQuantizer::DoInference() { } auto tensor = beforeInputs[0]; const float *tData = static_cast(tensor->MutableData()); - size_t shapeSize = tensor->ElementsNum(); - vector data(tData, tData + shapeSize); + size_t elem_count = tensor->ElementsNum(); + vector data(tData, tData + elem_count); this->calibrator_->RecordMaxValue(callParam.name_callback_param, data, this->calibrator_->GetInputDivergInfo()); return true; }; @@ -871,8 +874,8 @@ STATUS PostTrainingQuantizer::DoInference() { } auto tensor = afterOutputs[0]; const float *tensor_data = static_cast(tensor->MutableData()); - size_t shape_size = tensor->ElementsNum(); - vector data(tensor_data, tensor_data + shape_size); + size_t elem_count = tensor->ElementsNum(); + vector data(tensor_data, tensor_data + elem_count); this->calibrator_->RecordMaxValue(callParam.name_callback_param, data, this->calibrator_->GetOutputDivergInfo()); return true; }; @@ -910,7 +913,7 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() { const float *tensor_data = static_cast(tensor->MutableData()); size_t shape_size = tensor->ElementsNum(); vector data(tensor_data, tensor_data + shape_size); - this->calibrator_->UpdateDataFrequency(callParam.name_callback_param, data, tensor->shape(), + this->calibrator_->UpdateDataFrequency(callParam.name_callback_param, data, this->calibrator_->GetInputDivergInfo()); return true; }; @@ -926,7 +929,7 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() { const float *tenosr_data = static_cast(tensor->MutableData()); size_t shape_size = tensor->ElementsNum(); vector data(tenosr_data, tenosr_data + shape_size); - this->calibrator_->UpdateDataFrequency(call_param.name_callback_param, data, tensor->shape(), + this->calibrator_->UpdateDataFrequency(call_param.name_callback_param, data, this->calibrator_->GetOutputDivergInfo()); return true; }; diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h index e2cdfdfecb5..a61daf30ed4 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h @@ -39,14 +39,9 @@ struct MaxMin { float max; }; -enum ImageFormat { - RGB = 0, - GRAY = 1, - BGR = 2, -}; - const char kMethodMaxMin[] = "MAX_MIN"; const char kMethodKL[] = "KL"; +constexpr int kDefaultBinNumber = 2048; struct ConfigParam { // ImageFormat imageFormat; @@ -78,7 +73,8 @@ class PostTrainingQuantizer : public Quantizer { STATUS PreProcess(); - STATUS CheckTensorVec(const std::string &nodeName, const std::vector &tensorVec) const; + STATUS CheckTensorVec(const std::string &node_name, + const std::vector &tensor_vec) const; STATUS DoInference(); @@ -105,7 +101,7 @@ struct DivergInfo; class Calibrator { public: - explicit Calibrator(std::string path, size_t quant_size, int quant_max, int quant_msin); + explicit Calibrator(std::string path, size_t bit_num, int quant_max, int quant_min); ~Calibrator() = default; @@ -123,18 +119,18 @@ class Calibrator { STATUS AddQuantizedOp(CNodePtr node); - STATUS RecordMaxValue(std::string opName, std::vector data, + STATUS RecordMaxValue(const std::string &op_name, const std::vector &data, std::unordered_map> *diverg_info); STATUS UpdateDivergInverval(std::unordered_map> *diverg_info); - STATUS UpdateDataFrequency(std::string op_name, std::vector data, std::vector shape, + STATUS UpdateDataFrequency(const std::string& op_name, const std::vector& data, std::unordered_map> *diverg_info); void Dump(); STATUS ComputeThreshold(); - std::unordered_map GetResult( + std::unordered_map GetScale( std::unordered_map> *diverg_info); std::unordered_map GetZeropoint( diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 6e7690de3aa..6753973ac78 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -349,16 +349,12 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti quant_datas[index] = quant_data; } } - auto ret = memcpy_s(const_cast(raw_datas), weight->tensor_size(), quant_datas.data(), + auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy error: " << ret; return RET_ERROR; } - if (quantType == QuantType_WeightQuant) { - PostBitPack(const_cast(raw_datas), elem_count, bitNum); - } - weight->set_tensor_size(elem_count * sizeof(int8_t)); } else { // channel at first @@ -407,9 +403,6 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti MS_LOG(ERROR) << "memcpy error: " << ret; return RET_ERROR; } - if (quantType == QuantType_WeightQuant) { - PostBitPack(const_cast(raw_datas), elem_count, bitNum); - } weight->set_tensor_size(elem_count * sizeof(int8_t)); } @@ -441,9 +434,6 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti MS_LOG(ERROR) << "memcpy error: " << ret; return RET_ERROR; } - if (quantType == QuantType_WeightQuant) { - PostBitPack(raw_datas, elem_count, bitNum); - } weight->set_tensor_size(elem_count * sizeof(int8_t)); } if (quant_params.empty()) { diff --git a/mindspore/lite/tools/converter/quantizer/quantizer.h b/mindspore/lite/tools/converter/quantizer/quantizer.h index 963a9635527..3fe37379b3e 100644 --- a/mindspore/lite/tools/converter/quantizer/quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/quantizer.h @@ -51,7 +51,7 @@ class Quantizer { virtual STATUS DetermineNodeQuantType(); - virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0; + virtual STATUS DoQuantize(FuncGraphPtr func_graph) = 0; mindspore::lite::converter::Flags flags; protected: