forked from mindspore-Ecosystem/mindspore
!5157 post quantization code review
Merge pull request !5157 from xutianchun/quant_code_review
This commit is contained in:
commit
4db3abe2fe
|
@ -15,9 +15,9 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "tools/converter/converter.h"
|
#include "tools/converter/converter.h"
|
||||||
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <memory>
|
|
||||||
#include "tools/converter/converter_flags.h"
|
#include "tools/converter/converter_flags.h"
|
||||||
#include "src/common/common.h"
|
#include "src/common/common.h"
|
||||||
#include "src/common/file_utils.h"
|
#include "src/common/file_utils.h"
|
||||||
|
@ -141,31 +141,11 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
|
||||||
return meta_graph;
|
return meta_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) {
|
void Converter::CreateQuantizer(FuncGraphPtr func_graph, const converter::Flags *flags) {
|
||||||
auto type = flags->quantType;
|
auto type = flags->quantType;
|
||||||
switch (type) {
|
if (type == mindspore::schema::QuantType_PostTraining) {
|
||||||
case mindspore::schema::QuantType_AwareTraining: {
|
MS_LOG(INFO) << "create post training quantizer.";
|
||||||
// mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean));
|
mQuantizer.reset(new quant::PostTrainingQuantizer(func_graph, flags->configFile, 8));
|
||||||
break;
|
|
||||||
}
|
|
||||||
// case mindspore::schema::QuantType_WeightQuant: {
|
|
||||||
// MS_LOG(INFO) << "create WeightQuantizer!";
|
|
||||||
// mQuantizer.reset(
|
|
||||||
// new quant::WeightQuantizer(funcGraph, flags->quantSize, flags->convWeightQuantChannelThreshold,
|
|
||||||
// flags->bitNum));
|
|
||||||
// break;
|
|
||||||
// }
|
|
||||||
case mindspore::schema::QuantType_PostTraining: {
|
|
||||||
MS_LOG(INFO) << "create PostTrainningQuantizer!";
|
|
||||||
mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case mindspore::schema::QuantType_QUANT_NONE:
|
|
||||||
MS_LOG(INFO) << "Not do quantization for model!";
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
MS_LOG(INFO) << "will support quntizer type " << flags->quantTypeIn.c_str() << " in the future!";
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int RunConverter(int argc, const char **argv) {
|
int RunConverter(int argc, const char **argv) {
|
||||||
|
|
|
@ -34,7 +34,7 @@ class Converter {
|
||||||
Converter();
|
Converter();
|
||||||
virtual ~Converter();
|
virtual ~Converter();
|
||||||
virtual schema::MetaGraphT *Convert(const lite::converter::Flags *flags);
|
virtual schema::MetaGraphT *Convert(const lite::converter::Flags *flags);
|
||||||
void CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags);
|
void CreateQuantizer(FuncGraphPtr func_graph, const converter::Flags *flags);
|
||||||
void FreeFuncGraph(const FuncGraphPtr &func_graph);
|
void FreeFuncGraph(const FuncGraphPtr &func_graph);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
|
@ -80,7 +80,7 @@ struct DivergInfo {
|
||||||
this->interval = max_value / static_cast<float>(bin_num);
|
this->interval = max_value / static_cast<float>(bin_num);
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS UpdateHistogram(const std::vector<float> &data, const std::vector<int> &shape) {
|
STATUS UpdateHistogram(const std::vector<float> &data) {
|
||||||
for (auto value : data) {
|
for (auto value : data) {
|
||||||
if (value == 0) {
|
if (value == 0) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -235,7 +235,7 @@ struct DivergInfo {
|
||||||
return std::make_pair(this->cnode, zero_point);
|
return std::make_pair(this->cnode, zero_point);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
std::unordered_map<CNodePtr, float> Calibrator::GetResult(
|
std::unordered_map<CNodePtr, float> Calibrator::GetScale(
|
||||||
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
|
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
|
||||||
std::unordered_map<CNodePtr, float> result;
|
std::unordered_map<CNodePtr, float> result;
|
||||||
for (auto iter = diverg_info->begin(); iter != diverg_info->end(); iter++) {
|
for (auto iter = diverg_info->begin(); iter != diverg_info->end(); iter++) {
|
||||||
|
@ -246,9 +246,9 @@ std::unordered_map<CNodePtr, float> Calibrator::GetResult(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
std::unordered_map<CNodePtr, int32_t> Calibrator::GetZeropoint(
|
std::unordered_map<CNodePtr, int32_t> Calibrator::GetZeropoint(
|
||||||
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *mDivergInfo) {
|
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
|
||||||
std::unordered_map<CNodePtr, int32_t> result;
|
std::unordered_map<CNodePtr, int32_t> result;
|
||||||
for (auto iter = mDivergInfo->begin(); iter != mDivergInfo->end(); iter++) {
|
for (auto iter = diverg_info->begin(); iter != diverg_info->end(); iter++) {
|
||||||
DivergInfo *info = iter->second.get();
|
DivergInfo *info = iter->second.get();
|
||||||
auto zeropoint = info->GetZeropoint();
|
auto zeropoint = info->GetZeropoint();
|
||||||
result.insert(zeropoint);
|
result.insert(zeropoint);
|
||||||
|
@ -257,9 +257,9 @@ std::unordered_map<CNodePtr, int32_t> Calibrator::GetZeropoint(
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<CNodePtr, MaxMin> Calibrator::GetMinMax(
|
std::map<CNodePtr, MaxMin> Calibrator::GetMinMax(
|
||||||
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *mDivergInfo) {
|
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
|
||||||
std::map<CNodePtr, MaxMin> result;
|
std::map<CNodePtr, MaxMin> result;
|
||||||
for (auto iter = mDivergInfo->begin(); iter != mDivergInfo->end(); iter++) {
|
for (auto iter = diverg_info->begin(); iter != diverg_info->end(); iter++) {
|
||||||
DivergInfo *info = iter->second.get();
|
DivergInfo *info = iter->second.get();
|
||||||
mindspore::lite::quant::MaxMin input_maxmin{};
|
mindspore::lite::quant::MaxMin input_maxmin{};
|
||||||
input_maxmin.min = info->min;
|
input_maxmin.min = info->min;
|
||||||
|
@ -284,10 +284,10 @@ std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *Calibrator::GetOut
|
||||||
return &this->output_diverg_info_;
|
return &this->output_diverg_info_;
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS Calibrator::RecordMaxValue(std::string opName, vector<float> data,
|
STATUS Calibrator::RecordMaxValue(const std::string &op_name, const vector<float> &data,
|
||||||
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *mDivergInfo) {
|
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
|
||||||
auto got = (*mDivergInfo).find(opName);
|
auto got = (*diverg_info).find(op_name);
|
||||||
if (got != (*mDivergInfo).end()) {
|
if (got != (*diverg_info).end()) {
|
||||||
((*got).second)->RecordMaxValue(data);
|
((*got).second)->RecordMaxValue(data);
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
@ -332,11 +332,11 @@ STATUS Calibrator::UpdateDivergInverval(std::unordered_map<std::string, std::uni
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS Calibrator::UpdateDataFrequency(std::string op_name, vector<float> data, vector<int> shape,
|
STATUS Calibrator::UpdateDataFrequency(const std::string &op_name, const vector<float> &data,
|
||||||
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
|
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
|
||||||
auto got = (*diverg_info).find(op_name);
|
auto got = (*diverg_info).find(op_name);
|
||||||
if (got != (*diverg_info).end()) {
|
if (got != (*diverg_info).end()) {
|
||||||
((*got).second)->UpdateHistogram(data, shape);
|
((*got).second)->UpdateHistogram(data);
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -347,10 +347,10 @@ STATUS Calibrator::AddQuantizedOp(CNodePtr node) {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
string node_name = node->fullname_with_scope();
|
string node_name = node->fullname_with_scope();
|
||||||
std::unique_ptr<DivergInfo> input_diverg =
|
std::unique_ptr<DivergInfo> input_diverg = std::unique_ptr<DivergInfo>(
|
||||||
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_, config_param_.method_x));
|
new DivergInfo(node, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, config_param_.method_x));
|
||||||
std::unique_ptr<DivergInfo> output_diverg =
|
std::unique_ptr<DivergInfo> output_diverg = std::unique_ptr<DivergInfo>(
|
||||||
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_, config_param_.method_x));
|
new DivergInfo(node, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, config_param_.method_x));
|
||||||
|
|
||||||
input_diverg_info_.insert(std::make_pair(string(node_name), std::move(input_diverg)));
|
input_diverg_info_.insert(std::make_pair(string(node_name), std::move(input_diverg)));
|
||||||
output_diverg_info_.insert(std::make_pair(string(node_name), std::move(output_diverg)));
|
output_diverg_info_.insert(std::make_pair(string(node_name), std::move(output_diverg)));
|
||||||
|
@ -359,29 +359,33 @@ STATUS Calibrator::AddQuantizedOp(CNodePtr node) {
|
||||||
|
|
||||||
void Calibrator::AddImage(const string file) {
|
void Calibrator::AddImage(const string file) {
|
||||||
auto exist = [](const string file) {
|
auto exist = [](const string file) {
|
||||||
struct stat buf;
|
struct stat buf{};
|
||||||
return stat(file.c_str(), &buf) == 0;
|
return stat(file.c_str(), &buf) == 0;
|
||||||
};
|
};
|
||||||
if (exist(file)) {
|
if (exist(file)) {
|
||||||
MS_LOG(INFO) << "load image: " << file;
|
MS_LOG(INFO) << "load image: " << file;
|
||||||
this->images_.push_back(file);
|
this->images_.push_back(file);
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(WARNING) << "Invaild image file path: " << file;
|
MS_LOG(WARNING) << "invalid image file path: " << file;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS Calibrator::GenerateInputData(const int index, mindspore::tensor::MSTensor *tensor) const {
|
STATUS Calibrator::GenerateInputData(int index, mindspore::tensor::MSTensor *tensor) const {
|
||||||
string path = images_[index];
|
string path = images_[index];
|
||||||
MS_LOG(INFO) << "read image: " << path;
|
MS_LOG(INFO) << "read image: " << path;
|
||||||
size_t size;
|
size_t size;
|
||||||
char *binBuf = ReadFile(path.c_str(), &size);
|
char *bin_buf = ReadFile(path.c_str(), &size);
|
||||||
auto data = tensor->MutableData();
|
auto data = tensor->MutableData();
|
||||||
if (size != tensor->Size()) {
|
if (size != tensor->Size()) {
|
||||||
MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size
|
MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size
|
||||||
<< " input tensor size: " << tensor->Size();
|
<< " input tensor size: " << tensor->Size();
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
memcpy(data, binBuf, size);
|
auto ret = memcpy_s(data, tensor->Size(), bin_buf, size);
|
||||||
|
if (ret != EOK) {
|
||||||
|
MS_LOG(ERROR) << "memcpy_s error: " << ret;
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -467,7 +471,7 @@ STATUS Calibrator::ReadConfig() {
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "image_path: " << config_param_.image_path << " "
|
MS_LOG(DEBUG) << "image_path: " << config_param_.image_path << " "
|
||||||
<< "batch_count: " << config_param_.batch_count << " "
|
<< "batch_count: " << config_param_.batch_count << " "
|
||||||
<< "mothod_x: " << config_param_.method_x << " "
|
<< "method_x: " << config_param_.method_x << " "
|
||||||
<< "thread_num: " << config_param_.thread_num;
|
<< "thread_num: " << config_param_.thread_num;
|
||||||
|
|
||||||
delete[] resolved_path;
|
delete[] resolved_path;
|
||||||
|
@ -475,8 +479,8 @@ STATUS Calibrator::ReadConfig() {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
Calibrator::Calibrator(string path, size_t bitNum, int quantMax, int quantMin)
|
Calibrator::Calibrator(string path, size_t bit_num, int quant_max, int quant_min)
|
||||||
: config_path_(path), bit_num_(bitNum), quant_max_(quantMax), quant_min_(quantMin) {}
|
: config_path_(path), bit_num_(bit_num), quant_max_(quant_max), quant_min_(quant_min) {}
|
||||||
|
|
||||||
PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type,
|
PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type,
|
||||||
bool per_channel)
|
bool per_channel)
|
||||||
|
@ -669,11 +673,11 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
|
||||||
|
|
||||||
STATUS PostTrainingQuantizer::QuantNode() {
|
STATUS PostTrainingQuantizer::QuantNode() {
|
||||||
auto input_min_max = this->calibrator_->GetMinMax(this->calibrator_->GetInputDivergInfo());
|
auto input_min_max = this->calibrator_->GetMinMax(this->calibrator_->GetInputDivergInfo());
|
||||||
auto input_scale = this->calibrator_->GetResult(this->calibrator_->GetInputDivergInfo());
|
auto input_scale = this->calibrator_->GetScale(this->calibrator_->GetInputDivergInfo());
|
||||||
auto input_zero_point = this->calibrator_->GetZeropoint(this->calibrator_->GetInputDivergInfo());
|
auto input_zero_point = this->calibrator_->GetZeropoint(this->calibrator_->GetInputDivergInfo());
|
||||||
|
|
||||||
auto output_min_max = this->calibrator_->GetMinMax(this->calibrator_->GetOutputDivergInfo());
|
auto output_min_max = this->calibrator_->GetMinMax(this->calibrator_->GetOutputDivergInfo());
|
||||||
auto output_scale = this->calibrator_->GetResult(this->calibrator_->GetOutputDivergInfo());
|
auto output_scale = this->calibrator_->GetScale(this->calibrator_->GetOutputDivergInfo());
|
||||||
auto output_zeropoint = this->calibrator_->GetZeropoint(this->calibrator_->GetOutputDivergInfo());
|
auto output_zeropoint = this->calibrator_->GetZeropoint(this->calibrator_->GetOutputDivergInfo());
|
||||||
|
|
||||||
auto cnodes = funcGraph->GetOrderedCnodes();
|
auto cnodes = funcGraph->GetOrderedCnodes();
|
||||||
|
@ -803,7 +807,7 @@ STATUS PostTrainingQuantizer::PreProcess() {
|
||||||
// from user input
|
// from user input
|
||||||
QuantStrategy strategy(10);
|
QuantStrategy strategy(10);
|
||||||
auto cnodes = funcGraph->GetOrderedCnodes();
|
auto cnodes = funcGraph->GetOrderedCnodes();
|
||||||
for (auto cnode : cnodes) {
|
for (auto &cnode : cnodes) {
|
||||||
AnfNodePtr anf = std::dynamic_pointer_cast<AnfNode>(cnode);
|
AnfNodePtr anf = std::dynamic_pointer_cast<AnfNode>(cnode);
|
||||||
if (strategy.CanOpPostQuantized(anf)) {
|
if (strategy.CanOpPostQuantized(anf)) {
|
||||||
MS_LOG(INFO) << "node: " << cnode->fullname_with_scope() << " will be quantized";
|
MS_LOG(INFO) << "node: " << cnode->fullname_with_scope() << " will be quantized";
|
||||||
|
@ -813,16 +817,15 @@ STATUS PostTrainingQuantizer::PreProcess() {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS PostTrainingQuantizer::CheckTensorVec(const std::string &nodeName,
|
STATUS PostTrainingQuantizer::CheckTensorVec(const std::string &node_name,
|
||||||
const std::vector<mindspore::tensor::MSTensor *> &tensorVec) const {
|
const std::vector<mindspore::tensor::MSTensor *> &tensor_vec) const {
|
||||||
if (tensorVec.size() < 1) {
|
if (tensor_vec.size() < 1) {
|
||||||
MS_LOG(ERROR) << "node: " << nodeName << " input tensors is 0";
|
MS_LOG(ERROR) << "node: " << node_name << " input tensors is 0";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto *tensor = tensorVec[0];
|
auto *tensor = tensor_vec[0];
|
||||||
if (tensor->data_type() != kNumberTypeFloat32) {
|
if (tensor->data_type() != kNumberTypeFloat32) {
|
||||||
//&& tensor->RefCount() != MSCONST_WEIGHT_REFCOUNT
|
MS_LOG(DEBUG) << "node: " << node_name << " will not quantize"
|
||||||
MS_LOG(DEBUG) << "node: " << nodeName << " will not quantize"
|
|
||||||
<< " tensor data_type: " << tensor->data_type();
|
<< " tensor data_type: " << tensor->data_type();
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -856,8 +859,8 @@ STATUS PostTrainingQuantizer::DoInference() {
|
||||||
}
|
}
|
||||||
auto tensor = beforeInputs[0];
|
auto tensor = beforeInputs[0];
|
||||||
const float *tData = static_cast<const float *>(tensor->MutableData());
|
const float *tData = static_cast<const float *>(tensor->MutableData());
|
||||||
size_t shapeSize = tensor->ElementsNum();
|
size_t elem_count = tensor->ElementsNum();
|
||||||
vector<float> data(tData, tData + shapeSize);
|
vector<float> data(tData, tData + elem_count);
|
||||||
this->calibrator_->RecordMaxValue(callParam.name_callback_param, data, this->calibrator_->GetInputDivergInfo());
|
this->calibrator_->RecordMaxValue(callParam.name_callback_param, data, this->calibrator_->GetInputDivergInfo());
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
@ -871,8 +874,8 @@ STATUS PostTrainingQuantizer::DoInference() {
|
||||||
}
|
}
|
||||||
auto tensor = afterOutputs[0];
|
auto tensor = afterOutputs[0];
|
||||||
const float *tensor_data = static_cast<const float *>(tensor->MutableData());
|
const float *tensor_data = static_cast<const float *>(tensor->MutableData());
|
||||||
size_t shape_size = tensor->ElementsNum();
|
size_t elem_count = tensor->ElementsNum();
|
||||||
vector<float> data(tensor_data, tensor_data + shape_size);
|
vector<float> data(tensor_data, tensor_data + elem_count);
|
||||||
this->calibrator_->RecordMaxValue(callParam.name_callback_param, data, this->calibrator_->GetOutputDivergInfo());
|
this->calibrator_->RecordMaxValue(callParam.name_callback_param, data, this->calibrator_->GetOutputDivergInfo());
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
@ -910,7 +913,7 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
|
||||||
const float *tensor_data = static_cast<const float *>(tensor->MutableData());
|
const float *tensor_data = static_cast<const float *>(tensor->MutableData());
|
||||||
size_t shape_size = tensor->ElementsNum();
|
size_t shape_size = tensor->ElementsNum();
|
||||||
vector<float> data(tensor_data, tensor_data + shape_size);
|
vector<float> data(tensor_data, tensor_data + shape_size);
|
||||||
this->calibrator_->UpdateDataFrequency(callParam.name_callback_param, data, tensor->shape(),
|
this->calibrator_->UpdateDataFrequency(callParam.name_callback_param, data,
|
||||||
this->calibrator_->GetInputDivergInfo());
|
this->calibrator_->GetInputDivergInfo());
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
@ -926,7 +929,7 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
|
||||||
const float *tenosr_data = static_cast<const float *>(tensor->MutableData());
|
const float *tenosr_data = static_cast<const float *>(tensor->MutableData());
|
||||||
size_t shape_size = tensor->ElementsNum();
|
size_t shape_size = tensor->ElementsNum();
|
||||||
vector<float> data(tenosr_data, tenosr_data + shape_size);
|
vector<float> data(tenosr_data, tenosr_data + shape_size);
|
||||||
this->calibrator_->UpdateDataFrequency(call_param.name_callback_param, data, tensor->shape(),
|
this->calibrator_->UpdateDataFrequency(call_param.name_callback_param, data,
|
||||||
this->calibrator_->GetOutputDivergInfo());
|
this->calibrator_->GetOutputDivergInfo());
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
|
@ -39,14 +39,9 @@ struct MaxMin {
|
||||||
float max;
|
float max;
|
||||||
};
|
};
|
||||||
|
|
||||||
enum ImageFormat {
|
|
||||||
RGB = 0,
|
|
||||||
GRAY = 1,
|
|
||||||
BGR = 2,
|
|
||||||
};
|
|
||||||
|
|
||||||
const char kMethodMaxMin[] = "MAX_MIN";
|
const char kMethodMaxMin[] = "MAX_MIN";
|
||||||
const char kMethodKL[] = "KL";
|
const char kMethodKL[] = "KL";
|
||||||
|
constexpr int kDefaultBinNumber = 2048;
|
||||||
|
|
||||||
struct ConfigParam {
|
struct ConfigParam {
|
||||||
// ImageFormat imageFormat;
|
// ImageFormat imageFormat;
|
||||||
|
@ -78,7 +73,8 @@ class PostTrainingQuantizer : public Quantizer {
|
||||||
|
|
||||||
STATUS PreProcess();
|
STATUS PreProcess();
|
||||||
|
|
||||||
STATUS CheckTensorVec(const std::string &nodeName, const std::vector<mindspore::tensor::MSTensor *> &tensorVec) const;
|
STATUS CheckTensorVec(const std::string &node_name,
|
||||||
|
const std::vector<mindspore::tensor::MSTensor *> &tensor_vec) const;
|
||||||
|
|
||||||
STATUS DoInference();
|
STATUS DoInference();
|
||||||
|
|
||||||
|
@ -105,7 +101,7 @@ struct DivergInfo;
|
||||||
|
|
||||||
class Calibrator {
|
class Calibrator {
|
||||||
public:
|
public:
|
||||||
explicit Calibrator(std::string path, size_t quant_size, int quant_max, int quant_msin);
|
explicit Calibrator(std::string path, size_t bit_num, int quant_max, int quant_min);
|
||||||
|
|
||||||
~Calibrator() = default;
|
~Calibrator() = default;
|
||||||
|
|
||||||
|
@ -123,18 +119,18 @@ class Calibrator {
|
||||||
|
|
||||||
STATUS AddQuantizedOp(CNodePtr node);
|
STATUS AddQuantizedOp(CNodePtr node);
|
||||||
|
|
||||||
STATUS RecordMaxValue(std::string opName, std::vector<float> data,
|
STATUS RecordMaxValue(const std::string &op_name, const std::vector<float> &data,
|
||||||
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
|
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
|
||||||
|
|
||||||
STATUS UpdateDivergInverval(std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
|
STATUS UpdateDivergInverval(std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
|
||||||
|
|
||||||
STATUS UpdateDataFrequency(std::string op_name, std::vector<float> data, std::vector<int> shape,
|
STATUS UpdateDataFrequency(const std::string& op_name, const std::vector<float>& data,
|
||||||
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
|
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
|
||||||
void Dump();
|
void Dump();
|
||||||
|
|
||||||
STATUS ComputeThreshold();
|
STATUS ComputeThreshold();
|
||||||
|
|
||||||
std::unordered_map<CNodePtr, float> GetResult(
|
std::unordered_map<CNodePtr, float> GetScale(
|
||||||
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
|
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
|
||||||
|
|
||||||
std::unordered_map<CNodePtr, int32_t> GetZeropoint(
|
std::unordered_map<CNodePtr, int32_t> GetZeropoint(
|
||||||
|
|
|
@ -349,16 +349,12 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
|
||||||
quant_datas[index] = quant_data;
|
quant_datas[index] = quant_data;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto ret = memcpy_s(const_cast<float *>(raw_datas), weight->tensor_size(), quant_datas.data(),
|
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(),
|
||||||
elem_count * sizeof(int8_t));
|
elem_count * sizeof(int8_t));
|
||||||
if (ret != EOK) {
|
if (ret != EOK) {
|
||||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (quantType == QuantType_WeightQuant) {
|
|
||||||
PostBitPack(const_cast<float *>(raw_datas), elem_count, bitNum);
|
|
||||||
}
|
|
||||||
|
|
||||||
weight->set_tensor_size(elem_count * sizeof(int8_t));
|
weight->set_tensor_size(elem_count * sizeof(int8_t));
|
||||||
} else {
|
} else {
|
||||||
// channel at first
|
// channel at first
|
||||||
|
@ -407,9 +403,6 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
|
||||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (quantType == QuantType_WeightQuant) {
|
|
||||||
PostBitPack(const_cast<float *>(raw_datas), elem_count, bitNum);
|
|
||||||
}
|
|
||||||
weight->set_tensor_size(elem_count * sizeof(int8_t));
|
weight->set_tensor_size(elem_count * sizeof(int8_t));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -441,9 +434,6 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
|
||||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
if (quantType == QuantType_WeightQuant) {
|
|
||||||
PostBitPack(raw_datas, elem_count, bitNum);
|
|
||||||
}
|
|
||||||
weight->set_tensor_size(elem_count * sizeof(int8_t));
|
weight->set_tensor_size(elem_count * sizeof(int8_t));
|
||||||
}
|
}
|
||||||
if (quant_params.empty()) {
|
if (quant_params.empty()) {
|
||||||
|
|
|
@ -51,7 +51,7 @@ class Quantizer {
|
||||||
|
|
||||||
virtual STATUS DetermineNodeQuantType();
|
virtual STATUS DetermineNodeQuantType();
|
||||||
|
|
||||||
virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0;
|
virtual STATUS DoQuantize(FuncGraphPtr func_graph) = 0;
|
||||||
|
|
||||||
mindspore::lite::converter::Flags flags;
|
mindspore::lite::converter::Flags flags;
|
||||||
protected:
|
protected:
|
||||||
|
|
Loading…
Reference in New Issue