forked from mindspore-Ecosystem/mindspore
commit
1908d6b8c4
|
@ -21,7 +21,7 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#endif
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include <float.h>
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#endif
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
|
|
@ -21,8 +21,7 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include <float.h>
|
||||
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#endif
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#endif
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#endif
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
|
|
@ -19,8 +19,7 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include <float.h>
|
||||
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#endif
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include <float.h>
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#endif
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include "src/common/log_adapter.h"
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
#include <float.h>
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "src/param_value_lite.h"
|
||||
#endif
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
|
|
|
@ -389,7 +389,9 @@ void PrimitiveC::set_input_quant_params(const std::vector<std::vector<schema::Qu
|
|||
}
|
||||
|
||||
void PrimitiveC::set_input_quant_param(const size_t &index, const std::vector<schema::QuantParamT> &input_quant_param) {
|
||||
MS_ASSERT(index < this->input_quant_param_.size());
|
||||
if (index >= this->input_quant_param_.size()) {
|
||||
this->input_quant_param_.resize(index + 1);
|
||||
}
|
||||
this->input_quant_param_.at(index) = input_quant_param;
|
||||
}
|
||||
|
||||
|
@ -495,7 +497,7 @@ std::shared_ptr<PrimitiveC> GetTupleGetItemPrim() {
|
|||
}
|
||||
|
||||
template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>>
|
||||
std::shared_ptr<PrimitiveC> NewPrimitiveC(const Primitive &prim, const std::vector<AnfNodePtr> &inputs,
|
||||
std::shared_ptr<PrimitiveC> NewPrimitiveC(const mindspore::Primitive &prim, const std::vector<AnfNodePtr> &inputs,
|
||||
const schema::QuantType &quantType) {
|
||||
auto primc = std::make_shared<T>();
|
||||
if (primc == nullptr) {
|
||||
|
|
|
@ -204,7 +204,7 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
|
|||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return nullptr;
|
||||
}
|
||||
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->quantWeightSize,
|
||||
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->configFile, config->quantWeightSize,
|
||||
config->quantWeightChannel, config->bitNum);
|
||||
if (mQuantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New WeightQuantizer failed";
|
||||
|
|
|
@ -32,8 +32,6 @@
|
|||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
#include "tools/anf_importer/import_from_mindir.h"
|
||||
#include "proto/onnx.pb.h"
|
||||
#include "tools/converter/quantizer/post_training_quantizer.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
#include "include/version.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
#include "tools/converter/legacy_optimizer/graph/tensor_name_pass.h"
|
||||
#include "tools/converter/converter_context.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
|
|
|
@ -6,7 +6,6 @@ include_directories(${3RD_DIR}/opencv/build/include/opencv4)
|
|||
file(GLOB QUANTIZER
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc
|
||||
|
|
|
@ -39,6 +39,7 @@
|
|||
#include "tools/common/tensor_util.h"
|
||||
#include "src/common/file_utils.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
@ -380,182 +381,16 @@ STATUS Calibrator::AddQuantizedOp(const CNodePtr &node) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void Calibrator::AddImage(const string &file, size_t index) {
|
||||
if (index >= images_.size()) {
|
||||
MS_LOG(ERROR) << "images_ size: " << images_.size() << " but index: " << index;
|
||||
return;
|
||||
}
|
||||
auto exist = [](const string &file) {
|
||||
struct stat buf {};
|
||||
return stat(file.c_str(), &buf) == 0;
|
||||
};
|
||||
if (exist(file)) {
|
||||
this->images_[index].push_back(file);
|
||||
} else {
|
||||
MS_LOG(WARNING) << "invalid image file path: " << file;
|
||||
}
|
||||
}
|
||||
|
||||
STATUS Calibrator::GenerateInputData(size_t input_index, size_t image_index,
|
||||
mindspore::tensor::MSTensor *tensor) const {
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
if (input_index >= images_.size()) {
|
||||
MS_LOG(ERROR) << "images_ size: " << images_.size() << " but input_index: " << input_index;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (image_index >= images_[input_index].size()) {
|
||||
MS_LOG(ERROR) << "images_[input_index] size: " << images_[input_index].size()
|
||||
<< " but image_index: " << image_index;
|
||||
return RET_ERROR;
|
||||
}
|
||||
string path = images_[input_index][image_index];
|
||||
MS_LOG(INFO) << "read image: " << path;
|
||||
size_t size;
|
||||
char *bin_buf = ReadFile(path.c_str(), &size);
|
||||
if (bin_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "ReadFile return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto data = tensor->MutableData();
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "Get tensor MutableData return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (size != tensor->Size()) {
|
||||
MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size
|
||||
<< " input tensor size: " << tensor->Size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = memcpy_s(data, tensor->Size(), bin_buf, size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s error: " << ret;
|
||||
delete[] bin_buf;
|
||||
return RET_ERROR;
|
||||
}
|
||||
delete[] bin_buf;
|
||||
return RET_OK;
|
||||
return CopyInputDataToTensor(input_index, image_index, images_, tensor);
|
||||
}
|
||||
|
||||
STATUS Calibrator::CollectImages() {
|
||||
this->images_.resize(config_param_.image_paths.size());
|
||||
auto input_i = 0;
|
||||
bool multi_input = config_param_.image_paths.size() > 1;
|
||||
for (const auto &image_path : config_param_.image_paths) {
|
||||
DIR *root = opendir(image_path.c_str());
|
||||
if (root == nullptr) {
|
||||
MS_LOG(ERROR) << "invalid image path: " << image_path;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
struct dirent *image_dir = readdir(root);
|
||||
size_t count = 0;
|
||||
while (image_dir != nullptr) {
|
||||
string file_name(image_dir->d_name);
|
||||
if (file_name != "." && file_name != "..") {
|
||||
const std::string file_path = image_path + "/" + file_name;
|
||||
if (multi_input || config_param_.batch_count == 0) {
|
||||
this->AddImage(file_path, input_i);
|
||||
count++;
|
||||
} else if (count < config_param_.batch_count) {
|
||||
this->AddImage(file_path, input_i);
|
||||
count++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
image_dir = readdir(root);
|
||||
}
|
||||
std::sort(images_[input_i].begin(), images_[input_i].end());
|
||||
if (config_param_.batch_count != 0 && config_param_.batch_count < images_[input_i].size()) {
|
||||
images_[input_i].resize(config_param_.batch_count);
|
||||
}
|
||||
closedir(root);
|
||||
input_i++;
|
||||
}
|
||||
return RET_OK;
|
||||
return CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
|
||||
}
|
||||
|
||||
STATUS Calibrator::ReadConfig() {
|
||||
if (config_path_.empty() || config_path_.length() > PATH_MAX) {
|
||||
MS_LOG(ERROR) << "invalid config path!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
// check whether config file path is valid
|
||||
char *resolved_path = new (std::nothrow) char[PATH_MAX]{0};
|
||||
if (resolved_path == nullptr) {
|
||||
MS_LOG(ERROR) << "New an object failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
#ifdef _WIN32
|
||||
if (_fullpath(resolved_path, config_path_.c_str(), 1024) != nullptr) {
|
||||
config_path_ = string(resolved_path);
|
||||
}
|
||||
#else
|
||||
if (realpath(config_path_.c_str(), resolved_path) != nullptr) {
|
||||
config_path_ = string(resolved_path);
|
||||
}
|
||||
#endif
|
||||
std::ifstream fs(config_path_.c_str(), std::ifstream::in);
|
||||
if (!fs.is_open()) {
|
||||
MS_LOG(ERROR) << "config proto file %s open failed: " << config_path_;
|
||||
delete[] resolved_path;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
std::string line;
|
||||
while (std::getline(fs, line)) {
|
||||
auto index = line.find('=');
|
||||
if (index == std::string::npos) {
|
||||
MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check";
|
||||
delete[] resolved_path;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto key = line.substr(0, index);
|
||||
auto value = line.substr(index + 1);
|
||||
Trim(&key);
|
||||
Trim(&value);
|
||||
if (key == "image_path") {
|
||||
auto &raw_image_paths = value;
|
||||
auto ind = raw_image_paths.find(',');
|
||||
while (ind != std::string::npos) {
|
||||
auto image_path = raw_image_paths.substr(0, ind);
|
||||
Trim(&image_path);
|
||||
config_param_.image_paths.push_back(image_path);
|
||||
raw_image_paths = raw_image_paths.substr(ind + 1);
|
||||
Trim(&raw_image_paths);
|
||||
ind = raw_image_paths.find(',');
|
||||
}
|
||||
config_param_.image_paths.push_back(raw_image_paths);
|
||||
} else if (key == "batch_count") {
|
||||
config_param_.batch_count = std::stoul(value);
|
||||
} else if (key == "thread_num") {
|
||||
config_param_.thread_num = std::stoul(value);
|
||||
} else if (key == "method_x") {
|
||||
if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) {
|
||||
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
|
||||
} else {
|
||||
config_param_.method_x = value;
|
||||
}
|
||||
} else if (key == "bias_correction") {
|
||||
std::for_each(value.begin(), value.end(), ::tolower);
|
||||
if (value == "true") {
|
||||
config_param_.bias_correction = true;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(WARNING) << "unsupported parameter";
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &path : config_param_.image_paths) {
|
||||
MS_LOG(DEBUG) << "calibration data_path: " << path;
|
||||
}
|
||||
MS_LOG(DEBUG) << "batch_count: " << config_param_.batch_count << " "
|
||||
<< "method_x: " << config_param_.method_x << " "
|
||||
<< "thread_num: " << config_param_.thread_num << " "
|
||||
<< "bias_correction: " << config_param_.bias_correction;
|
||||
|
||||
delete[] resolved_path;
|
||||
fs.close();
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS Calibrator::ReadConfig() { return ParseConfigFile(config_path_, &config_param_); }
|
||||
|
||||
Calibrator::Calibrator(string path, size_t bit_num, int quant_max, int quant_min)
|
||||
: config_path_(std::move(path)), bit_num_(bit_num), quant_max_(quant_max), quant_min_(quant_min) {}
|
||||
|
@ -621,8 +456,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::DoWeightQuant(const AnfNodePtr &weight, std::shared_ptr<PrimitiveC> primitive_c,
|
||||
bool perchanel) const {
|
||||
STATUS PostTrainingQuantizer::DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight,
|
||||
std::shared_ptr<PrimitiveC> primitive_c, bool perchanel) const {
|
||||
MS_ASSERT(weight != nullptr);
|
||||
MS_ASSERT(lite_primitive != nullptr);
|
||||
// perlayer
|
||||
|
@ -640,8 +475,21 @@ STATUS PostTrainingQuantizer::DoWeightQuant(const AnfNodePtr &weight, std::share
|
|||
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto status = QuantFilter<int8_t>(paramValue, std::move(primitive_c), QuantType_PostTraining, quant_max, quant_min,
|
||||
bit_num, perchanel);
|
||||
auto bit_num_t = bit_num;
|
||||
auto quant_max_t = quant_max;
|
||||
auto quant_min_t = quant_min;
|
||||
if (calibrator_->config_param_.mixed) {
|
||||
auto opname_iter = opname_bit_.find(op_name);
|
||||
if (opname_iter == opname_bit_.end()) {
|
||||
MS_LOG(WARNING) << op_name << " not in the opname_bit_ map";
|
||||
} else {
|
||||
bit_num_t = opname_iter->second;
|
||||
quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
|
||||
quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
|
||||
}
|
||||
}
|
||||
auto status = QuantFilter<int8_t>(paramValue, std::move(primitive_c), QuantType_PostTraining, quant_max_t,
|
||||
quant_min_t, bit_num_t, perchanel);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||
return status;
|
||||
|
@ -921,7 +769,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
}
|
||||
if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) {
|
||||
MS_LOG(DEBUG) << "this parameter do quant";
|
||||
DoWeightQuant(input_node, primitive_c, false);
|
||||
DoWeightQuant(op_name, input_node, primitive_c, false);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "this parameter no need to do quant";
|
||||
}
|
||||
|
@ -943,7 +791,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
op_type == PrimitiveType_FullConnection) {
|
||||
perchannel = true;
|
||||
}
|
||||
DoWeightQuant(weight, primitive_c, perchannel);
|
||||
DoWeightQuant(op_name, weight, primitive_c, perchannel);
|
||||
// do bias quant
|
||||
if (cnode->inputs().size() == 4) {
|
||||
auto bias = cnode->input(3);
|
||||
|
@ -982,18 +830,8 @@ STATUS PostTrainingQuantizer::UpdateDivergInverval() {
|
|||
* 3. save quantied node
|
||||
**/
|
||||
STATUS PostTrainingQuantizer::PreProcess() {
|
||||
if (this->calibrator_ == nullptr) {
|
||||
MS_LOG(ERROR) << "calibrator is null!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// 1. generate config param
|
||||
STATUS status = calibrator_->ReadConfig();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "read proto text failed!";
|
||||
return status;
|
||||
}
|
||||
// 2. collect image files
|
||||
status = calibrator_->CollectImages();
|
||||
auto status = calibrator_->CollectImages();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "collect images failed!";
|
||||
return status;
|
||||
|
@ -1560,55 +1398,49 @@ STATUS PostTrainingQuantizer::ComputeThreshold() { return this->calibrator_->Com
|
|||
|
||||
STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||
MS_LOG(INFO) << "start to parse config file";
|
||||
STATUS status = PreProcess();
|
||||
if (this->calibrator_ == nullptr) {
|
||||
MS_LOG(ERROR) << "calibrator is null!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// 1. generate config param
|
||||
STATUS status = calibrator_->ReadConfig();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "read proto text failed!";
|
||||
return status;
|
||||
}
|
||||
|
||||
if (calibrator_->config_param_.mixed) {
|
||||
// get opname_bit map
|
||||
auto weight_quant_func_graph = CopyFuncGraph(func_graph);
|
||||
if (weight_quant_func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "CopyFuncGraph error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
WeightQuantizer weight_quantizer(weight_quant_func_graph, calibrator_->config_param_);
|
||||
weight_quantizer.flags = flags;
|
||||
status = weight_quantizer.DoQuantize(weight_quant_func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do mix weight quant error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
opname_bit_ = weight_quantizer.opname_bit_;
|
||||
}
|
||||
|
||||
status = PreProcess();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "do pre process failed!";
|
||||
return status;
|
||||
}
|
||||
|
||||
// anf -- fb
|
||||
auto meta_graph = Export(func_graph, true, true);
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Export to meta_graph return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// transform
|
||||
GraphDefTransform transform;
|
||||
transform.SetGraphDef(meta_graph);
|
||||
flags.quantType = schema::QuantType_QUANT_NONE;
|
||||
status = transform.Transform(flags);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "FBTransform model failed " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_LOG(INFO) << "start create session";
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto offset = schema::MetaGraph::Pack(builder, meta_graph);
|
||||
builder.Finish(offset);
|
||||
schema::FinishMetaGraphBuffer(builder, offset);
|
||||
size_t size = builder.GetSize();
|
||||
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
|
||||
if (content == nullptr) {
|
||||
MS_LOG(ERROR) << "GetBufferPointer nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto model = lite::Model::Import(content, size);
|
||||
|
||||
Context ctx;
|
||||
ctx.thread_num_ = calibrator_->GetThreadNum();
|
||||
|
||||
fp32_session_ = dynamic_cast<mindspore::lite::LiteSession *>(session::LiteSession::CreateSession(&ctx));
|
||||
fp32_session_ = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
|
||||
if (fp32_session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "create session failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto ret = fp32_session_->CompileGraph(model);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "compile graph error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "start to update divergence's max value";
|
||||
status = DoInference();
|
||||
if (status != RET_OK) {
|
||||
|
@ -1647,49 +1479,13 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
|
||||
if (calibrator_->GetBiasCorrection()) {
|
||||
// init in8 session
|
||||
// anf -- fb
|
||||
auto int8_meta_graph = Export(func_graph, true, true);
|
||||
if (int8_meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Export to int8_meta_graph return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
// transform
|
||||
GraphDefTransform fb_transform;
|
||||
fb_transform.SetGraphDef(int8_meta_graph);
|
||||
MS_LOG(INFO) << "create quant session";
|
||||
flags.quantType = schema::QuantType_PostTraining;
|
||||
status = fb_transform.Transform(flags);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "FBTransform model failed " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_LOG(INFO) << "start create quantized session";
|
||||
flatbuffers::FlatBufferBuilder int8_builder(1024);
|
||||
auto int8_offset = schema::MetaGraph::Pack(int8_builder, int8_meta_graph);
|
||||
int8_builder.Finish(int8_offset);
|
||||
schema::FinishMetaGraphBuffer(int8_builder, int8_offset);
|
||||
size = int8_builder.GetSize();
|
||||
auto *int8_content = reinterpret_cast<const char *>(int8_builder.GetBufferPointer());
|
||||
if (int8_content == nullptr) {
|
||||
MS_LOG(ERROR) << "GetBufferPointer nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto int8_model = lite::Model::Import(int8_content, size);
|
||||
|
||||
Context int8_ctx;
|
||||
int8_ctx.thread_num_ = calibrator_->GetThreadNum();
|
||||
int8_ctx.device_list_[0].device_info_.cpu_device_info_.cpu_bind_mode_ = HIGHER_CPU;
|
||||
|
||||
int8_session_ = dynamic_cast<mindspore::lite::LiteSession *>(session::LiteSession::CreateSession(&int8_ctx));
|
||||
int8_session_ = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
|
||||
if (int8_session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "create session failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
ret = int8_session_->CompileGraph(int8_model);
|
||||
if (ret != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "compile graph error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "do bias correction";
|
||||
status = BiasCorrection(func_graph);
|
||||
|
|
|
@ -28,6 +28,8 @@
|
|||
#include "tools/converter/quantizer/quantizer.h"
|
||||
#include "tools/converter/converter.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class Calibrator;
|
||||
|
@ -38,19 +40,8 @@ struct MaxMin {
|
|||
float max;
|
||||
};
|
||||
|
||||
const char kMethodMaxMin[] = "MAX_MIN";
|
||||
const char kMethodKL[] = "KL";
|
||||
const char kMethodOutlier[] = "RemovalOutlier";
|
||||
constexpr int kDefaultBinNumber = 2048;
|
||||
|
||||
struct ConfigParam {
|
||||
std::vector<std::string> image_paths;
|
||||
uint32_t batch_count{100};
|
||||
std::string method_x{kMethodKL};
|
||||
uint32_t thread_num{1};
|
||||
bool bias_correction{false};
|
||||
};
|
||||
|
||||
class PostTrainingQuantizer : public Quantizer {
|
||||
public:
|
||||
PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8,
|
||||
|
@ -64,14 +55,16 @@ class PostTrainingQuantizer : public Quantizer {
|
|||
int quant_min{INT8_MIN};
|
||||
|
||||
private:
|
||||
std::map<std::string, int> opname_bit_;
|
||||
|
||||
bool per_channel_{true};
|
||||
|
||||
TypeId target_type_{kNumberTypeInt8};
|
||||
|
||||
std::unique_ptr<Calibrator> calibrator_;
|
||||
|
||||
mindspore::lite::LiteSession *fp32_session_;
|
||||
mindspore::lite::LiteSession *int8_session_;
|
||||
session::LiteSession *fp32_session_{nullptr};
|
||||
session::LiteSession *int8_session_{nullptr};
|
||||
|
||||
std::map<std::string, std::vector<float>> fp32_op_input_map; // concurency
|
||||
std::map<std::string, std::vector<float>> fp32_op_output_ch_mean_map; // concurency
|
||||
|
@ -112,7 +105,8 @@ class PostTrainingQuantizer : public Quantizer {
|
|||
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min,
|
||||
const std::shared_ptr<PrimitiveC> &) const;
|
||||
|
||||
STATUS DoWeightQuant(const AnfNodePtr &weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel) const;
|
||||
STATUS DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight, std::shared_ptr<PrimitiveC> primitive_c,
|
||||
bool perchannel) const;
|
||||
|
||||
STATUS DoBiasQuant(const AnfNodePtr &bias, const std::shared_ptr<PrimitiveC> &primitive_c);
|
||||
STATUS Int8Inference();
|
||||
|
@ -213,13 +207,13 @@ class Calibrator {
|
|||
|
||||
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetOutputDivergInfo();
|
||||
|
||||
PostQuantConfig config_param_;
|
||||
|
||||
private:
|
||||
std::vector<std::vector<std::string>> images_; // multi_input, echo input has multi input data
|
||||
|
||||
std::string config_path_;
|
||||
|
||||
ConfigParam config_param_;
|
||||
|
||||
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> inputs_diverg_info_;
|
||||
|
||||
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> outputs_diverg_info_;
|
||||
|
@ -227,8 +221,6 @@ class Calibrator {
|
|||
size_t bit_num_;
|
||||
int quant_max_;
|
||||
int quant_min_;
|
||||
|
||||
void AddImage(const std::string &file, size_t index);
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_POSTRAINING_QUANTIZER_H
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#include "mindspore/lite/tools/converter/quantizer/quantize_util.h"
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
@ -26,6 +28,8 @@
|
|||
#include "src/common/utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
#include "mindspore/lite/include/version.h"
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
@ -83,10 +87,10 @@ bool QuantStrategy::CanConvOpQuantized(const CNodePtr &node) const {
|
|||
|
||||
bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
|
||||
MS_ASSERT(node != nullptr);
|
||||
if (!node->isa<CNode>()) {
|
||||
if (!node->isa<mindspore::CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = std::dynamic_pointer_cast<CNode>(node);
|
||||
auto cnode = std::dynamic_pointer_cast<mindspore::CNode>(node);
|
||||
auto type = NodePrimitiveType(cnode);
|
||||
static const std::vector<schema::PrimitiveType> int8OpList = {
|
||||
schema::PrimitiveType_Conv2D,
|
||||
|
@ -475,4 +479,307 @@ schema::PrimitiveType NodePrimitiveType(const CNodePtr &cnode) {
|
|||
}
|
||||
return (schema::PrimitiveType)primitive_c->Type();
|
||||
}
|
||||
|
||||
STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) {
|
||||
if (post_quant_config == nullptr) {
|
||||
MS_LOG(ERROR) << "post_quant_config is null.";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (config_file.empty() || config_file.length() > PATH_MAX) {
|
||||
MS_LOG(ERROR) << "invalid config path!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
// check whether config file path is valid
|
||||
auto resolved_path = std::make_unique<char[]>(PATH_MAX);
|
||||
if (resolved_path == nullptr) {
|
||||
MS_LOG(ERROR) << "New an object failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
#ifdef _WIN32
|
||||
if (_fullpath(resolved_path.get(), config_file.c_str(), 1024) != nullptr) {
|
||||
config_file = string(resolved_path.get());
|
||||
}
|
||||
#else
|
||||
if (realpath(config_file.c_str(), resolved_path.get()) != nullptr) {
|
||||
config_file = string(resolved_path.get());
|
||||
}
|
||||
#endif
|
||||
std::ifstream fs(config_file.c_str(), std::ifstream::in);
|
||||
if (!fs.is_open()) {
|
||||
MS_LOG(ERROR) << "config file open failed: " << config_file;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
std::string line;
|
||||
while (std::getline(fs, line)) {
|
||||
Trim(&line);
|
||||
if (line.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto index = line.find('=');
|
||||
if (index == std::string::npos) {
|
||||
MS_LOG(ERROR) << "the config file is invalid, can not find '=', please check";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto key = line.substr(0, index);
|
||||
auto value = line.substr(index + 1);
|
||||
Trim(&key);
|
||||
Trim(&value);
|
||||
if (key == "image_path") {
|
||||
auto &raw_image_paths = value;
|
||||
auto ind = raw_image_paths.find(',');
|
||||
while (ind != std::string::npos) {
|
||||
auto image_path = raw_image_paths.substr(0, ind);
|
||||
Trim(&image_path);
|
||||
post_quant_config->image_paths.push_back(image_path);
|
||||
raw_image_paths = raw_image_paths.substr(ind + 1);
|
||||
Trim(&raw_image_paths);
|
||||
ind = raw_image_paths.find(',');
|
||||
}
|
||||
post_quant_config->image_paths.push_back(raw_image_paths);
|
||||
} else if (key == "batch_count") {
|
||||
post_quant_config->batch_count = std::stoul(value);
|
||||
} else if (key == "thread_num") {
|
||||
post_quant_config->thread_num = std::stoul(value);
|
||||
} else if (key == "method_x") {
|
||||
if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) {
|
||||
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
|
||||
} else {
|
||||
post_quant_config->method_x = value;
|
||||
}
|
||||
} else if (key == "bias_correction") {
|
||||
std::for_each(value.begin(), value.end(), ::tolower);
|
||||
if (value == "true") {
|
||||
post_quant_config->bias_correction = true;
|
||||
}
|
||||
} else if (key == "mixed") {
|
||||
std::for_each(value.begin(), value.end(), ::tolower);
|
||||
if (value == "true") {
|
||||
post_quant_config->mixed = true;
|
||||
}
|
||||
} else if (key == "mean_error_threshold") {
|
||||
post_quant_config->mean_error_threshold = std::stof(value);
|
||||
} else {
|
||||
MS_LOG(WARNING) << "unsupported parameter: " << key;
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &path : post_quant_config->image_paths) {
|
||||
MS_LOG(DEBUG) << "calibration data_path: " << path;
|
||||
}
|
||||
MS_LOG(DEBUG) << "batch_count: " << post_quant_config->batch_count << "\n"
|
||||
<< "method_x: " << post_quant_config->method_x << "\n"
|
||||
<< "thread_num: " << post_quant_config->thread_num << "\n"
|
||||
<< "bias_correction: " << post_quant_config->bias_correction << "\n"
|
||||
<< "mixed: " << post_quant_config->mixed << "\n"
|
||||
<< "mean_error_threshold: " << post_quant_config->mean_error_threshold;
|
||||
post_quant_config->inited = true;
|
||||
fs.close();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
session::LiteSession *CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags,
|
||||
int thread_num) {
|
||||
auto meta_graph = Export(func_graph, true, true);
|
||||
if (meta_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Export to meta_graph failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// transform
|
||||
GraphDefTransform fb_transform;
|
||||
fb_transform.SetGraphDef(meta_graph);
|
||||
auto status = fb_transform.Transform(flags);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "FBTransform model failed";
|
||||
return nullptr;
|
||||
}
|
||||
meta_graph->version = Version();
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto offset = schema::MetaGraph::Pack(builder, meta_graph);
|
||||
builder.Finish(offset);
|
||||
schema::FinishMetaGraphBuffer(builder, offset);
|
||||
auto size = builder.GetSize();
|
||||
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
|
||||
if (content == nullptr) {
|
||||
MS_LOG(ERROR) << "GetBufferPointer return null";
|
||||
return nullptr;
|
||||
}
|
||||
auto model = lite::Model::Import(content, size);
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "Import model failed";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Context ctx;
|
||||
ctx.thread_num_ = thread_num;
|
||||
|
||||
auto session = session::LiteSession::CreateSession(&ctx);
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "create session failed.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
status = session->CompileGraph(model);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "CompileGraph error";
|
||||
return nullptr;
|
||||
}
|
||||
model->Free();
|
||||
return session;
|
||||
}
|
||||
|
||||
STATUS CollectCalibInputs(const std::vector<std::string> &input_dirs, size_t count_limited,
|
||||
std::vector<std::vector<std::string>> *inputs) {
|
||||
if (inputs == nullptr) {
|
||||
MS_LOG(ERROR) << "inputs is null";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto AddImage = [&inputs](const std::string &file, size_t index) {
|
||||
if (index >= inputs->size()) {
|
||||
MS_LOG(ERROR) << "images_ size: " << inputs->size() << " but input index: " << index;
|
||||
return;
|
||||
}
|
||||
struct stat buf {};
|
||||
if (stat(file.c_str(), &buf) == 0) {
|
||||
inputs->at(index).push_back(file);
|
||||
} else {
|
||||
MS_LOG(WARNING) << "invalid image file path: " << file;
|
||||
}
|
||||
};
|
||||
|
||||
inputs->resize(input_dirs.size());
|
||||
auto input_i = 0;
|
||||
bool multi_input = input_dirs.size() > 1;
|
||||
for (const auto &image_path : input_dirs) {
|
||||
DIR *root = opendir(image_path.c_str());
|
||||
if (root == nullptr) {
|
||||
MS_LOG(ERROR) << "invalid image path: " << image_path;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
struct dirent *image_dir = readdir(root);
|
||||
size_t count = 0;
|
||||
while (image_dir != nullptr) {
|
||||
string file_name(image_dir->d_name);
|
||||
if (file_name != "." && file_name != "..") {
|
||||
const std::string file_path = image_path + "/" + file_name;
|
||||
if (multi_input || count == 0) {
|
||||
AddImage(file_path, input_i);
|
||||
count++;
|
||||
} else if (count < count_limited) {
|
||||
AddImage(file_path, input_i);
|
||||
count++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
image_dir = readdir(root);
|
||||
}
|
||||
std::sort(inputs->at(input_i).begin(), inputs->at(input_i).end());
|
||||
if (count_limited != 0 && count_limited < inputs->at(input_i).size()) {
|
||||
inputs->at(input_i).resize(count_limited);
|
||||
}
|
||||
closedir(root);
|
||||
input_i++;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS CopyInputDataToTensor(size_t input_index, size_t image_index,
|
||||
const std::vector<std::vector<std::string>> &images, mindspore::tensor::MSTensor *tensor) {
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
if (input_index >= images.size()) {
|
||||
MS_LOG(ERROR) << "images_ size: " << images.size() << " but input_index: " << input_index;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (image_index >= images[input_index].size()) {
|
||||
MS_LOG(ERROR) << "images_[input_index] size: " << images[input_index].size() << " but image_index: " << image_index;
|
||||
return RET_ERROR;
|
||||
}
|
||||
string path = images[input_index][image_index];
|
||||
MS_LOG(INFO) << "read image: " << path;
|
||||
size_t size;
|
||||
char *bin_buf = ReadFile(path.c_str(), &size);
|
||||
if (bin_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "ReadFile return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto data = tensor->MutableData();
|
||||
if (data == nullptr) {
|
||||
MS_LOG(ERROR) << "Get tensor MutableData return nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (size != tensor->Size()) {
|
||||
MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size
|
||||
<< " input tensor size: " << tensor->Size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = memcpy_s(data, tensor->Size(), bin_buf, size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy_s error: " << ret;
|
||||
delete[] bin_buf;
|
||||
return RET_ERROR;
|
||||
}
|
||||
delete[] bin_buf;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &func_graph) {
|
||||
Cloner cloner({func_graph}, true, true, true, std::make_shared<TraceCopy>(), nullptr);
|
||||
auto new_func_graph = cloner[func_graph];
|
||||
|
||||
std::map<std::string, CNodePtr> old_cnode_map;
|
||||
for (const auto &cnode : func_graph->GetOrderedCnodes()) {
|
||||
old_cnode_map[cnode->fullname_with_scope()] = cnode;
|
||||
}
|
||||
|
||||
for (auto &cnode : new_func_graph->GetOrderedCnodes()) {
|
||||
auto cnode_name = cnode->fullname_with_scope();
|
||||
auto old_cnode_iter = old_cnode_map.find(cnode_name);
|
||||
if (old_cnode_iter == old_cnode_map.end()) {
|
||||
MS_LOG(ERROR) << "can not find node: " << cnode_name;
|
||||
return nullptr;
|
||||
}
|
||||
auto old_cnode = old_cnode_iter->second;
|
||||
auto inputs = cnode->inputs();
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
auto input_node = inputs[i];
|
||||
if (input_node->isa<Parameter>()) {
|
||||
auto param_node = input_node->cast<ParameterPtr>();
|
||||
if (param_node->has_default()) {
|
||||
ParamValueLitePtr old_param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
|
||||
auto new_param_value = std::make_shared<ParamValueLite>();
|
||||
|
||||
auto copyed_data = malloc(old_param_value->tensor_size());
|
||||
if (copyed_data == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc data error, size: " << old_param_value->tensor_size();
|
||||
return nullptr;
|
||||
}
|
||||
memcpy(copyed_data, old_param_value->tensor_addr(), old_param_value->tensor_size());
|
||||
|
||||
new_param_value->set_tensor_size(old_param_value->tensor_size());
|
||||
new_param_value->set_tensor_addr(copyed_data);
|
||||
new_param_value->set_tensor_shape(old_param_value->tensor_shape());
|
||||
new_param_value->set_format(old_param_value->format());
|
||||
new_param_value->set_tensor_type(old_param_value->tensor_type());
|
||||
|
||||
param_node->set_default_param(new_param_value);
|
||||
}
|
||||
|
||||
auto old_abstract_base = param_node->abstract();
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(old_abstract_base)) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter should be abstract tensor, " << param_node->name();
|
||||
return nullptr;
|
||||
}
|
||||
auto old_abstract = utils::cast<abstract::AbstractTensorPtr>(old_abstract_base);
|
||||
auto new_abstract = std::make_shared<abstract::AbstractTensor>(old_abstract->element()->GetTypeTrack(),
|
||||
old_abstract->GetShapeTrack());
|
||||
param_node->set_abstract(new_abstract);
|
||||
}
|
||||
} // end inputs loop
|
||||
} // end cnodes loop
|
||||
return new_func_graph;
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZER_UTIL_H
|
||||
|
||||
#include <dirent.h>
|
||||
#include <sys/stat.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <cmath>
|
||||
|
@ -35,11 +37,29 @@
|
|||
#include "ir/primitive.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "tools/converter/quantizer/bitpacking.h"
|
||||
#include "src/lite_session.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
#include "src/common/file_utils.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
static constexpr size_t UINT8_QUANTIZATION = 8;
|
||||
static constexpr size_t WEIGHT_INDEX = 1;
|
||||
|
||||
const char kMethodMaxMin[] = "MAX_MIN";
|
||||
const char kMethodKL[] = "KL";
|
||||
const char kMethodOutlier[] = "RemovalOutlier";
|
||||
|
||||
struct PostQuantConfig {
|
||||
std::vector<std::string> image_paths;
|
||||
uint32_t batch_count{100};
|
||||
std::string method_x{kMethodKL};
|
||||
uint32_t thread_num{1};
|
||||
bool bias_correction{false};
|
||||
bool mixed{false};
|
||||
float mean_error_threshold{0.04};
|
||||
bool inited{false};
|
||||
};
|
||||
|
||||
/**
|
||||
* 1. when op's weight size > mWeightSize just skip
|
||||
* 2. only do conv/deconv/convdepthwise/deconvdepthwise/mul/matmul/batchmatmul quantization
|
||||
|
@ -320,6 +340,21 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
// utils
|
||||
|
||||
schema::PrimitiveType NodePrimitiveType(const CNodePtr &cnode);
|
||||
|
||||
STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config);
|
||||
|
||||
session::LiteSession *CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags,
|
||||
int thread_num);
|
||||
|
||||
STATUS CollectCalibInputs(const std::vector<std::string> &input_dirs, size_t count_limited,
|
||||
std::vector<std::vector<std::string>> *inputs);
|
||||
|
||||
STATUS CopyInputDataToTensor(size_t input_index, size_t image_index,
|
||||
const std::vector<std::vector<std::string>> &images, mindspore::tensor::MSTensor *tensor);
|
||||
|
||||
FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &);
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <list>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include "src/common/common.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
||||
|
@ -36,6 +37,7 @@ bool WeightQuantizer::IsPosNum(const std::string &str) {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) {
|
||||
MS_ASSERT(config != nullptr);
|
||||
if (!WeightQuantizer::IsPosNum(config->quantWeightChannel)) {
|
||||
|
@ -57,28 +59,57 @@ STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize,
|
||||
|
||||
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config) : Quantizer(graph) {
|
||||
quant_strategy_ = std::make_unique<QuantStrategy>(0, 0);
|
||||
config_param_ = config;
|
||||
}
|
||||
|
||||
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const string &weightSize,
|
||||
const std::string &convWeightChannelThreshold, const std::string &bitNum)
|
||||
: Quantizer(graph) {
|
||||
this->config_file_ = config_file;
|
||||
auto quantSize = static_cast<size_t>(std::stoull(weightSize));
|
||||
this->bitNum = static_cast<size_t>(std::stoull(bitNum));
|
||||
this->bit_num_ = static_cast<size_t>(std::stoull(bitNum));
|
||||
auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold));
|
||||
mStrategy = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold);
|
||||
quant_max = (1 << (unsigned int)(this->bitNum - 1)) - 1;
|
||||
quant_min = -(1 << (unsigned int)(this->bitNum - 1));
|
||||
quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold);
|
||||
quant_max = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
|
||||
quant_min = -(1 << (unsigned int)(this->bit_num_ - 1));
|
||||
// parse type_id
|
||||
if (this->bitNum > 0 && this->bitNum <= 8) {
|
||||
if (this->bit_num_ > 0 && this->bit_num_ <= 8) {
|
||||
type_id = kNumberTypeInt8;
|
||||
} else if (this->bitNum <= 16) {
|
||||
} else if (this->bit_num_ <= 16) {
|
||||
type_id = kNumberTypeInt16;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "invalid input bits";
|
||||
}
|
||||
}
|
||||
|
||||
WeightQuantizer::~WeightQuantizer() { delete fp32_session_; }
|
||||
|
||||
STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node,
|
||||
std::shared_ptr<PrimitiveC> primitive_c) {
|
||||
// set dtype
|
||||
param_value->set_tensor_type(type_id);
|
||||
auto abstract_base = param_node->abstract();
|
||||
if (abstract_base == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||
abstract_tensor->element()->set_type(TypeIdToType(type_id));
|
||||
primitive_c->set_quant_type(schema::QuantType_WeightQuant);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
|
||||
for (auto &cnode : nodes) {
|
||||
if (!mStrategy->CanConvOpQuantized(cnode)) {
|
||||
if (!quant_strategy_->CanConvOpQuantized(cnode)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -108,36 +139,28 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
|
|||
}
|
||||
auto status = RET_ERROR;
|
||||
if (type_id == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
|
||||
status =
|
||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
|
||||
} else if (type_id == kNumberTypeInt16) {
|
||||
status =
|
||||
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
|
||||
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
// set dtype
|
||||
param_value->set_tensor_type(type_id);
|
||||
auto abstractBase = param_node->abstract();
|
||||
if (abstractBase == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
|
||||
status = SetAbstract(param_value, param_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
||||
abstractTensor->element()->set_type(TypeIdToType(type_id));
|
||||
primitive_c->set_quant_type(schema::QuantType_WeightQuant);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
|
||||
for (auto &node : nodes) {
|
||||
if (!mStrategy->CanMulOpQuantized(node)) {
|
||||
if (!quant_strategy_->CanMulOpQuantized(node)) {
|
||||
continue;
|
||||
}
|
||||
auto already_quant = false;
|
||||
|
@ -186,38 +209,271 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
|
|||
|
||||
auto status = RET_ERROR;
|
||||
if (type_id == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
|
||||
status =
|
||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
|
||||
} else if (type_id == kNumberTypeInt16) {
|
||||
status =
|
||||
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true);
|
||||
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
param_value->set_tensor_type(type_id);
|
||||
// set dtype
|
||||
auto abstractBase = param_node->abstract();
|
||||
if (abstractBase == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
|
||||
status = SetAbstract(param_value, param_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
||||
abstractTensor->element()->set_type(TypeIdToType(type_id));
|
||||
primitive_c->set_quant_type(schema::QuantType_WeightQuant);
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoQuantize(FuncGraphPtr funcGraph) {
|
||||
MS_ASSERT(funcGraph != nullptr);
|
||||
constexpr float relative_tolerance = 1e-5;
|
||||
constexpr float abs_tolerance = 1e-4;
|
||||
|
||||
template <typename T>
|
||||
float CompareOutputData(const std::unordered_map<std::string, mindspore::tensor::MSTensor *> &expected_tensor,
|
||||
const std::unordered_map<std::string, mindspore::tensor::MSTensor *> &compare_tensor) {
|
||||
auto valid_data = [](T data) -> bool { return (!std::isnan(data) && !std::isinf(data)); };
|
||||
|
||||
float total_mean_error = 0.0f;
|
||||
int tensor_cnt = expected_tensor.size();
|
||||
|
||||
if (tensor_cnt <= 0) {
|
||||
MS_LOG(ERROR) << "unexpected tensor_cnt: " << tensor_cnt;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
for (const auto &exp_tensor_pair : expected_tensor) {
|
||||
float mean_error = 0.0f;
|
||||
int error_cnt = 0;
|
||||
|
||||
auto exp_tensor_name = exp_tensor_pair.first;
|
||||
auto exp_tensor = exp_tensor_pair.second;
|
||||
auto cmp_tensor_find_iter = compare_tensor.find(exp_tensor_name);
|
||||
if (cmp_tensor_find_iter == compare_tensor.end()) {
|
||||
MS_LOG(ERROR) << "can not find: " << exp_tensor_name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto cmp_tensor = cmp_tensor_find_iter->second;
|
||||
|
||||
auto exp_tensor_shape = exp_tensor->shape();
|
||||
auto cmp_tensor_shape = cmp_tensor->shape();
|
||||
if (exp_tensor_shape != cmp_tensor_shape) {
|
||||
MS_LOG(ERROR) << "exp tensor shape not equal to cmp. exp_tensor_elem_cnt: " << exp_tensor->ElementsNum()
|
||||
<< " cmp_tensor_elem_cnt: " << cmp_tensor->ElementsNum();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto exp_data = static_cast<T *>(exp_tensor->MutableData());
|
||||
auto cmp_data = static_cast<T *>(cmp_tensor->MutableData());
|
||||
auto elem_cnt = exp_tensor->ElementsNum();
|
||||
for (int i = 0; i < elem_cnt; i++) {
|
||||
if (!valid_data(exp_data[i]) || !valid_data(cmp_data[i])) {
|
||||
MS_LOG(ERROR) << "data is not valid. exp: " << exp_data[i] << " cmp: " << cmp_data[i] << " index: " << i;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto tolerance = abs_tolerance + relative_tolerance * fabs(exp_data[i]);
|
||||
auto abs_error = std::fabs(exp_data[i] - cmp_data[i]);
|
||||
if (abs_error > tolerance) {
|
||||
if (fabs(exp_data[i] == 0)) {
|
||||
if (abs_error > 1e-5) {
|
||||
mean_error += abs_error;
|
||||
error_cnt++;
|
||||
} else {
|
||||
// it is ok, very close to 0
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
mean_error += abs_error / (fabs(exp_data[i]) + FLT_MIN);
|
||||
error_cnt++;
|
||||
}
|
||||
} else {
|
||||
// it is ok, no error
|
||||
continue;
|
||||
}
|
||||
} // end one tensor data loop
|
||||
total_mean_error += mean_error / elem_cnt;
|
||||
} // end tensor loop
|
||||
return total_mean_error / tensor_cnt;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) {
|
||||
// 0.1 Create Fp32 Session
|
||||
flags.quantType = schema::QuantType_QUANT_NONE;
|
||||
fp32_session_ = CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num);
|
||||
if (fp32_session_ == nullptr) {
|
||||
MS_LOG(ERROR) << "CreateSessoin fail";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto fp32_inputs = fp32_session_->GetInputs();
|
||||
// 0.2 Parse input calib files
|
||||
auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "CollectCalibInputs fail";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
for (auto iter = cnodes.end(); iter != cnodes.begin();) {
|
||||
auto cnode = *(--iter);
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is null.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
MS_LOG(DEBUG) << "process node: " << op_name
|
||||
<< " type: " << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive_c->Type());
|
||||
if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) {
|
||||
auto input_node = cnode->input(2);
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
MS_LOG(WARNING) << op_name << " the second input is not parameter";
|
||||
continue;
|
||||
}
|
||||
auto param_node = input_node->cast<ParameterPtr>();
|
||||
if (!param_node->has_default()) {
|
||||
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
|
||||
continue;
|
||||
}
|
||||
auto param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
|
||||
if (param_value == nullptr) {
|
||||
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
|
||||
continue;
|
||||
}
|
||||
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
|
||||
MS_LOG(WARNING) << op_name << " the second input type is not float";
|
||||
continue;
|
||||
}
|
||||
// copy origin data in case to recover
|
||||
auto *raw_data = static_cast<float *>(param_value->tensor_addr());
|
||||
auto elem_count = param_value->tensor_shape_size();
|
||||
auto origin_data = malloc(sizeof(float) * elem_count);
|
||||
auto ret = memcpy_s(origin_data, sizeof(float) * elem_count, raw_data, param_value->tensor_size());
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy fail: "
|
||||
<< " dst size: " << sizeof(float) * elem_count << " src size: " << param_value->tensor_size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
// 1. try quant
|
||||
for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) {
|
||||
type_id = TypeId::kNumberTypeInt8;
|
||||
int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
|
||||
int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
|
||||
|
||||
if (type_id == TypeId::kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
|
||||
quant_min_t, bit_num_t, true);
|
||||
} else if (type_id == TypeId::kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
|
||||
quant_min_t, bit_num_t, true);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unexpected type_id: " << type_id;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "quant filter fail.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = SetAbstract(param_value, param_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
// 2. evaluate the quant
|
||||
// 2.1 create quant session, get input, output tensor
|
||||
flags.quantType = schema::QuantType_WeightQuant;
|
||||
auto quant_session =
|
||||
std::unique_ptr<session::LiteSession>(CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num));
|
||||
if (quant_session == nullptr) {
|
||||
MS_LOG(ERROR) << "create session error: " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto quant_inputs = quant_session->GetInputs();
|
||||
|
||||
auto mean_error = 0.0f;
|
||||
if (fp32_inputs.size() != images_.size()) {
|
||||
MS_LOG(ERROR) << "model's input tensor cnt: " << fp32_inputs.size() << " != " << images_.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto image_cnt = images_.at(0).size();
|
||||
for (size_t i = 0; i < image_cnt; i++) {
|
||||
// set multi-input data
|
||||
for (size_t input_index = 0; input_index < fp32_inputs.size(); input_index++) {
|
||||
status = CopyInputDataToTensor(input_index, i, images_, fp32_inputs[input_index]);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "generate input data from images failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = CopyInputDataToTensor(input_index, i, images_, quant_inputs[input_index]);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "generate input data from images failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
std::future<STATUS> fp32_inference = std::async(
|
||||
std::launch::async, [](session::LiteSession *fp32_session) -> STATUS { return fp32_session->RunGraph(); },
|
||||
fp32_session_);
|
||||
|
||||
status = quant_session->RunGraph();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "quant session run error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = fp32_inference.get();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "fp32 session run error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// 3. compare betwen quant and fp32
|
||||
auto fp32_outputs = fp32_session_->GetOutputs();
|
||||
auto quant_outputs = quant_session->GetOutputs();
|
||||
mean_error += CompareOutputData<float>(fp32_outputs, quant_outputs);
|
||||
} // end_for: calib data loop
|
||||
mean_error = mean_error / image_cnt;
|
||||
|
||||
if (mean_error <= config_param_.mean_error_threshold) {
|
||||
MS_LOG(DEBUG) << "op: " << op_name << " got mixed bit: " << bit_num_t << " mean_error: " << mean_error;
|
||||
opname_bit_[op_name] = bit_num_t;
|
||||
break;
|
||||
} else if (bit_num_t != 8) {
|
||||
// recover
|
||||
param_value->set_tensor_size(sizeof(float) * elem_count);
|
||||
ret = memcpy_s(raw_data, param_value->tensor_size(), origin_data, sizeof(float) * elem_count);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy fail: "
|
||||
<< " src size: " << sizeof(float) * elem_count << " dst size: " << param_value->tensor_size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "op: " << op_name << " set bit: " << bit_num_t << " mean_error: " << mean_error;
|
||||
opname_bit_[op_name] = bit_num_t;
|
||||
}
|
||||
} // end bit loop
|
||||
free(origin_data);
|
||||
} // if: conv and matmul
|
||||
} // end loop: all cnode
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
STATUS ret;
|
||||
auto cnodes = funcGraph->GetOrderedCnodes();
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
|
||||
if (!config_file_.empty()) {
|
||||
ret = ParseConfigFile(config_file_, &config_param_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ReadConfig error.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
if (config_param_.mixed) {
|
||||
MS_LOG(INFO) << "Do mixed bit quantization";
|
||||
return DoMiexedQuant(func_graph);
|
||||
}
|
||||
|
||||
ret = DoConvQuantize(cnodes);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoConvQuantize failed :" << ret;
|
||||
|
|
|
@ -17,9 +17,12 @@
|
|||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H
|
||||
|
||||
#include <future>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <list>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tools/converter/quantizer/quantizer.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
@ -27,27 +30,37 @@
|
|||
#include "include/model.h"
|
||||
#include "base/base.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "src/lite_session.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class WeightQuantizer : public Quantizer {
|
||||
public:
|
||||
WeightQuantizer(FuncGraphPtr graph, const std::string &weightSize, const std::string &covWeightChannelThreshold,
|
||||
const std::string &bitNum);
|
||||
WeightQuantizer(FuncGraphPtr graph, const std::string &config_file, const std::string &weightSize,
|
||||
const std::string &covWeightChannelThreshold, const std::string &bitNum);
|
||||
WeightQuantizer(FuncGraphPtr graph, const PostQuantConfig &config);
|
||||
~WeightQuantizer();
|
||||
|
||||
~WeightQuantizer() = default;
|
||||
|
||||
STATUS DoQuantize(FuncGraphPtr funcGraph) override;
|
||||
STATUS DoQuantize(FuncGraphPtr func_graph) override;
|
||||
STATUS DoConvQuantize(const std::list<CNodePtr> &nodes);
|
||||
STATUS DoMulQuantize(const std::list<CNodePtr> &nodes);
|
||||
static STATUS WeightQuantInputCheck(const converter::Flags *config);
|
||||
static bool IsPosNum(const std::string &str);
|
||||
|
||||
int quant_max;
|
||||
int quant_min;
|
||||
TypeId type_id{kTypeUnknown};
|
||||
std::map<std::string, int> opname_bit_;
|
||||
|
||||
private:
|
||||
std::unique_ptr<QuantStrategy> mStrategy;
|
||||
size_t bitNum;
|
||||
std::unique_ptr<QuantStrategy> quant_strategy_;
|
||||
size_t bit_num_;
|
||||
std::string config_file_;
|
||||
PostQuantConfig config_param_;
|
||||
std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...]
|
||||
session::LiteSession *fp32_session_ = nullptr;
|
||||
|
||||
STATUS DoMiexedQuant(FuncGraphPtr);
|
||||
STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c);
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue