diff --git a/mindspore/lite/src/ops/adder.cc b/mindspore/lite/src/ops/adder.cc index e43bd1b10e8..6320c48cb99 100644 --- a/mindspore/lite/src/ops/adder.cc +++ b/mindspore/lite/src/ops/adder.cc @@ -21,7 +21,7 @@ #include "include/errorcode.h" #include "src/common/log_adapter.h" #ifdef PRIMITIVE_WRITEABLE -#include "tools/converter/quantizer/quantize_util.h" +#include "src/param_value_lite.h" #endif #ifndef PRIMITIVE_WRITEABLE diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index b1961c234bd..3bb9fa81e64 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -24,7 +24,7 @@ #include "src/common/log_adapter.h" #ifdef PRIMITIVE_WRITEABLE #include -#include "tools/converter/quantizer/quantize_util.h" +#include "src/param_value_lite.h" #endif #ifndef PRIMITIVE_WRITEABLE diff --git a/mindspore/lite/src/ops/deconv2d.cc b/mindspore/lite/src/ops/deconv2d.cc index 1605a12a0b5..d6878602ed0 100644 --- a/mindspore/lite/src/ops/deconv2d.cc +++ b/mindspore/lite/src/ops/deconv2d.cc @@ -21,8 +21,7 @@ #include "src/common/log_adapter.h" #ifdef PRIMITIVE_WRITEABLE #include - -#include "tools/converter/quantizer/quantize_util.h" +#include "src/param_value_lite.h" #endif #ifndef PRIMITIVE_WRITEABLE diff --git a/mindspore/lite/src/ops/depthwise_conv2d.cc b/mindspore/lite/src/ops/depthwise_conv2d.cc index 19a626fe3f7..ad5bef82133 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.cc +++ b/mindspore/lite/src/ops/depthwise_conv2d.cc @@ -19,7 +19,7 @@ #include #include #ifdef PRIMITIVE_WRITEABLE -#include "tools/converter/quantizer/quantize_util.h" +#include "src/param_value_lite.h" #endif #ifndef PRIMITIVE_WRITEABLE #include "src/ops/ops_register.h" diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc index bd11e88fd8e..9e50dc222eb 100644 --- a/mindspore/lite/src/ops/matmul.cc +++ b/mindspore/lite/src/ops/matmul.cc @@ -18,7 +18,7 @@ #include #include #ifdef PRIMITIVE_WRITEABLE -#include "tools/converter/quantizer/quantize_util.h" +#include "src/param_value_lite.h" #endif #ifndef PRIMITIVE_WRITEABLE diff --git a/mindspore/lite/src/ops/maximum.cc b/mindspore/lite/src/ops/maximum.cc index 15899974d29..1d17d3e6551 100644 --- a/mindspore/lite/src/ops/maximum.cc +++ b/mindspore/lite/src/ops/maximum.cc @@ -19,8 +19,7 @@ #include "src/common/log_adapter.h" #ifdef PRIMITIVE_WRITEABLE #include - -#include "tools/converter/quantizer/quantize_util.h" +#include "src/param_value_lite.h" #endif #ifndef PRIMITIVE_WRITEABLE diff --git a/mindspore/lite/src/ops/maximum_grad.cc b/mindspore/lite/src/ops/maximum_grad.cc index cbe6a46428e..634e4d5853a 100644 --- a/mindspore/lite/src/ops/maximum_grad.cc +++ b/mindspore/lite/src/ops/maximum_grad.cc @@ -19,7 +19,7 @@ #include "src/common/log_adapter.h" #ifdef PRIMITIVE_WRITEABLE #include -#include "tools/converter/quantizer/quantize_util.h" +#include "src/param_value_lite.h" #endif #ifndef PRIMITIVE_WRITEABLE diff --git a/mindspore/lite/src/ops/minimum_grad.cc b/mindspore/lite/src/ops/minimum_grad.cc index 73c66aa836b..6c5df183f56 100644 --- a/mindspore/lite/src/ops/minimum_grad.cc +++ b/mindspore/lite/src/ops/minimum_grad.cc @@ -19,7 +19,7 @@ #include "src/common/log_adapter.h" #ifdef PRIMITIVE_WRITEABLE #include -#include "tools/converter/quantizer/quantize_util.h" +#include "src/param_value_lite.h" #endif #ifndef PRIMITIVE_WRITEABLE diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 991ca08ed10..0ad8da75566 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -387,7 +387,9 @@ void PrimitiveC::set_input_quant_params(const std::vector &input_quant_param) { - MS_ASSERT(index < this->input_quant_param_.size()); + if (index >= this->input_quant_param_.size()) { + this->input_quant_param_.resize(index + 1); + } this->input_quant_param_.at(index) = input_quant_param; } @@ -493,7 +495,7 @@ std::shared_ptr GetTupleGetItemPrim() { } template ::value>> -std::shared_ptr NewPrimitiveC(const Primitive &prim, const std::vector &inputs, +std::shared_ptr NewPrimitiveC(const mindspore::Primitive &prim, const std::vector &inputs, const schema::QuantType &quantType) { auto primc = std::make_shared(); if (primc == nullptr) { diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index b4743fb6992..2b3296fa162 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -204,7 +204,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); return nullptr; } - this->mQuantizer = std::make_unique(new_graph, config->quantWeightSize, + this->mQuantizer = std::make_unique(new_graph, config->configFile, config->quantWeightSize, config->quantWeightChannel, config->bitNum); if (mQuantizer == nullptr) { MS_LOG(ERROR) << "New WeightQuantizer failed"; diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 9d4a665b0af..09c9049bd38 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -32,8 +32,6 @@ #include "tools/anf_exporter/anf_exporter.h" #include "tools/anf_importer/import_from_mindir.h" #include "proto/onnx.pb.h" -#include "tools/converter/quantizer/post_training_quantizer.h" -#include "tools/converter/quantizer/quant_cast.h" #include "include/version.h" namespace mindspore { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.cc index 566709a848e..b0789675376 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.cc @@ -16,7 +16,6 @@ #include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h" #include "tools/converter/converter_context.h" -#include "tools/converter/quantizer/quantize_util.h" #include "tools/common/tensor_util.h" namespace mindspore::lite { diff --git a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt index b37f7dc3fb3..176ce6a7171 100644 --- a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt +++ b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt @@ -6,7 +6,6 @@ include_directories(${3RD_DIR}/opencv/build/include/opencv4) file(GLOB QUANTIZER ${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc - ${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index 825fab53cad..202c5784df0 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -39,6 +39,7 @@ #include "tools/common/tensor_util.h" #include "src/common/file_utils.h" #include "src/common/utils.h" +#include "tools/converter/quantizer/weight_quantizer.h" using std::string; using std::vector; @@ -380,182 +381,16 @@ STATUS Calibrator::AddQuantizedOp(const CNodePtr &node) { return RET_OK; } -void Calibrator::AddImage(const string &file, size_t index) { - if (index >= images_.size()) { - MS_LOG(ERROR) << "images_ size: " << images_.size() << " but index: " << index; - return; - } - auto exist = [](const string &file) { - struct stat buf {}; - return stat(file.c_str(), &buf) == 0; - }; - if (exist(file)) { - this->images_[index].push_back(file); - } else { - MS_LOG(WARNING) << "invalid image file path: " << file; - } -} - STATUS Calibrator::GenerateInputData(size_t input_index, size_t image_index, mindspore::tensor::MSTensor *tensor) const { - MS_ASSERT(tensor != nullptr); - if (input_index >= images_.size()) { - MS_LOG(ERROR) << "images_ size: " << images_.size() << " but input_index: " << input_index; - return RET_ERROR; - } - if (image_index >= images_[input_index].size()) { - MS_LOG(ERROR) << "images_[input_index] size: " << images_[input_index].size() - << " but image_index: " << image_index; - return RET_ERROR; - } - string path = images_[input_index][image_index]; - MS_LOG(INFO) << "read image: " << path; - size_t size; - char *bin_buf = ReadFile(path.c_str(), &size); - if (bin_buf == nullptr) { - MS_LOG(ERROR) << "ReadFile return nullptr"; - return RET_NULL_PTR; - } - auto data = tensor->MutableData(); - if (data == nullptr) { - MS_LOG(ERROR) << "Get tensor MutableData return nullptr"; - return RET_NULL_PTR; - } - if (size != tensor->Size()) { - MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size - << " input tensor size: " << tensor->Size(); - return RET_ERROR; - } - auto ret = memcpy_s(data, tensor->Size(), bin_buf, size); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s error: " << ret; - delete[] bin_buf; - return RET_ERROR; - } - delete[] bin_buf; - return RET_OK; + return CopyInputDataToTensor(input_index, image_index, images_, tensor); } STATUS Calibrator::CollectImages() { - this->images_.resize(config_param_.image_paths.size()); - auto input_i = 0; - bool multi_input = config_param_.image_paths.size() > 1; - for (const auto &image_path : config_param_.image_paths) { - DIR *root = opendir(image_path.c_str()); - if (root == nullptr) { - MS_LOG(ERROR) << "invalid image path: " << image_path; - return RET_PARAM_INVALID; - } - struct dirent *image_dir = readdir(root); - size_t count = 0; - while (image_dir != nullptr) { - string file_name(image_dir->d_name); - if (file_name != "." && file_name != "..") { - const std::string file_path = image_path + "/" + file_name; - if (multi_input || config_param_.batch_count == 0) { - this->AddImage(file_path, input_i); - count++; - } else if (count < config_param_.batch_count) { - this->AddImage(file_path, input_i); - count++; - } else { - break; - } - } - image_dir = readdir(root); - } - std::sort(images_[input_i].begin(), images_[input_i].end()); - if (config_param_.batch_count != 0 && config_param_.batch_count < images_[input_i].size()) { - images_[input_i].resize(config_param_.batch_count); - } - closedir(root); - input_i++; - } - return RET_OK; + return CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); } -STATUS Calibrator::ReadConfig() { - if (config_path_.empty() || config_path_.length() > PATH_MAX) { - MS_LOG(ERROR) << "invalid config path!"; - return RET_PARAM_INVALID; - } - // check whether config file path is valid - char *resolved_path = new (std::nothrow) char[PATH_MAX]{0}; - if (resolved_path == nullptr) { - MS_LOG(ERROR) << "New an object failed."; - return RET_ERROR; - } -#ifdef _WIN32 - if (_fullpath(resolved_path, config_path_.c_str(), 1024) != nullptr) { - config_path_ = string(resolved_path); - } -#else - if (realpath(config_path_.c_str(), resolved_path) != nullptr) { - config_path_ = string(resolved_path); - } -#endif - std::ifstream fs(config_path_.c_str(), std::ifstream::in); - if (!fs.is_open()) { - MS_LOG(ERROR) << "config proto file %s open failed: " << config_path_; - delete[] resolved_path; - return RET_PARAM_INVALID; - } - std::string line; - while (std::getline(fs, line)) { - auto index = line.find('='); - if (index == std::string::npos) { - MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check"; - delete[] resolved_path; - return RET_PARAM_INVALID; - } - auto key = line.substr(0, index); - auto value = line.substr(index + 1); - Trim(&key); - Trim(&value); - if (key == "image_path") { - auto &raw_image_paths = value; - auto ind = raw_image_paths.find(','); - while (ind != std::string::npos) { - auto image_path = raw_image_paths.substr(0, ind); - Trim(&image_path); - config_param_.image_paths.push_back(image_path); - raw_image_paths = raw_image_paths.substr(ind + 1); - Trim(&raw_image_paths); - ind = raw_image_paths.find(','); - } - config_param_.image_paths.push_back(raw_image_paths); - } else if (key == "batch_count") { - config_param_.batch_count = std::stoul(value); - } else if (key == "thread_num") { - config_param_.thread_num = std::stoul(value); - } else if (key == "method_x") { - if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) { - MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value."; - } else { - config_param_.method_x = value; - } - } else if (key == "bias_correction") { - std::for_each(value.begin(), value.end(), ::tolower); - if (value == "true") { - config_param_.bias_correction = true; - } - } else { - MS_LOG(WARNING) << "unsupported parameter"; - } - } - - for (const auto &path : config_param_.image_paths) { - MS_LOG(DEBUG) << "calibration data_path: " << path; - } - MS_LOG(DEBUG) << "batch_count: " << config_param_.batch_count << " " - << "method_x: " << config_param_.method_x << " " - << "thread_num: " << config_param_.thread_num << " " - << "bias_correction: " << config_param_.bias_correction; - - delete[] resolved_path; - fs.close(); - return RET_OK; -} +STATUS Calibrator::ReadConfig() { return ParseConfigFile(config_path_, &config_param_); } Calibrator::Calibrator(string path, size_t bit_num, int quant_max, int quant_min) : config_path_(std::move(path)), bit_num_(bit_num), quant_max_(quant_max), quant_min_(quant_min) {} @@ -621,8 +456,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct return RET_OK; } -STATUS PostTrainingQuantizer::DoWeightQuant(const AnfNodePtr &weight, std::shared_ptr primitive_c, - bool perchanel) const { +STATUS PostTrainingQuantizer::DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight, + std::shared_ptr primitive_c, bool perchanel) const { MS_ASSERT(weight != nullptr); MS_ASSERT(lite_primitive != nullptr); // perlayer @@ -640,8 +475,21 @@ STATUS PostTrainingQuantizer::DoWeightQuant(const AnfNodePtr &weight, std::share MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value"; return RET_NULL_PTR; } - auto status = QuantFilter(paramValue, std::move(primitive_c), QuantType_PostTraining, quant_max, quant_min, - bit_num, perchanel); + auto bit_num_t = bit_num; + auto quant_max_t = quant_max; + auto quant_min_t = quant_min; + if (calibrator_->config_param_.mixed) { + auto opname_iter = opname_bit_.find(op_name); + if (opname_iter == opname_bit_.end()) { + MS_LOG(WARNING) << op_name << " not in the opname_bit_ map"; + } else { + bit_num_t = opname_iter->second; + quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; + quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); + } + } + auto status = QuantFilter(paramValue, std::move(primitive_c), QuantType_PostTraining, quant_max_t, + quant_min_t, bit_num_t, perchanel); if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed: " << status; return status; @@ -921,7 +769,7 @@ STATUS PostTrainingQuantizer::QuantNode() { } if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) { MS_LOG(DEBUG) << "this parameter do quant"; - DoWeightQuant(input_node, primitive_c, false); + DoWeightQuant(op_name, input_node, primitive_c, false); } else { MS_LOG(DEBUG) << "this parameter no need to do quant"; } @@ -943,7 +791,7 @@ STATUS PostTrainingQuantizer::QuantNode() { op_type == PrimitiveType_FullConnection) { perchannel = true; } - DoWeightQuant(weight, primitive_c, perchannel); + DoWeightQuant(op_name, weight, primitive_c, perchannel); // do bias quant if (cnode->inputs().size() == 4) { auto bias = cnode->input(3); @@ -982,18 +830,8 @@ STATUS PostTrainingQuantizer::UpdateDivergInverval() { * 3. save quantied node **/ STATUS PostTrainingQuantizer::PreProcess() { - if (this->calibrator_ == nullptr) { - MS_LOG(ERROR) << "calibrator is null!"; - return RET_ERROR; - } - // 1. generate config param - STATUS status = calibrator_->ReadConfig(); - if (status != RET_OK) { - MS_LOG(ERROR) << "read proto text failed!"; - return status; - } // 2. collect image files - status = calibrator_->CollectImages(); + auto status = calibrator_->CollectImages(); if (status != RET_OK) { MS_LOG(ERROR) << "collect images failed!"; return status; @@ -1560,55 +1398,49 @@ STATUS PostTrainingQuantizer::ComputeThreshold() { return this->calibrator_->Com STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { MS_LOG(INFO) << "start to parse config file"; - STATUS status = PreProcess(); + if (this->calibrator_ == nullptr) { + MS_LOG(ERROR) << "calibrator is null!"; + return RET_ERROR; + } + // 1. generate config param + STATUS status = calibrator_->ReadConfig(); + if (status != RET_OK) { + MS_LOG(ERROR) << "read proto text failed!"; + return status; + } + + if (calibrator_->config_param_.mixed) { + // get opname_bit map + auto weight_quant_func_graph = CopyFuncGraph(func_graph); + if (weight_quant_func_graph == nullptr) { + MS_LOG(ERROR) << "CopyFuncGraph error"; + return RET_ERROR; + } + WeightQuantizer weight_quantizer(weight_quant_func_graph, calibrator_->config_param_); + weight_quantizer.flags = flags; + status = weight_quantizer.DoQuantize(weight_quant_func_graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "Do mix weight quant error"; + return RET_ERROR; + } + opname_bit_ = weight_quantizer.opname_bit_; + } + + status = PreProcess(); if (status != RET_OK) { MS_LOG(ERROR) << "do pre process failed!"; return status; } + // anf -- fb - auto meta_graph = Export(func_graph, true, true); - if (meta_graph == nullptr) { - MS_LOG(ERROR) << "Export to meta_graph return nullptr"; - return RET_ERROR; - } - - // transform - GraphDefTransform transform; - transform.SetGraphDef(meta_graph); flags.quantType = schema::QuantType_QUANT_NONE; - status = transform.Transform(flags); - if (status != RET_OK) { - MS_LOG(ERROR) << "FBTransform model failed " << status; - return RET_ERROR; - } MS_LOG(INFO) << "start create session"; - flatbuffers::FlatBufferBuilder builder(1024); - auto offset = schema::MetaGraph::Pack(builder, meta_graph); - builder.Finish(offset); - schema::FinishMetaGraphBuffer(builder, offset); - size_t size = builder.GetSize(); - auto *content = reinterpret_cast(builder.GetBufferPointer()); - if (content == nullptr) { - MS_LOG(ERROR) << "GetBufferPointer nullptr"; - return RET_ERROR; - } - auto model = lite::Model::Import(content, size); - - Context ctx; - ctx.thread_num_ = calibrator_->GetThreadNum(); - - fp32_session_ = dynamic_cast(session::LiteSession::CreateSession(&ctx)); + fp32_session_ = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum()); if (fp32_session_ == nullptr) { MS_LOG(ERROR) << "create session failed!"; return RET_ERROR; } - auto ret = fp32_session_->CompileGraph(model); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "compile graph error"; - return RET_ERROR; - } - MS_LOG(INFO) << "start to update divergence's max value"; status = DoInference(); if (status != RET_OK) { @@ -1647,49 +1479,13 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { if (calibrator_->GetBiasCorrection()) { // init in8 session - // anf -- fb - auto int8_meta_graph = Export(func_graph, true, true); - if (int8_meta_graph == nullptr) { - MS_LOG(ERROR) << "Export to int8_meta_graph return nullptr"; - return RET_ERROR; - } - - // transform - GraphDefTransform fb_transform; - fb_transform.SetGraphDef(int8_meta_graph); + MS_LOG(INFO) << "create quant session"; flags.quantType = schema::QuantType_PostTraining; - status = fb_transform.Transform(flags); - if (status != RET_OK) { - MS_LOG(ERROR) << "FBTransform model failed " << status; - return RET_ERROR; - } - MS_LOG(INFO) << "start create quantized session"; - flatbuffers::FlatBufferBuilder int8_builder(1024); - auto int8_offset = schema::MetaGraph::Pack(int8_builder, int8_meta_graph); - int8_builder.Finish(int8_offset); - schema::FinishMetaGraphBuffer(int8_builder, int8_offset); - size = int8_builder.GetSize(); - auto *int8_content = reinterpret_cast(int8_builder.GetBufferPointer()); - if (int8_content == nullptr) { - MS_LOG(ERROR) << "GetBufferPointer nullptr"; - return RET_ERROR; - } - auto int8_model = lite::Model::Import(int8_content, size); - - Context int8_ctx; - int8_ctx.thread_num_ = calibrator_->GetThreadNum(); - int8_ctx.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU; - - int8_session_ = dynamic_cast(session::LiteSession::CreateSession(&int8_ctx)); + int8_session_ = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum()); if (int8_session_ == nullptr) { MS_LOG(ERROR) << "create session failed!"; return RET_ERROR; } - ret = int8_session_->CompileGraph(int8_model); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "compile graph error"; - return RET_ERROR; - } MS_LOG(INFO) << "do bias correction"; status = BiasCorrection(func_graph); diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h index 0665eecc070..3a9b70ee31c 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h @@ -28,6 +28,8 @@ #include "tools/converter/quantizer/quantizer.h" #include "tools/converter/converter.h" #include "include/ms_tensor.h" +#include "tools/converter/quantizer/quantize_util.h" +#include "tools/converter/quantizer/weight_quantizer.h" namespace mindspore::lite::quant { class Calibrator; @@ -38,19 +40,8 @@ struct MaxMin { float max; }; -const char kMethodMaxMin[] = "MAX_MIN"; -const char kMethodKL[] = "KL"; -const char kMethodOutlier[] = "RemovalOutlier"; constexpr int kDefaultBinNumber = 2048; -struct ConfigParam { - std::vector image_paths; - uint32_t batch_count{100}; - std::string method_x{kMethodKL}; - uint32_t thread_num{1}; - bool bias_correction{false}; -}; - class PostTrainingQuantizer : public Quantizer { public: PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8, @@ -64,14 +55,16 @@ class PostTrainingQuantizer : public Quantizer { int quant_min{INT8_MIN}; private: + std::map opname_bit_; + bool per_channel_{true}; TypeId target_type_{kNumberTypeInt8}; std::unique_ptr calibrator_; - mindspore::lite::LiteSession *fp32_session_; - mindspore::lite::LiteSession *int8_session_; + session::LiteSession *fp32_session_{nullptr}; + session::LiteSession *int8_session_{nullptr}; std::map> fp32_op_input_map; // concurency std::map> fp32_op_output_ch_mean_map; // concurency @@ -112,7 +105,8 @@ class PostTrainingQuantizer : public Quantizer { STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, const std::shared_ptr &) const; - STATUS DoWeightQuant(const AnfNodePtr &weight, std::shared_ptr primitive_c, bool perchannel) const; + STATUS DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight, std::shared_ptr primitive_c, + bool perchannel) const; STATUS DoBiasQuant(const AnfNodePtr &bias, const std::shared_ptr &primitive_c); STATUS Int8Inference(); @@ -213,13 +207,13 @@ class Calibrator { std::unordered_map>> *GetOutputDivergInfo(); + PostQuantConfig config_param_; + private: std::vector> images_; // multi_input, echo input has multi input data std::string config_path_; - ConfigParam config_param_; - std::unordered_map>> inputs_diverg_info_; std::unordered_map>> outputs_diverg_info_; @@ -227,8 +221,6 @@ class Calibrator { size_t bit_num_; int quant_max_; int quant_min_; - - void AddImage(const std::string &file, size_t index); }; } // namespace mindspore::lite::quant #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_POSTRAINING_QUANTIZER_H diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 5de9b7d43db..1ec399c060c 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -17,6 +17,8 @@ #include "mindspore/lite/tools/converter/quantizer/quantize_util.h" #include #include +#include +#include #include #include #include @@ -26,6 +28,8 @@ #include "src/common/utils.h" #include "abstract/abstract_value.h" #include "securec/include/securec.h" +#include "tools/anf_exporter/anf_exporter.h" +#include "mindspore/lite/include/version.h" using std::string; using std::vector; @@ -83,10 +87,10 @@ bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const { bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { MS_ASSERT(node != nullptr); - if (!node->isa()) { + if (!node->isa()) { return false; } - auto cnode = std::dynamic_pointer_cast(node); + auto cnode = std::dynamic_pointer_cast(node); auto type = NodePrimitiveType(cnode); static const std::vector int8OpList = { schema::PrimitiveType_Conv2D, @@ -475,4 +479,307 @@ schema::PrimitiveType NodePrimitiveType(const CNodePtr &cnode) { } return (schema::PrimitiveType)primitive_c->Type(); } + +STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) { + if (post_quant_config == nullptr) { + MS_LOG(ERROR) << "post_quant_config is null."; + return RET_PARAM_INVALID; + } + + if (config_file.empty() || config_file.length() > PATH_MAX) { + MS_LOG(ERROR) << "invalid config path!"; + return RET_PARAM_INVALID; + } + // check whether config file path is valid + auto resolved_path = std::make_unique(PATH_MAX); + if (resolved_path == nullptr) { + MS_LOG(ERROR) << "New an object failed."; + return RET_ERROR; + } +#ifdef _WIN32 + if (_fullpath(resolved_path.get(), config_file.c_str(), 1024) != nullptr) { + config_file = string(resolved_path.get()); + } +#else + if (realpath(config_file.c_str(), resolved_path.get()) != nullptr) { + config_file = string(resolved_path.get()); + } +#endif + std::ifstream fs(config_file.c_str(), std::ifstream::in); + if (!fs.is_open()) { + MS_LOG(ERROR) << "config file open failed: " << config_file; + return RET_PARAM_INVALID; + } + std::string line; + while (std::getline(fs, line)) { + Trim(&line); + if (line.empty()) { + continue; + } + auto index = line.find('='); + if (index == std::string::npos) { + MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check"; + return RET_PARAM_INVALID; + } + auto key = line.substr(0, index); + auto value = line.substr(index + 1); + Trim(&key); + Trim(&value); + if (key == "image_path") { + auto &raw_image_paths = value; + auto ind = raw_image_paths.find(','); + while (ind != std::string::npos) { + auto image_path = raw_image_paths.substr(0, ind); + Trim(&image_path); + post_quant_config->image_paths.push_back(image_path); + raw_image_paths = raw_image_paths.substr(ind + 1); + Trim(&raw_image_paths); + ind = raw_image_paths.find(','); + } + post_quant_config->image_paths.push_back(raw_image_paths); + } else if (key == "batch_count") { + post_quant_config->batch_count = std::stoul(value); + } else if (key == "thread_num") { + post_quant_config->thread_num = std::stoul(value); + } else if (key == "method_x") { + if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) { + MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value."; + } else { + post_quant_config->method_x = value; + } + } else if (key == "bias_correction") { + std::for_each(value.begin(), value.end(), ::tolower); + if (value == "true") { + post_quant_config->bias_correction = true; + } + } else if (key == "mixed") { + std::for_each(value.begin(), value.end(), ::tolower); + if (value == "true") { + post_quant_config->mixed = true; + } + } else if (key == "mean_error_threshold") { + post_quant_config->mean_error_threshold = std::stof(value); + } else { + MS_LOG(WARNING) << "unsupported parameter: " << key; + } + } + + for (const auto &path : post_quant_config->image_paths) { + MS_LOG(DEBUG) << "calibration data_path: " << path; + } + MS_LOG(DEBUG) << "batch_count: " << post_quant_config->batch_count << "\n" + << "method_x: " << post_quant_config->method_x << "\n" + << "thread_num: " << post_quant_config->thread_num << "\n" + << "bias_correction: " << post_quant_config->bias_correction << "\n" + << "mixed: " << post_quant_config->mixed << "\n" + << "mean_error_threshold: " << post_quant_config->mean_error_threshold; + post_quant_config->inited = true; + fs.close(); + return RET_OK; +} + +session::LiteSession *CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, + int thread_num) { + auto meta_graph = Export(func_graph, true, true); + if (meta_graph == nullptr) { + MS_LOG(ERROR) << "Export to meta_graph failed"; + return nullptr; + } + + // transform + GraphDefTransform fb_transform; + fb_transform.SetGraphDef(meta_graph); + auto status = fb_transform.Transform(flags); + if (status != RET_OK) { + MS_LOG(ERROR) << "FBTransform model failed"; + return nullptr; + } + meta_graph->version = Version(); + + flatbuffers::FlatBufferBuilder builder(1024); + auto offset = schema::MetaGraph::Pack(builder, meta_graph); + builder.Finish(offset); + schema::FinishMetaGraphBuffer(builder, offset); + auto size = builder.GetSize(); + auto *content = reinterpret_cast(builder.GetBufferPointer()); + if (content == nullptr) { + MS_LOG(ERROR) << "GetBufferPointer return null"; + return nullptr; + } + auto model = lite::Model::Import(content, size); + if (model == nullptr) { + MS_LOG(ERROR) << "Import model failed"; + return nullptr; + } + + Context ctx; + ctx.thread_num_ = thread_num; + + auto session = session::LiteSession::CreateSession(&ctx); + if (session == nullptr) { + MS_LOG(ERROR) << "create session failed."; + return nullptr; + } + + status = session->CompileGraph(model); + if (status != RET_OK) { + MS_LOG(ERROR) << "CompileGraph error"; + return nullptr; + } + model->Free(); + return session; +} + +STATUS CollectCalibInputs(const std::vector &input_dirs, size_t count_limited, + std::vector> *inputs) { + if (inputs == nullptr) { + MS_LOG(ERROR) << "inputs is null"; + return RET_ERROR; + } + auto AddImage = [&inputs](const std::string &file, size_t index) { + if (index >= inputs->size()) { + MS_LOG(ERROR) << "images_ size: " << inputs->size() << " but input index: " << index; + return; + } + struct stat buf {}; + if (stat(file.c_str(), &buf) == 0) { + inputs->at(index).push_back(file); + } else { + MS_LOG(WARNING) << "invalid image file path: " << file; + } + }; + + inputs->resize(input_dirs.size()); + auto input_i = 0; + bool multi_input = input_dirs.size() > 1; + for (const auto &image_path : input_dirs) { + DIR *root = opendir(image_path.c_str()); + if (root == nullptr) { + MS_LOG(ERROR) << "invalid image path: " << image_path; + return RET_PARAM_INVALID; + } + struct dirent *image_dir = readdir(root); + size_t count = 0; + while (image_dir != nullptr) { + string file_name(image_dir->d_name); + if (file_name != "." && file_name != "..") { + const std::string file_path = image_path + "/" + file_name; + if (multi_input || count == 0) { + AddImage(file_path, input_i); + count++; + } else if (count < count_limited) { + AddImage(file_path, input_i); + count++; + } else { + break; + } + } + image_dir = readdir(root); + } + std::sort(inputs->at(input_i).begin(), inputs->at(input_i).end()); + if (count_limited != 0 && count_limited < inputs->at(input_i).size()) { + inputs->at(input_i).resize(count_limited); + } + closedir(root); + input_i++; + } + return RET_OK; +} + +STATUS CopyInputDataToTensor(size_t input_index, size_t image_index, + const std::vector> &images, mindspore::tensor::MSTensor *tensor) { + MS_ASSERT(tensor != nullptr); + if (input_index >= images.size()) { + MS_LOG(ERROR) << "images_ size: " << images.size() << " but input_index: " << input_index; + return RET_ERROR; + } + if (image_index >= images[input_index].size()) { + MS_LOG(ERROR) << "images_[input_index] size: " << images[input_index].size() << " but image_index: " << image_index; + return RET_ERROR; + } + string path = images[input_index][image_index]; + MS_LOG(INFO) << "read image: " << path; + size_t size; + char *bin_buf = ReadFile(path.c_str(), &size); + if (bin_buf == nullptr) { + MS_LOG(ERROR) << "ReadFile return nullptr"; + return RET_NULL_PTR; + } + auto data = tensor->MutableData(); + if (data == nullptr) { + MS_LOG(ERROR) << "Get tensor MutableData return nullptr"; + return RET_NULL_PTR; + } + if (size != tensor->Size()) { + MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size + << " input tensor size: " << tensor->Size(); + return RET_ERROR; + } + auto ret = memcpy_s(data, tensor->Size(), bin_buf, size); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s error: " << ret; + delete[] bin_buf; + return RET_ERROR; + } + delete[] bin_buf; + return RET_OK; +} + +FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &func_graph) { + Cloner cloner({func_graph}, true, true, true, std::make_shared(), nullptr); + auto new_func_graph = cloner[func_graph]; + + std::map old_cnode_map; + for (const auto &cnode : func_graph->GetOrderedCnodes()) { + old_cnode_map[cnode->fullname_with_scope()] = cnode; + } + + for (auto &cnode : new_func_graph->GetOrderedCnodes()) { + auto cnode_name = cnode->fullname_with_scope(); + auto old_cnode_iter = old_cnode_map.find(cnode_name); + if (old_cnode_iter == old_cnode_map.end()) { + MS_LOG(ERROR) << "can not find node: " << cnode_name; + return nullptr; + } + auto old_cnode = old_cnode_iter->second; + auto inputs = cnode->inputs(); + for (size_t i = 0; i < inputs.size(); i++) { + auto input_node = inputs[i]; + if (input_node->isa()) { + auto param_node = input_node->cast(); + if (param_node->has_default()) { + ParamValueLitePtr old_param_value = std::static_pointer_cast(param_node->default_param()); + auto new_param_value = std::make_shared(); + + auto copyed_data = malloc(old_param_value->tensor_size()); + if (copyed_data == nullptr) { + MS_LOG(ERROR) << "malloc data error, size: " << old_param_value->tensor_size(); + return nullptr; + } + memcpy(copyed_data, old_param_value->tensor_addr(), old_param_value->tensor_size()); + + new_param_value->set_tensor_size(old_param_value->tensor_size()); + new_param_value->set_tensor_addr(copyed_data); + new_param_value->set_tensor_shape(old_param_value->tensor_shape()); + new_param_value->set_format(old_param_value->format()); + new_param_value->set_tensor_type(old_param_value->tensor_type()); + + param_node->set_default_param(new_param_value); + } + + auto old_abstract_base = param_node->abstract(); + if (!utils::isa(old_abstract_base)) { + MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << param_node->name(); + return nullptr; + } + auto old_abstract = utils::cast(old_abstract_base); + auto new_abstract = std::make_shared(old_abstract->element()->GetTypeTrack(), + old_abstract->GetShapeTrack()); + param_node->set_abstract(new_abstract); + } + } // end inputs loop + } // end cnodes loop + return new_func_graph; +} + } // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index b8eaf9e00ab..7fae9738efd 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -17,6 +17,8 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H +#include +#include #include #include #include @@ -35,11 +37,29 @@ #include "ir/primitive.h" #include "abstract/dshape.h" #include "tools/converter/quantizer/bitpacking.h" +#include "src/lite_session.h" +#include "tools/converter/graphdef_transform.h" +#include "src/common/file_utils.h" namespace mindspore::lite::quant { static constexpr size_t UINT8_QUANTIZATION = 8; static constexpr size_t WEIGHT_INDEX = 1; +const char kMethodMaxMin[] = "MAX_MIN"; +const char kMethodKL[] = "KL"; +const char kMethodOutlier[] = "RemovalOutlier"; + +struct PostQuantConfig { + std::vector image_paths; + uint32_t batch_count{100}; + std::string method_x{kMethodKL}; + uint32_t thread_num{1}; + bool bias_correction{false}; + bool mixed{false}; + float mean_error_threshold{0.04}; + bool inited{false}; +}; + /** * 1. when op's weight size > mWeightSize just skip * 2. only do conv/deconv/convdepthwise/deconvdepthwise/mul/matmul/batchmatmul quantization @@ -320,6 +340,21 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr &input_dirs, size_t count_limited, + std::vector> *inputs); + +STATUS CopyInputDataToTensor(size_t input_index, size_t image_index, + const std::vector> &images, mindspore::tensor::MSTensor *tensor); + +FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &); } // namespace mindspore::lite::quant #endif diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index 8eda191dec6..04038d6ce73 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "src/common/common.h" #include "ir/dtype/type_id.h" @@ -36,6 +37,7 @@ bool WeightQuantizer::IsPosNum(const std::string &str) { } return true; } + STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { MS_ASSERT(config != nullptr); if (!WeightQuantizer::IsPosNum(config->quantWeightChannel)) { @@ -57,28 +59,57 @@ STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { } return RET_OK; } -WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, + +WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) { + quant_strategy_ = std::make_unique(0, 0); + config_param_ = config; +} + +WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const string &weightSize, const std::string &convWeightChannelThreshold, const std::string &bitNum) : Quantizer(graph) { + this->config_file_ = config_file; auto quantSize = static_cast(std::stoull(weightSize)); - this->bitNum = static_cast(std::stoull(bitNum)); + this->bit_num_ = static_cast(std::stoull(bitNum)); auto convQuantWeightChannelThreshold = static_cast(std::stoull(convWeightChannelThreshold)); - mStrategy = std::make_unique(quantSize, convQuantWeightChannelThreshold); - quant_max = (1 << (unsigned int)(this->bitNum - 1)) - 1; - quant_min = -(1 << (unsigned int)(this->bitNum - 1)); + quant_strategy_ = std::make_unique(quantSize, convQuantWeightChannelThreshold); + quant_max = (1 << (unsigned int)(this->bit_num_ - 1)) - 1; + quant_min = -(1 << (unsigned int)(this->bit_num_ - 1)); // parse type_id - if (this->bitNum > 0 && this->bitNum <= 8) { + if (this->bit_num_ > 0 && this->bit_num_ <= 8) { type_id = kNumberTypeInt8; - } else if (this->bitNum <= 16) { + } else if (this->bit_num_ <= 16) { type_id = kNumberTypeInt16; } else { MS_LOG(ERROR) << "invalid input bits"; } } +WeightQuantizer::~WeightQuantizer() { delete fp32_session_; } + +STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, + std::shared_ptr primitive_c) { + // set dtype + param_value->set_tensor_type(type_id); + auto abstract_base = param_node->abstract(); + if (abstract_base == nullptr) { + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); + return RET_ERROR; + } + if (!utils::isa(abstract_base)) { + MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); + return RET_ERROR; + } + auto abstract_tensor = utils::cast(abstract_base); + abstract_tensor->element()->set_type(TypeIdToType(type_id)); + primitive_c->set_quant_type(schema::QuantType_WeightQuant); + + return RET_OK; +} + STATUS WeightQuantizer::DoConvQuantize(const std::list &nodes) { for (auto &cnode : nodes) { - if (!mStrategy->CanConvOpQuantized(cnode)) { + if (!quant_strategy_->CanConvOpQuantized(cnode)) { continue; } @@ -108,36 +139,28 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list &nodes) { } auto status = RET_ERROR; if (type_id == kNumberTypeInt8) { - status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); + status = + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); } else if (type_id == kNumberTypeInt16) { status = - QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); } if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed : " << status; return status; } - // set dtype - param_value->set_tensor_type(type_id); - auto abstractBase = param_node->abstract(); - if (abstractBase == nullptr) { - MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); + status = SetAbstract(param_value, param_node, primitive_c); + if (status != RET_OK) { + MS_LOG(ERROR) << "SetAbstract failed : " << status; return RET_ERROR; } - if (!utils::isa(abstractBase)) { - MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); - return RET_ERROR; - } - auto abstractTensor = utils::cast(abstractBase); - abstractTensor->element()->set_type(TypeIdToType(type_id)); - primitive_c->set_quant_type(schema::QuantType_WeightQuant); } return RET_OK; } STATUS WeightQuantizer::DoMulQuantize(const std::list &nodes) { for (auto &node : nodes) { - if (!mStrategy->CanMulOpQuantized(node)) { + if (!quant_strategy_->CanMulOpQuantized(node)) { continue; } auto already_quant = false; @@ -186,38 +209,271 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list &nodes) { auto status = RET_ERROR; if (type_id == kNumberTypeInt8) { - status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); + status = + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); } else if (type_id == kNumberTypeInt16) { status = - QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); + QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true); } if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed : " << status; return status; } - param_value->set_tensor_type(type_id); - // set dtype - auto abstractBase = param_node->abstract(); - if (abstractBase == nullptr) { - MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); + status = SetAbstract(param_value, param_node, primitive_c); + if (status != RET_OK) { + MS_LOG(ERROR) << "SetAbstract failed : " << status; return RET_ERROR; } - if (!utils::isa(abstractBase)) { - MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); - return RET_ERROR; - } - auto abstractTensor = utils::cast(abstractBase); - abstractTensor->element()->set_type(TypeIdToType(type_id)); - primitive_c->set_quant_type(schema::QuantType_WeightQuant); } return RET_OK; } -STATUS WeightQuantizer::DoQuantize(FuncGraphPtr funcGraph) { - MS_ASSERT(funcGraph != nullptr); +constexpr float relative_tolerance = 1e-5; +constexpr float abs_tolerance = 1e-4; + +template +float CompareOutputData(const std::unordered_map &expected_tensor, + const std::unordered_map &compare_tensor) { + auto valid_data = [](T data) -> bool { return (!std::isnan(data) && !std::isinf(data)); }; + + float total_mean_error = 0.0f; + int tensor_cnt = expected_tensor.size(); + + if (tensor_cnt <= 0) { + MS_LOG(ERROR) << "unexpected tensor_cnt: " << tensor_cnt; + return RET_ERROR; + } + + for (const auto &exp_tensor_pair : expected_tensor) { + float mean_error = 0.0f; + int error_cnt = 0; + + auto exp_tensor_name = exp_tensor_pair.first; + auto exp_tensor = exp_tensor_pair.second; + auto cmp_tensor_find_iter = compare_tensor.find(exp_tensor_name); + if (cmp_tensor_find_iter == compare_tensor.end()) { + MS_LOG(ERROR) << "can not find: " << exp_tensor_name; + return RET_ERROR; + } + auto cmp_tensor = cmp_tensor_find_iter->second; + + auto exp_tensor_shape = exp_tensor->shape(); + auto cmp_tensor_shape = cmp_tensor->shape(); + if (exp_tensor_shape != cmp_tensor_shape) { + MS_LOG(ERROR) << "exp tensor shape not equal to cmp. exp_tensor_elem_cnt: " << exp_tensor->ElementsNum() + << " cmp_tensor_elem_cnt: " << cmp_tensor->ElementsNum(); + return RET_ERROR; + } + auto exp_data = static_cast(exp_tensor->MutableData()); + auto cmp_data = static_cast(cmp_tensor->MutableData()); + auto elem_cnt = exp_tensor->ElementsNum(); + for (int i = 0; i < elem_cnt; i++) { + if (!valid_data(exp_data[i]) || !valid_data(cmp_data[i])) { + MS_LOG(ERROR) << "data is not valid. exp: " << exp_data[i] << " cmp: " << cmp_data[i] << " index: " << i; + return RET_ERROR; + } + auto tolerance = abs_tolerance + relative_tolerance * fabs(exp_data[i]); + auto abs_error = std::fabs(exp_data[i] - cmp_data[i]); + if (abs_error > tolerance) { + if (fabs(exp_data[i] == 0)) { + if (abs_error > 1e-5) { + mean_error += abs_error; + error_cnt++; + } else { + // it is ok, very close to 0 + continue; + } + } else { + mean_error += abs_error / (fabs(exp_data[i]) + FLT_MIN); + error_cnt++; + } + } else { + // it is ok, no error + continue; + } + } // end one tensor data loop + total_mean_error += mean_error / elem_cnt; + } // end tensor loop + return total_mean_error / tensor_cnt; +} + +STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) { + // 0.1 Create Fp32 Session + flags.quantType = schema::QuantType_QUANT_NONE; + fp32_session_ = CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num); + if (fp32_session_ == nullptr) { + MS_LOG(ERROR) << "CreateSessoin fail"; + return RET_ERROR; + } + auto fp32_inputs = fp32_session_->GetInputs(); + // 0.2 Parse input calib files + auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_); + if (status != RET_OK) { + MS_LOG(ERROR) << "CollectCalibInputs fail"; + return RET_ERROR; + } + + auto cnodes = func_graph->GetOrderedCnodes(); + for (auto iter = cnodes.end(); iter != cnodes.begin();) { + auto cnode = *(--iter); + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << "primitive_c is null."; + return RET_ERROR; + } + auto op_name = cnode->fullname_with_scope(); + MS_LOG(DEBUG) << "process node: " << op_name + << " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type()); + if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) { + auto input_node = cnode->input(2); + if (!input_node->isa()) { + MS_LOG(WARNING) << op_name << " the second input is not parameter"; + continue; + } + auto param_node = input_node->cast(); + if (!param_node->has_default()) { + MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; + continue; + } + auto param_value = std::static_pointer_cast(param_node->default_param()); + if (param_value == nullptr) { + MS_LOG(WARNING) << op_name << " the second input can not convert to parameter"; + continue; + } + if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) { + MS_LOG(WARNING) << op_name << " the second input type is not float"; + continue; + } + // copy origin data in case to recover + auto *raw_data = static_cast(param_value->tensor_addr()); + auto elem_count = param_value->tensor_shape_size(); + auto origin_data = malloc(sizeof(float) * elem_count); + auto ret = memcpy_s(origin_data, sizeof(float) * elem_count, raw_data, param_value->tensor_size()); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy fail: " + << " dst size: " << sizeof(float) * elem_count << " src size: " << param_value->tensor_size(); + return RET_ERROR; + } + // 1. try quant + for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) { + type_id = TypeId::kNumberTypeInt8; + int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1; + int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1)); + + if (type_id == TypeId::kNumberTypeInt8) { + status = QuantFilter(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, + quant_min_t, bit_num_t, true); + } else if (type_id == TypeId::kNumberTypeInt16) { + status = QuantFilter(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t, + quant_min_t, bit_num_t, true); + } else { + MS_LOG(ERROR) << "unexpected type_id: " << type_id; + return RET_ERROR; + } + if (status != RET_OK) { + MS_LOG(ERROR) << "quant filter fail."; + return RET_ERROR; + } + status = SetAbstract(param_value, param_node, primitive_c); + if (status != RET_OK) { + MS_LOG(ERROR) << "SetAbstract failed : " << status; + return RET_ERROR; + } + // 2. evaluate the quant + // 2.1 create quant session, get input, output tensor + flags.quantType = schema::QuantType_WeightQuant; + auto quant_session = + std::unique_ptr(CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num)); + if (quant_session == nullptr) { + MS_LOG(ERROR) << "create session error: " << status; + return RET_ERROR; + } + auto quant_inputs = quant_session->GetInputs(); + + auto mean_error = 0.0f; + if (fp32_inputs.size() != images_.size()) { + MS_LOG(ERROR) << "model's input tensor cnt: " << fp32_inputs.size() << " != " << images_.size(); + return RET_ERROR; + } + auto image_cnt = images_.at(0).size(); + for (size_t i = 0; i < image_cnt; i++) { + // set multi-input data + for (size_t input_index = 0; input_index < fp32_inputs.size(); input_index++) { + status = CopyInputDataToTensor(input_index, i, images_, fp32_inputs[input_index]); + if (status != RET_OK) { + MS_LOG(ERROR) << "generate input data from images failed!"; + return RET_ERROR; + } + status = CopyInputDataToTensor(input_index, i, images_, quant_inputs[input_index]); + if (status != RET_OK) { + MS_LOG(ERROR) << "generate input data from images failed!"; + return RET_ERROR; + } + } + std::future fp32_inference = std::async( + std::launch::async, [](session::LiteSession *fp32_session) -> STATUS { return fp32_session->RunGraph(); }, + fp32_session_); + + status = quant_session->RunGraph(); + if (status != RET_OK) { + MS_LOG(ERROR) << "quant session run error"; + return RET_ERROR; + } + status = fp32_inference.get(); + if (status != RET_OK) { + MS_LOG(ERROR) << "fp32 session run error"; + return RET_ERROR; + } + // 3. compare betwen quant and fp32 + auto fp32_outputs = fp32_session_->GetOutputs(); + auto quant_outputs = quant_session->GetOutputs(); + mean_error += CompareOutputData(fp32_outputs, quant_outputs); + } // end_for: calib data loop + mean_error = mean_error / image_cnt; + + if (mean_error <= config_param_.mean_error_threshold) { + MS_LOG(DEBUG) << "op: " << op_name << " got mixed bit: " << bit_num_t << " mean_error: " << mean_error; + opname_bit_[op_name] = bit_num_t; + break; + } else if (bit_num_t != 8) { + // recover + param_value->set_tensor_size(sizeof(float) * elem_count); + ret = memcpy_s(raw_data, param_value->tensor_size(), origin_data, sizeof(float) * elem_count); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy fail: " + << " src size: " << sizeof(float) * elem_count << " dst size: " << param_value->tensor_size(); + return RET_ERROR; + } + } else { + MS_LOG(DEBUG) << "op: " << op_name << " set bit: " << bit_num_t << " mean_error: " << mean_error; + opname_bit_[op_name] = bit_num_t; + } + } // end bit loop + free(origin_data); + } // if: conv and matmul + } // end loop: all cnode + return RET_OK; +} + +STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { + MS_ASSERT(func_graph != nullptr); STATUS ret; - auto cnodes = funcGraph->GetOrderedCnodes(); + auto cnodes = func_graph->GetOrderedCnodes(); + + if (!config_file_.empty()) { + ret = ParseConfigFile(config_file_, &config_param_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ReadConfig error."; + return RET_ERROR; + } + } + + if (config_param_.mixed) { + MS_LOG(INFO) << "Do mixed bit quantization"; + return DoMiexedQuant(func_graph); + } + ret = DoConvQuantize(cnodes); if (ret != RET_OK) { MS_LOG(ERROR) << "DoConvQuantize failed :" << ret; diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h index 7e91153ee4d..7e7494d2b23 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.h @@ -17,9 +17,12 @@ #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H +#include #include +#include #include #include +#include #include "tools/converter/quantizer/quantizer.h" #include "tools/converter/quantizer/quantize_util.h" #include "ir/func_graph.h" @@ -27,27 +30,37 @@ #include "include/model.h" #include "base/base.h" #include "abstract/dshape.h" +#include "src/lite_session.h" namespace mindspore::lite::quant { class WeightQuantizer : public Quantizer { public: - WeightQuantizer(FuncGraphPtr graph, const std::string &weightSize, const std::string &covWeightChannelThreshold, - const std::string &bitNum); + WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const std::string &weightSize, + const std::string &covWeightChannelThreshold, const std::string &bitNum); + WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config); + ~WeightQuantizer(); - ~WeightQuantizer() = default; - - STATUS DoQuantize(FuncGraphPtr funcGraph) override; + STATUS DoQuantize(FuncGraphPtr func_graph) override; STATUS DoConvQuantize(const std::list &nodes); STATUS DoMulQuantize(const std::list &nodes); static STATUS WeightQuantInputCheck(const converter::Flags *config); static bool IsPosNum(const std::string &str); + int quant_max; int quant_min; TypeId type_id{kTypeUnknown}; + std::map opname_bit_; private: - std::unique_ptr mStrategy; - size_t bitNum; + std::unique_ptr quant_strategy_; + size_t bit_num_; + std::string config_file_; + PostQuantConfig config_param_; + std::vector> images_; // multi_input, [[mode_input_0], [model_input_1]...] + session::LiteSession *fp32_session_ = nullptr; + + STATUS DoMiexedQuant(FuncGraphPtr); + STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr primitive_c); }; } // namespace mindspore::lite::quant #endif