post training quant mixed pricesion

This commit is contained in:
xutianchun 2021-01-03 17:03:03 +08:00
parent 21addb331d
commit bb1c4e3c6a
19 changed files with 740 additions and 345 deletions

View File

@ -21,7 +21,7 @@
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#ifdef PRIMITIVE_WRITEABLE
#include "tools/converter/quantizer/quantize_util.h"
#include "src/param_value_lite.h"
#endif
#ifndef PRIMITIVE_WRITEABLE

View File

@ -24,7 +24,7 @@
#include "src/common/log_adapter.h"
#ifdef PRIMITIVE_WRITEABLE
#include <float.h>
#include "tools/converter/quantizer/quantize_util.h"
#include "src/param_value_lite.h"
#endif
#ifndef PRIMITIVE_WRITEABLE

View File

@ -21,8 +21,7 @@
#include "src/common/log_adapter.h"
#ifdef PRIMITIVE_WRITEABLE
#include <float.h>
#include "tools/converter/quantizer/quantize_util.h"
#include "src/param_value_lite.h"
#endif
#ifndef PRIMITIVE_WRITEABLE

View File

@ -19,7 +19,7 @@
#include <memory>
#include <string>
#ifdef PRIMITIVE_WRITEABLE
#include "tools/converter/quantizer/quantize_util.h"
#include "src/param_value_lite.h"
#endif
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"

View File

@ -18,7 +18,7 @@
#include <memory>
#include <utility>
#ifdef PRIMITIVE_WRITEABLE
#include "tools/converter/quantizer/quantize_util.h"
#include "src/param_value_lite.h"
#endif
#ifndef PRIMITIVE_WRITEABLE

View File

@ -19,8 +19,7 @@
#include "src/common/log_adapter.h"
#ifdef PRIMITIVE_WRITEABLE
#include <float.h>
#include "tools/converter/quantizer/quantize_util.h"
#include "src/param_value_lite.h"
#endif
#ifndef PRIMITIVE_WRITEABLE

View File

@ -19,7 +19,7 @@
#include "src/common/log_adapter.h"
#ifdef PRIMITIVE_WRITEABLE
#include <float.h>
#include "tools/converter/quantizer/quantize_util.h"
#include "src/param_value_lite.h"
#endif
#ifndef PRIMITIVE_WRITEABLE

View File

@ -19,7 +19,7 @@
#include "src/common/log_adapter.h"
#ifdef PRIMITIVE_WRITEABLE
#include <float.h>
#include "tools/converter/quantizer/quantize_util.h"
#include "src/param_value_lite.h"
#endif
#ifndef PRIMITIVE_WRITEABLE

View File

@ -387,7 +387,9 @@ void PrimitiveC::set_input_quant_params(const std::vector<std::vector<schema::Qu
}
void PrimitiveC::set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) {
MS_ASSERT(index < this->input_quant_param_.size());
if (index >= this->input_quant_param_.size()) {
this->input_quant_param_.resize(index + 1);
}
this->input_quant_param_.at(index) = input_quant_param;
}
@ -493,7 +495,7 @@ std::shared_ptr<PrimitiveC> GetTupleGetItemPrim() {
}
template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>>
std::shared_ptr<PrimitiveC> NewPrimitiveC(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
std::shared_ptr<PrimitiveC> NewPrimitiveC(const mindspore::Primitive &prim, const std::vector<AnfNodePtr> &inputs,
const schema::QuantType &quantType) {
auto primc = std::make_shared<T>();
if (primc == nullptr) {

View File

@ -204,7 +204,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->quantWeightSize,
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->configFile, config->quantWeightSize,
config->quantWeightChannel, config->bitNum);
if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";

View File

@ -32,8 +32,6 @@
#include "tools/anf_exporter/anf_exporter.h"
#include "tools/anf_importer/import_from_mindir.h"
#include "proto/onnx.pb.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "include/version.h"
namespace mindspore {

View File

@ -16,7 +16,6 @@
#include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/common/tensor_util.h"
namespace mindspore::lite {

View File

@ -6,7 +6,6 @@ include_directories(${3RD_DIR}/opencv/build/include/opencv4)
file(GLOB QUANTIZER
${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc
${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc

View File

@ -39,6 +39,7 @@
#include "tools/common/tensor_util.h"
#include "src/common/file_utils.h"
#include "src/common/utils.h"
#include "tools/converter/quantizer/weight_quantizer.h"
using std::string;
using std::vector;
@ -380,182 +381,16 @@ STATUS Calibrator::AddQuantizedOp(const CNodePtr &node) {
return RET_OK;
}
void Calibrator::AddImage(const string &file, size_t index) {
if (index >= images_.size()) {
MS_LOG(ERROR) << "images_ size: " << images_.size() << " but index: " << index;
return;
}
auto exist = [](const string &file) {
struct stat buf {};
return stat(file.c_str(), &buf) == 0;
};
if (exist(file)) {
this->images_[index].push_back(file);
} else {
MS_LOG(WARNING) << "invalid image file path: " << file;
}
}
STATUS Calibrator::GenerateInputData(size_t input_index, size_t image_index,
mindspore::tensor::MSTensor *tensor) const {
MS_ASSERT(tensor != nullptr);
if (input_index >= images_.size()) {
MS_LOG(ERROR) << "images_ size: " << images_.size() << " but input_index: " << input_index;
return RET_ERROR;
}
if (image_index >= images_[input_index].size()) {
MS_LOG(ERROR) << "images_[input_index] size: " << images_[input_index].size()
<< " but image_index: " << image_index;
return RET_ERROR;
}
string path = images_[input_index][image_index];
MS_LOG(INFO) << "read image: " << path;
size_t size;
char *bin_buf = ReadFile(path.c_str(), &size);
if (bin_buf == nullptr) {
MS_LOG(ERROR) << "ReadFile return nullptr";
return RET_NULL_PTR;
}
auto data = tensor->MutableData();
if (data == nullptr) {
MS_LOG(ERROR) << "Get tensor MutableData return nullptr";
return RET_NULL_PTR;
}
if (size != tensor->Size()) {
MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size
<< " input tensor size: " << tensor->Size();
return RET_ERROR;
}
auto ret = memcpy_s(data, tensor->Size(), bin_buf, size);
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s error: " << ret;
delete[] bin_buf;
return RET_ERROR;
}
delete[] bin_buf;
return RET_OK;
return CopyInputDataToTensor(input_index, image_index, images_, tensor);
}
STATUS Calibrator::CollectImages() {
this->images_.resize(config_param_.image_paths.size());
auto input_i = 0;
bool multi_input = config_param_.image_paths.size() > 1;
for (const auto &image_path : config_param_.image_paths) {
DIR *root = opendir(image_path.c_str());
if (root == nullptr) {
MS_LOG(ERROR) << "invalid image path: " << image_path;
return RET_PARAM_INVALID;
}
struct dirent *image_dir = readdir(root);
size_t count = 0;
while (image_dir != nullptr) {
string file_name(image_dir->d_name);
if (file_name != "." && file_name != "..") {
const std::string file_path = image_path + "/" + file_name;
if (multi_input || config_param_.batch_count == 0) {
this->AddImage(file_path, input_i);
count++;
} else if (count < config_param_.batch_count) {
this->AddImage(file_path, input_i);
count++;
} else {
break;
}
}
image_dir = readdir(root);
}
std::sort(images_[input_i].begin(), images_[input_i].end());
if (config_param_.batch_count != 0 && config_param_.batch_count < images_[input_i].size()) {
images_[input_i].resize(config_param_.batch_count);
}
closedir(root);
input_i++;
}
return RET_OK;
return CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
}
STATUS Calibrator::ReadConfig() {
if (config_path_.empty() || config_path_.length() > PATH_MAX) {
MS_LOG(ERROR) << "invalid config path!";
return RET_PARAM_INVALID;
}
// check whether config file path is valid
char *resolved_path = new (std::nothrow) char[PATH_MAX]{0};
if (resolved_path == nullptr) {
MS_LOG(ERROR) << "New an object failed.";
return RET_ERROR;
}
#ifdef _WIN32
if (_fullpath(resolved_path, config_path_.c_str(), 1024) != nullptr) {
config_path_ = string(resolved_path);
}
#else
if (realpath(config_path_.c_str(), resolved_path) != nullptr) {
config_path_ = string(resolved_path);
}
#endif
std::ifstream fs(config_path_.c_str(), std::ifstream::in);
if (!fs.is_open()) {
MS_LOG(ERROR) << "config proto file %s open failed: " << config_path_;
delete[] resolved_path;
return RET_PARAM_INVALID;
}
std::string line;
while (std::getline(fs, line)) {
auto index = line.find('=');
if (index == std::string::npos) {
MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check";
delete[] resolved_path;
return RET_PARAM_INVALID;
}
auto key = line.substr(0, index);
auto value = line.substr(index + 1);
Trim(&key);
Trim(&value);
if (key == "image_path") {
auto &raw_image_paths = value;
auto ind = raw_image_paths.find(',');
while (ind != std::string::npos) {
auto image_path = raw_image_paths.substr(0, ind);
Trim(&image_path);
config_param_.image_paths.push_back(image_path);
raw_image_paths = raw_image_paths.substr(ind + 1);
Trim(&raw_image_paths);
ind = raw_image_paths.find(',');
}
config_param_.image_paths.push_back(raw_image_paths);
} else if (key == "batch_count") {
config_param_.batch_count = std::stoul(value);
} else if (key == "thread_num") {
config_param_.thread_num = std::stoul(value);
} else if (key == "method_x") {
if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) {
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
} else {
config_param_.method_x = value;
}
} else if (key == "bias_correction") {
std::for_each(value.begin(), value.end(), ::tolower);
if (value == "true") {
config_param_.bias_correction = true;
}
} else {
MS_LOG(WARNING) << "unsupported parameter";
}
}
for (const auto &path : config_param_.image_paths) {
MS_LOG(DEBUG) << "calibration data_path: " << path;
}
MS_LOG(DEBUG) << "batch_count: " << config_param_.batch_count << " "
<< "method_x: " << config_param_.method_x << " "
<< "thread_num: " << config_param_.thread_num << " "
<< "bias_correction: " << config_param_.bias_correction;
delete[] resolved_path;
fs.close();
return RET_OK;
}
STATUS Calibrator::ReadConfig() { return ParseConfigFile(config_path_, &config_param_); }
Calibrator::Calibrator(string path, size_t bit_num, int quant_max, int quant_min)
: config_path_(std::move(path)), bit_num_(bit_num), quant_max_(quant_max), quant_min_(quant_min) {}
@ -621,8 +456,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
return RET_OK;
}
STATUS PostTrainingQuantizer::DoWeightQuant(const AnfNodePtr &weight, std::shared_ptr<PrimitiveC> primitive_c,
bool perchanel) const {
STATUS PostTrainingQuantizer::DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight,
std::shared_ptr<PrimitiveC> primitive_c, bool perchanel) const {
MS_ASSERT(weight != nullptr);
MS_ASSERT(lite_primitive != nullptr);
// perlayer
@ -640,8 +475,21 @@ STATUS PostTrainingQuantizer::DoWeightQuant(const AnfNodePtr &weight, std::share
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
return RET_NULL_PTR;
}
auto status = QuantFilter<int8_t>(paramValue, std::move(primitive_c), QuantType_PostTraining, quant_max, quant_min,
bit_num, perchanel);
auto bit_num_t = bit_num;
auto quant_max_t = quant_max;
auto quant_min_t = quant_min;
if (calibrator_->config_param_.mixed) {
auto opname_iter = opname_bit_.find(op_name);
if (opname_iter == opname_bit_.end()) {
MS_LOG(WARNING) << op_name << " not in the opname_bit_ map";
} else {
bit_num_t = opname_iter->second;
quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
}
}
auto status = QuantFilter<int8_t>(paramValue, std::move(primitive_c), QuantType_PostTraining, quant_max_t,
quant_min_t, bit_num_t, perchanel);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status;
@ -921,7 +769,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
}
if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) {
MS_LOG(DEBUG) << "this parameter do quant";
DoWeightQuant(input_node, primitive_c, false);
DoWeightQuant(op_name, input_node, primitive_c, false);
} else {
MS_LOG(DEBUG) << "this parameter no need to do quant";
}
@ -943,7 +791,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
op_type == PrimitiveType_FullConnection) {
perchannel = true;
}
DoWeightQuant(weight, primitive_c, perchannel);
DoWeightQuant(op_name, weight, primitive_c, perchannel);
// do bias quant
if (cnode->inputs().size() == 4) {
auto bias = cnode->input(3);
@ -982,18 +830,8 @@ STATUS PostTrainingQuantizer::UpdateDivergInverval() {
* 3. save quantied node
**/
STATUS PostTrainingQuantizer::PreProcess() {
if (this->calibrator_ == nullptr) {
MS_LOG(ERROR) << "calibrator is null!";
return RET_ERROR;
}
// 1. generate config param
STATUS status = calibrator_->ReadConfig();
if (status != RET_OK) {
MS_LOG(ERROR) << "read proto text failed!";
return status;
}
// 2. collect image files
status = calibrator_->CollectImages();
auto status = calibrator_->CollectImages();
if (status != RET_OK) {
MS_LOG(ERROR) << "collect images failed!";
return status;
@ -1560,55 +1398,49 @@ STATUS PostTrainingQuantizer::ComputeThreshold() { return this->calibrator_->Com
STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_LOG(INFO) << "start to parse config file";
STATUS status = PreProcess();
if (this->calibrator_ == nullptr) {
MS_LOG(ERROR) << "calibrator is null!";
return RET_ERROR;
}
// 1. generate config param
STATUS status = calibrator_->ReadConfig();
if (status != RET_OK) {
MS_LOG(ERROR) << "read proto text failed!";
return status;
}
if (calibrator_->config_param_.mixed) {
// get opname_bit map
auto weight_quant_func_graph = CopyFuncGraph(func_graph);
if (weight_quant_func_graph == nullptr) {
MS_LOG(ERROR) << "CopyFuncGraph error";
return RET_ERROR;
}
WeightQuantizer weight_quantizer(weight_quant_func_graph, calibrator_->config_param_);
weight_quantizer.flags = flags;
status = weight_quantizer.DoQuantize(weight_quant_func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do mix weight quant error";
return RET_ERROR;
}
opname_bit_ = weight_quantizer.opname_bit_;
}
status = PreProcess();
if (status != RET_OK) {
MS_LOG(ERROR) << "do pre process failed!";
return status;
}
// anf -- fb
auto meta_graph = Export(func_graph, true, true);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta_graph return nullptr";
return RET_ERROR;
}
// transform
GraphDefTransform transform;
transform.SetGraphDef(meta_graph);
flags.quantType = schema::QuantType_QUANT_NONE;
status = transform.Transform(flags);
if (status != RET_OK) {
MS_LOG(ERROR) << "FBTransform model failed " << status;
return RET_ERROR;
}
MS_LOG(INFO) << "start create session";
flatbuffers::FlatBufferBuilder builder(1024);
auto offset = schema::MetaGraph::Pack(builder, meta_graph);
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
size_t size = builder.GetSize();
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
if (content == nullptr) {
MS_LOG(ERROR) << "GetBufferPointer nullptr";
return RET_ERROR;
}
auto model = lite::Model::Import(content, size);
Context ctx;
ctx.thread_num_ = calibrator_->GetThreadNum();
fp32_session_ = dynamic_cast<mindspore::lite::LiteSession *>(session::LiteSession::CreateSession(&ctx));
fp32_session_ = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
if (fp32_session_ == nullptr) {
MS_LOG(ERROR) << "create session failed!";
return RET_ERROR;
}
auto ret = fp32_session_->CompileGraph(model);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "compile graph error";
return RET_ERROR;
}
MS_LOG(INFO) << "start to update divergence's max value";
status = DoInference();
if (status != RET_OK) {
@ -1647,49 +1479,13 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
if (calibrator_->GetBiasCorrection()) {
// init in8 session
// anf -- fb
auto int8_meta_graph = Export(func_graph, true, true);
if (int8_meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to int8_meta_graph return nullptr";
return RET_ERROR;
}
// transform
GraphDefTransform fb_transform;
fb_transform.SetGraphDef(int8_meta_graph);
MS_LOG(INFO) << "create quant session";
flags.quantType = schema::QuantType_PostTraining;
status = fb_transform.Transform(flags);
if (status != RET_OK) {
MS_LOG(ERROR) << "FBTransform model failed " << status;
return RET_ERROR;
}
MS_LOG(INFO) << "start create quantized session";
flatbuffers::FlatBufferBuilder int8_builder(1024);
auto int8_offset = schema::MetaGraph::Pack(int8_builder, int8_meta_graph);
int8_builder.Finish(int8_offset);
schema::FinishMetaGraphBuffer(int8_builder, int8_offset);
size = int8_builder.GetSize();
auto *int8_content = reinterpret_cast<const char *>(int8_builder.GetBufferPointer());
if (int8_content == nullptr) {
MS_LOG(ERROR) << "GetBufferPointer nullptr";
return RET_ERROR;
}
auto int8_model = lite::Model::Import(int8_content, size);
Context int8_ctx;
int8_ctx.thread_num_ = calibrator_->GetThreadNum();
int8_ctx.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU;
int8_session_ = dynamic_cast<mindspore::lite::LiteSession *>(session::LiteSession::CreateSession(&int8_ctx));
int8_session_ = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
if (int8_session_ == nullptr) {
MS_LOG(ERROR) << "create session failed!";
return RET_ERROR;
}
ret = int8_session_->CompileGraph(int8_model);
if (ret != lite::RET_OK) {
MS_LOG(ERROR) << "compile graph error";
return RET_ERROR;
}
MS_LOG(INFO) << "do bias correction";
status = BiasCorrection(func_graph);

View File

@ -28,6 +28,8 @@
#include "tools/converter/quantizer/quantizer.h"
#include "tools/converter/converter.h"
#include "include/ms_tensor.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/converter/quantizer/weight_quantizer.h"
namespace mindspore::lite::quant {
class Calibrator;
@ -38,19 +40,8 @@ struct MaxMin {
float max;
};
const char kMethodMaxMin[] = "MAX_MIN";
const char kMethodKL[] = "KL";
const char kMethodOutlier[] = "RemovalOutlier";
constexpr int kDefaultBinNumber = 2048;
struct ConfigParam {
std::vector<std::string> image_paths;
uint32_t batch_count{100};
std::string method_x{kMethodKL};
uint32_t thread_num{1};
bool bias_correction{false};
};
class PostTrainingQuantizer : public Quantizer {
public:
PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8,
@ -64,14 +55,16 @@ class PostTrainingQuantizer : public Quantizer {
int quant_min{INT8_MIN};
private:
std::map<std::string, int> opname_bit_;
bool per_channel_{true};
TypeId target_type_{kNumberTypeInt8};
std::unique_ptr<Calibrator> calibrator_;
mindspore::lite::LiteSession *fp32_session_;
mindspore::lite::LiteSession *int8_session_;
session::LiteSession *fp32_session_{nullptr};
session::LiteSession *int8_session_{nullptr};
std::map<std::string, std::vector<float>> fp32_op_input_map; // concurency
std::map<std::string, std::vector<float>> fp32_op_output_ch_mean_map; // concurency
@ -112,7 +105,8 @@ class PostTrainingQuantizer : public Quantizer {
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min,
const std::shared_ptr<PrimitiveC> &) const;
STATUS DoWeightQuant(const AnfNodePtr &weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel) const;
STATUS DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight, std::shared_ptr<PrimitiveC> primitive_c,
bool perchannel) const;
STATUS DoBiasQuant(const AnfNodePtr &bias, const std::shared_ptr<PrimitiveC> &primitive_c);
STATUS Int8Inference();
@ -213,13 +207,13 @@ class Calibrator {
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetOutputDivergInfo();
PostQuantConfig config_param_;
private:
std::vector<std::vector<std::string>> images_; // multi_input, echo input has multi input data
std::string config_path_;
ConfigParam config_param_;
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> inputs_diverg_info_;
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> outputs_diverg_info_;
@ -227,8 +221,6 @@ class Calibrator {
size_t bit_num_;
int quant_max_;
int quant_min_;
void AddImage(const std::string &file, size_t index);
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_POSTRAINING_QUANTIZER_H

View File

@ -17,6 +17,8 @@
#include "mindspore/lite/tools/converter/quantizer/quantize_util.h"
#include <cmath>
#include <string>
#include <map>
#include <fstream>
#include <algorithm>
#include <memory>
#include <vector>
@ -26,6 +28,8 @@
#include "src/common/utils.h"
#include "abstract/abstract_value.h"
#include "securec/include/securec.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "mindspore/lite/include/version.h"
using std::string;
using std::vector;
@ -83,10 +87,10 @@ bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const {
bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
MS_ASSERT(node != nullptr);
if (!node->isa<CNode>()) {
if (!node->isa<mindspore::CNode>()) {
return false;
}
auto cnode = std::dynamic_pointer_cast<CNode>(node);
auto cnode = std::dynamic_pointer_cast<mindspore::CNode>(node);
auto type = NodePrimitiveType(cnode);
static const std::vector<schema::PrimitiveType> int8OpList = {
schema::PrimitiveType_Conv2D,
@ -475,4 +479,307 @@ schema::PrimitiveType NodePrimitiveType(const CNodePtr &cnode) {
}
return (schema::PrimitiveType)primitive_c->Type();
}
STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) {
if (post_quant_config == nullptr) {
MS_LOG(ERROR) << "post_quant_config is null.";
return RET_PARAM_INVALID;
}
if (config_file.empty() || config_file.length() > PATH_MAX) {
MS_LOG(ERROR) << "invalid config path!";
return RET_PARAM_INVALID;
}
// check whether config file path is valid
auto resolved_path = std::make_unique<char[]>(PATH_MAX);
if (resolved_path == nullptr) {
MS_LOG(ERROR) << "New an object failed.";
return RET_ERROR;
}
#ifdef _WIN32
if (_fullpath(resolved_path.get(), config_file.c_str(), 1024) != nullptr) {
config_file = string(resolved_path.get());
}
#else
if (realpath(config_file.c_str(), resolved_path.get()) != nullptr) {
config_file = string(resolved_path.get());
}
#endif
std::ifstream fs(config_file.c_str(), std::ifstream::in);
if (!fs.is_open()) {
MS_LOG(ERROR) << "config file open failed: " << config_file;
return RET_PARAM_INVALID;
}
std::string line;
while (std::getline(fs, line)) {
Trim(&line);
if (line.empty()) {
continue;
}
auto index = line.find('=');
if (index == std::string::npos) {
MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check";
return RET_PARAM_INVALID;
}
auto key = line.substr(0, index);
auto value = line.substr(index + 1);
Trim(&key);
Trim(&value);
if (key == "image_path") {
auto &raw_image_paths = value;
auto ind = raw_image_paths.find(',');
while (ind != std::string::npos) {
auto image_path = raw_image_paths.substr(0, ind);
Trim(&image_path);
post_quant_config->image_paths.push_back(image_path);
raw_image_paths = raw_image_paths.substr(ind + 1);
Trim(&raw_image_paths);
ind = raw_image_paths.find(',');
}
post_quant_config->image_paths.push_back(raw_image_paths);
} else if (key == "batch_count") {
post_quant_config->batch_count = std::stoul(value);
} else if (key == "thread_num") {
post_quant_config->thread_num = std::stoul(value);
} else if (key == "method_x") {
if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) {
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
} else {
post_quant_config->method_x = value;
}
} else if (key == "bias_correction") {
std::for_each(value.begin(), value.end(), ::tolower);
if (value == "true") {
post_quant_config->bias_correction = true;
}
} else if (key == "mixed") {
std::for_each(value.begin(), value.end(), ::tolower);
if (value == "true") {
post_quant_config->mixed = true;
}
} else if (key == "mean_error_threshold") {
post_quant_config->mean_error_threshold = std::stof(value);
} else {
MS_LOG(WARNING) << "unsupported parameter: " << key;
}
}
for (const auto &path : post_quant_config->image_paths) {
MS_LOG(DEBUG) << "calibration data_path: " << path;
}
MS_LOG(DEBUG) << "batch_count: " << post_quant_config->batch_count << "\n"
<< "method_x: " << post_quant_config->method_x << "\n"
<< "thread_num: " << post_quant_config->thread_num << "\n"
<< "bias_correction: " << post_quant_config->bias_correction << "\n"
<< "mixed: " << post_quant_config->mixed << "\n"
<< "mean_error_threshold: " << post_quant_config->mean_error_threshold;
post_quant_config->inited = true;
fs.close();
return RET_OK;
}
session::LiteSession *CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags,
int thread_num) {
auto meta_graph = Export(func_graph, true, true);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta_graph failed";
return nullptr;
}
// transform
GraphDefTransform fb_transform;
fb_transform.SetGraphDef(meta_graph);
auto status = fb_transform.Transform(flags);
if (status != RET_OK) {
MS_LOG(ERROR) << "FBTransform model failed";
return nullptr;
}
meta_graph->version = Version();
flatbuffers::FlatBufferBuilder builder(1024);
auto offset = schema::MetaGraph::Pack(builder, meta_graph);
builder.Finish(offset);
schema::FinishMetaGraphBuffer(builder, offset);
auto size = builder.GetSize();
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
if (content == nullptr) {
MS_LOG(ERROR) << "GetBufferPointer return null";
return nullptr;
}
auto model = lite::Model::Import(content, size);
if (model == nullptr) {
MS_LOG(ERROR) << "Import model failed";
return nullptr;
}
Context ctx;
ctx.thread_num_ = thread_num;
auto session = session::LiteSession::CreateSession(&ctx);
if (session == nullptr) {
MS_LOG(ERROR) << "create session failed.";
return nullptr;
}
status = session->CompileGraph(model);
if (status != RET_OK) {
MS_LOG(ERROR) << "CompileGraph error";
return nullptr;
}
model->Free();
return session;
}
STATUS CollectCalibInputs(const std::vector<std::string> &input_dirs, size_t count_limited,
std::vector<std::vector<std::string>> *inputs) {
if (inputs == nullptr) {
MS_LOG(ERROR) << "inputs is null";
return RET_ERROR;
}
auto AddImage = [&inputs](const std::string &file, size_t index) {
if (index >= inputs->size()) {
MS_LOG(ERROR) << "images_ size: " << inputs->size() << " but input index: " << index;
return;
}
struct stat buf {};
if (stat(file.c_str(), &buf) == 0) {
inputs->at(index).push_back(file);
} else {
MS_LOG(WARNING) << "invalid image file path: " << file;
}
};
inputs->resize(input_dirs.size());
auto input_i = 0;
bool multi_input = input_dirs.size() > 1;
for (const auto &image_path : input_dirs) {
DIR *root = opendir(image_path.c_str());
if (root == nullptr) {
MS_LOG(ERROR) << "invalid image path: " << image_path;
return RET_PARAM_INVALID;
}
struct dirent *image_dir = readdir(root);
size_t count = 0;
while (image_dir != nullptr) {
string file_name(image_dir->d_name);
if (file_name != "." && file_name != "..") {
const std::string file_path = image_path + "/" + file_name;
if (multi_input || count == 0) {
AddImage(file_path, input_i);
count++;
} else if (count < count_limited) {
AddImage(file_path, input_i);
count++;
} else {
break;
}
}
image_dir = readdir(root);
}
std::sort(inputs->at(input_i).begin(), inputs->at(input_i).end());
if (count_limited != 0 && count_limited < inputs->at(input_i).size()) {
inputs->at(input_i).resize(count_limited);
}
closedir(root);
input_i++;
}
return RET_OK;
}
STATUS CopyInputDataToTensor(size_t input_index, size_t image_index,
const std::vector<std::vector<std::string>> &images, mindspore::tensor::MSTensor *tensor) {
MS_ASSERT(tensor != nullptr);
if (input_index >= images.size()) {
MS_LOG(ERROR) << "images_ size: " << images.size() << " but input_index: " << input_index;
return RET_ERROR;
}
if (image_index >= images[input_index].size()) {
MS_LOG(ERROR) << "images_[input_index] size: " << images[input_index].size() << " but image_index: " << image_index;
return RET_ERROR;
}
string path = images[input_index][image_index];
MS_LOG(INFO) << "read image: " << path;
size_t size;
char *bin_buf = ReadFile(path.c_str(), &size);
if (bin_buf == nullptr) {
MS_LOG(ERROR) << "ReadFile return nullptr";
return RET_NULL_PTR;
}
auto data = tensor->MutableData();
if (data == nullptr) {
MS_LOG(ERROR) << "Get tensor MutableData return nullptr";
return RET_NULL_PTR;
}
if (size != tensor->Size()) {
MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size
<< " input tensor size: " << tensor->Size();
return RET_ERROR;
}
auto ret = memcpy_s(data, tensor->Size(), bin_buf, size);
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s error: " << ret;
delete[] bin_buf;
return RET_ERROR;
}
delete[] bin_buf;
return RET_OK;
}
FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &func_graph) {
Cloner cloner({func_graph}, true, true, true, std::make_shared<TraceCopy>(), nullptr);
auto new_func_graph = cloner[func_graph];
std::map<std::string, CNodePtr> old_cnode_map;
for (const auto &cnode : func_graph->GetOrderedCnodes()) {
old_cnode_map[cnode->fullname_with_scope()] = cnode;
}
for (auto &cnode : new_func_graph->GetOrderedCnodes()) {
auto cnode_name = cnode->fullname_with_scope();
auto old_cnode_iter = old_cnode_map.find(cnode_name);
if (old_cnode_iter == old_cnode_map.end()) {
MS_LOG(ERROR) << "can not find node: " << cnode_name;
return nullptr;
}
auto old_cnode = old_cnode_iter->second;
auto inputs = cnode->inputs();
for (size_t i = 0; i < inputs.size(); i++) {
auto input_node = inputs[i];
if (input_node->isa<Parameter>()) {
auto param_node = input_node->cast<ParameterPtr>();
if (param_node->has_default()) {
ParamValueLitePtr old_param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
auto new_param_value = std::make_shared<ParamValueLite>();
auto copyed_data = malloc(old_param_value->tensor_size());
if (copyed_data == nullptr) {
MS_LOG(ERROR) << "malloc data error, size: " << old_param_value->tensor_size();
return nullptr;
}
memcpy(copyed_data, old_param_value->tensor_addr(), old_param_value->tensor_size());
new_param_value->set_tensor_size(old_param_value->tensor_size());
new_param_value->set_tensor_addr(copyed_data);
new_param_value->set_tensor_shape(old_param_value->tensor_shape());
new_param_value->set_format(old_param_value->format());
new_param_value->set_tensor_type(old_param_value->tensor_type());
param_node->set_default_param(new_param_value);
}
auto old_abstract_base = param_node->abstract();
if (!utils::isa<abstract::AbstractTensorPtr>(old_abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << param_node->name();
return nullptr;
}
auto old_abstract = utils::cast<abstract::AbstractTensorPtr>(old_abstract_base);
auto new_abstract = std::make_shared<abstract::AbstractTensor>(old_abstract->element()->GetTypeTrack(),
old_abstract->GetShapeTrack());
param_node->set_abstract(new_abstract);
}
} // end inputs loop
} // end cnodes loop
return new_func_graph;
}
} // namespace mindspore::lite::quant

View File

@ -17,6 +17,8 @@
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H
#include <dirent.h>
#include <sys/stat.h>
#include <memory>
#include <string>
#include <cmath>
@ -35,11 +37,29 @@
#include "ir/primitive.h"
#include "abstract/dshape.h"
#include "tools/converter/quantizer/bitpacking.h"
#include "src/lite_session.h"
#include "tools/converter/graphdef_transform.h"
#include "src/common/file_utils.h"
namespace mindspore::lite::quant {
static constexpr size_t UINT8_QUANTIZATION = 8;
static constexpr size_t WEIGHT_INDEX = 1;
const char kMethodMaxMin[] = "MAX_MIN";
const char kMethodKL[] = "KL";
const char kMethodOutlier[] = "RemovalOutlier";
struct PostQuantConfig {
std::vector<std::string> image_paths;
uint32_t batch_count{100};
std::string method_x{kMethodKL};
uint32_t thread_num{1};
bool bias_correction{false};
bool mixed{false};
float mean_error_threshold{0.04};
bool inited{false};
};
/**
* 1. when op's weight size > mWeightSize just skip
* 2. only do conv/deconv/convdepthwise/deconvdepthwise/mul/matmul/batchmatmul quantization
@ -320,6 +340,21 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit
return RET_OK;
}
// utils
schema::PrimitiveType NodePrimitiveType(const CNodePtr &cnode);
STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config);
session::LiteSession *CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags,
int thread_num);
STATUS CollectCalibInputs(const std::vector<std::string> &input_dirs, size_t count_limited,
std::vector<std::vector<std::string>> *inputs);
STATUS CopyInputDataToTensor(size_t input_index, size_t image_index,
const std::vector<std::vector<std::string>> &images, mindspore::tensor::MSTensor *tensor);
FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &);
} // namespace mindspore::lite::quant
#endif

View File

@ -18,6 +18,7 @@
#include <list>
#include <string>
#include <vector>
#include <unordered_map>
#include "src/common/common.h"
#include "ir/dtype/type_id.h"
@ -36,6 +37,7 @@ bool WeightQuantizer::IsPosNum(const std::string &str) {
}
return true;
}
STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) {
MS_ASSERT(config != nullptr);
if (!WeightQuantizer::IsPosNum(config->quantWeightChannel)) {
@ -57,28 +59,57 @@ STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) {
}
return RET_OK;
}
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize,
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) {
quant_strategy_ = std::make_unique<QuantStrategy>(0, 0);
config_param_ = config;
}
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const string &weightSize,
const std::string &convWeightChannelThreshold, const std::string &bitNum)
: Quantizer(graph) {
this->config_file_ = config_file;
auto quantSize = static_cast<size_t>(std::stoull(weightSize));
this->bitNum = static_cast<size_t>(std::stoull(bitNum));
this->bit_num_ = static_cast<size_t>(std::stoull(bitNum));
auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold));
mStrategy = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold);
quant_max = (1 << (unsigned int)(this->bitNum - 1)) - 1;
quant_min = -(1 << (unsigned int)(this->bitNum - 1));
quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold);
quant_max = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
quant_min = -(1 << (unsigned int)(this->bit_num_ - 1));
// parse type_id
if (this->bitNum > 0 && this->bitNum <= 8) {
if (this->bit_num_ > 0 && this->bit_num_ <= 8) {
type_id = kNumberTypeInt8;
} else if (this->bitNum <= 16) {
} else if (this->bit_num_ <= 16) {
type_id = kNumberTypeInt16;
} else {
MS_LOG(ERROR) << "invalid input bits";
}
}
WeightQuantizer::~WeightQuantizer() { delete fp32_session_; }
STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node,
std::shared_ptr<PrimitiveC> primitive_c) {
// set dtype
param_value->set_tensor_type(type_id);
auto abstract_base = param_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
return RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
abstract_tensor->element()->set_type(TypeIdToType(type_id));
primitive_c->set_quant_type(schema::QuantType_WeightQuant);
return RET_OK;
}
STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
for (auto &cnode : nodes) {
if (!mStrategy->CanConvOpQuantized(cnode)) {
if (!quant_strategy_->CanConvOpQuantized(cnode)) {
continue;
}
@ -108,36 +139,28 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
}
auto status = RET_ERROR;
if (type_id == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
} else if (type_id == kNumberTypeInt16) {
status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
// set dtype
param_value->set_tensor_type(type_id);
auto abstractBase = param_node->abstract();
if (abstractBase == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
return RET_ERROR;
}
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
abstractTensor->element()->set_type(TypeIdToType(type_id));
primitive_c->set_quant_type(schema::QuantType_WeightQuant);
}
return RET_OK;
}
STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
for (auto &node : nodes) {
if (!mStrategy->CanMulOpQuantized(node)) {
if (!quant_strategy_->CanMulOpQuantized(node)) {
continue;
}
auto already_quant = false;
@ -186,38 +209,271 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
auto status = RET_ERROR;
if (type_id == kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
} else if (type_id == kNumberTypeInt16) {
status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
param_value->set_tensor_type(type_id);
// set dtype
auto abstractBase = param_node->abstract();
if (abstractBase == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
return RET_ERROR;
}
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
abstractTensor->element()->set_type(TypeIdToType(type_id));
primitive_c->set_quant_type(schema::QuantType_WeightQuant);
}
return RET_OK;
}
STATUS WeightQuantizer::DoQuantize(FuncGraphPtr funcGraph) {
MS_ASSERT(funcGraph != nullptr);
constexpr float relative_tolerance = 1e-5;
constexpr float abs_tolerance = 1e-4;
template <typename T>
float CompareOutputData(const std::unordered_map<std::string, mindspore::tensor::MSTensor *> &expected_tensor,
const std::unordered_map<std::string, mindspore::tensor::MSTensor *> &compare_tensor) {
auto valid_data = [](T data) -> bool { return (!std::isnan(data) && !std::isinf(data)); };
float total_mean_error = 0.0f;
int tensor_cnt = expected_tensor.size();
if (tensor_cnt <= 0) {
MS_LOG(ERROR) << "unexpected tensor_cnt: " << tensor_cnt;
return RET_ERROR;
}
for (const auto &exp_tensor_pair : expected_tensor) {
float mean_error = 0.0f;
int error_cnt = 0;
auto exp_tensor_name = exp_tensor_pair.first;
auto exp_tensor = exp_tensor_pair.second;
auto cmp_tensor_find_iter = compare_tensor.find(exp_tensor_name);
if (cmp_tensor_find_iter == compare_tensor.end()) {
MS_LOG(ERROR) << "can not find: " << exp_tensor_name;
return RET_ERROR;
}
auto cmp_tensor = cmp_tensor_find_iter->second;
auto exp_tensor_shape = exp_tensor->shape();
auto cmp_tensor_shape = cmp_tensor->shape();
if (exp_tensor_shape != cmp_tensor_shape) {
MS_LOG(ERROR) << "exp tensor shape not equal to cmp. exp_tensor_elem_cnt: " << exp_tensor->ElementsNum()
<< " cmp_tensor_elem_cnt: " << cmp_tensor->ElementsNum();
return RET_ERROR;
}
auto exp_data = static_cast<T *>(exp_tensor->MutableData());
auto cmp_data = static_cast<T *>(cmp_tensor->MutableData());
auto elem_cnt = exp_tensor->ElementsNum();
for (int i = 0; i < elem_cnt; i++) {
if (!valid_data(exp_data[i]) || !valid_data(cmp_data[i])) {
MS_LOG(ERROR) << "data is not valid. exp: " << exp_data[i] << " cmp: " << cmp_data[i] << " index: " << i;
return RET_ERROR;
}
auto tolerance = abs_tolerance + relative_tolerance * fabs(exp_data[i]);
auto abs_error = std::fabs(exp_data[i] - cmp_data[i]);
if (abs_error > tolerance) {
if (fabs(exp_data[i] == 0)) {
if (abs_error > 1e-5) {
mean_error += abs_error;
error_cnt++;
} else {
// it is ok, very close to 0
continue;
}
} else {
mean_error += abs_error / (fabs(exp_data[i]) + FLT_MIN);
error_cnt++;
}
} else {
// it is ok, no error
continue;
}
} // end one tensor data loop
total_mean_error += mean_error / elem_cnt;
} // end tensor loop
return total_mean_error / tensor_cnt;
}
STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) {
// 0.1 Create Fp32 Session
flags.quantType = schema::QuantType_QUANT_NONE;
fp32_session_ = CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num);
if (fp32_session_ == nullptr) {
MS_LOG(ERROR) << "CreateSessoin fail";
return RET_ERROR;
}
auto fp32_inputs = fp32_session_->GetInputs();
// 0.2 Parse input calib files
auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
if (status != RET_OK) {
MS_LOG(ERROR) << "CollectCalibInputs fail";
return RET_ERROR;
}
auto cnodes = func_graph->GetOrderedCnodes();
for (auto iter = cnodes.end(); iter != cnodes.begin();) {
auto cnode = *(--iter);
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is null.";
return RET_ERROR;
}
auto op_name = cnode->fullname_with_scope();
MS_LOG(DEBUG) << "process node: " << op_name
<< " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type());
if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) {
auto input_node = cnode->input(2);
if (!input_node->isa<Parameter>()) {
MS_LOG(WARNING) << op_name << " the second input is not parameter";
continue;
}
auto param_node = input_node->cast<ParameterPtr>();
if (!param_node->has_default()) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
continue;
}
auto param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
if (param_value == nullptr) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
continue;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << op_name << " the second input type is not float";
continue;
}
// copy origin data in case to recover
auto *raw_data = static_cast<float *>(param_value->tensor_addr());
auto elem_count = param_value->tensor_shape_size();
auto origin_data = malloc(sizeof(float) * elem_count);
auto ret = memcpy_s(origin_data, sizeof(float) * elem_count, raw_data, param_value->tensor_size());
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy fail: "
<< " dst size: " << sizeof(float) * elem_count << " src size: " << param_value->tensor_size();
return RET_ERROR;
}
// 1. try quant
for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) {
type_id = TypeId::kNumberTypeInt8;
int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
if (type_id == TypeId::kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
quant_min_t, bit_num_t, true);
} else if (type_id == TypeId::kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
quant_min_t, bit_num_t, true);
} else {
MS_LOG(ERROR) << "unexpected type_id: " << type_id;
return RET_ERROR;
}
if (status != RET_OK) {
MS_LOG(ERROR) << "quant filter fail.";
return RET_ERROR;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
// 2. evaluate the quant
// 2.1 create quant session, get input, output tensor
flags.quantType = schema::QuantType_WeightQuant;
auto quant_session =
std::unique_ptr<session::LiteSession>(CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num));
if (quant_session == nullptr) {
MS_LOG(ERROR) << "create session error: " << status;
return RET_ERROR;
}
auto quant_inputs = quant_session->GetInputs();
auto mean_error = 0.0f;
if (fp32_inputs.size() != images_.size()) {
MS_LOG(ERROR) << "model's input tensor cnt: " << fp32_inputs.size() << " != " << images_.size();
return RET_ERROR;
}
auto image_cnt = images_.at(0).size();
for (size_t i = 0; i < image_cnt; i++) {
// set multi-input data
for (size_t input_index = 0; input_index < fp32_inputs.size(); input_index++) {
status = CopyInputDataToTensor(input_index, i, images_, fp32_inputs[input_index]);
if (status != RET_OK) {
MS_LOG(ERROR) << "generate input data from images failed!";
return RET_ERROR;
}
status = CopyInputDataToTensor(input_index, i, images_, quant_inputs[input_index]);
if (status != RET_OK) {
MS_LOG(ERROR) << "generate input data from images failed!";
return RET_ERROR;
}
}
std::future<STATUS> fp32_inference = std::async(
std::launch::async, [](session::LiteSession *fp32_session) -> STATUS { return fp32_session->RunGraph(); },
fp32_session_);
status = quant_session->RunGraph();
if (status != RET_OK) {
MS_LOG(ERROR) << "quant session run error";
return RET_ERROR;
}
status = fp32_inference.get();
if (status != RET_OK) {
MS_LOG(ERROR) << "fp32 session run error";
return RET_ERROR;
}
// 3. compare betwen quant and fp32
auto fp32_outputs = fp32_session_->GetOutputs();
auto quant_outputs = quant_session->GetOutputs();
mean_error += CompareOutputData<float>(fp32_outputs, quant_outputs);
} // end_for: calib data loop
mean_error = mean_error / image_cnt;
if (mean_error <= config_param_.mean_error_threshold) {
MS_LOG(DEBUG) << "op: " << op_name << " got mixed bit: " << bit_num_t << " mean_error: " << mean_error;
opname_bit_[op_name] = bit_num_t;
break;
} else if (bit_num_t != 8) {
// recover
param_value->set_tensor_size(sizeof(float) * elem_count);
ret = memcpy_s(raw_data, param_value->tensor_size(), origin_data, sizeof(float) * elem_count);
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy fail: "
<< " src size: " << sizeof(float) * elem_count << " dst size: " << param_value->tensor_size();
return RET_ERROR;
}
} else {
MS_LOG(DEBUG) << "op: " << op_name << " set bit: " << bit_num_t << " mean_error: " << mean_error;
opname_bit_[op_name] = bit_num_t;
}
} // end bit loop
free(origin_data);
} // if: conv and matmul
} // end loop: all cnode
return RET_OK;
}
STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_ASSERT(func_graph != nullptr);
STATUS ret;
auto cnodes = funcGraph->GetOrderedCnodes();
auto cnodes = func_graph->GetOrderedCnodes();
if (!config_file_.empty()) {
ret = ParseConfigFile(config_file_, &config_param_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ReadConfig error.";
return RET_ERROR;
}
}
if (config_param_.mixed) {
MS_LOG(INFO) << "Do mixed bit quantization";
return DoMiexedQuant(func_graph);
}
ret = DoConvQuantize(cnodes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoConvQuantize failed :" << ret;

View File

@ -17,9 +17,12 @@
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H
#include <future>
#include <memory>
#include <map>
#include <list>
#include <string>
#include <vector>
#include "tools/converter/quantizer/quantizer.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "ir/func_graph.h"
@ -27,27 +30,37 @@
#include "include/model.h"
#include "base/base.h"
#include "abstract/dshape.h"
#include "src/lite_session.h"
namespace mindspore::lite::quant {
class WeightQuantizer : public Quantizer {
public:
WeightQuantizer(FuncGraphPtr graph, const std::string &weightSize, const std::string &covWeightChannelThreshold,
const std::string &bitNum);
WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const std::string &weightSize,
const std::string &covWeightChannelThreshold, const std::string &bitNum);
WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config);
~WeightQuantizer();
~WeightQuantizer() = default;
STATUS DoQuantize(FuncGraphPtr funcGraph) override;
STATUS DoQuantize(FuncGraphPtr func_graph) override;
STATUS DoConvQuantize(const std::list<CNodePtr> &nodes);
STATUS DoMulQuantize(const std::list<CNodePtr> &nodes);
static STATUS WeightQuantInputCheck(const converter::Flags *config);
static bool IsPosNum(const std::string &str);
int quant_max;
int quant_min;
TypeId type_id{kTypeUnknown};
std::map<std::string, int> opname_bit_;
private:
std::unique_ptr<QuantStrategy> mStrategy;
size_t bitNum;
std::unique_ptr<QuantStrategy> quant_strategy_;
size_t bit_num_;
std::string config_file_;
PostQuantConfig config_param_;
std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...]
session::LiteSession *fp32_session_ = nullptr;
STATUS DoMiexedQuant(FuncGraphPtr);
STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c);
};
} // namespace mindspore::lite::quant
#endif