Calibrator support setting symmetry

This commit is contained in:
yeyunpeng2020 2021-11-23 20:05:58 +08:00
parent 77bf581af5
commit dd1f9eec30
7 changed files with 167 additions and 165 deletions

View File

@ -161,5 +161,28 @@ inline float GetCosSimilarity(const void *vector_a, const void *vector_b, size_t
return 0;
}
}
template <typename T>
float KLDivergence(std::vector<T> p, std::vector<T> q) {
auto sum = 0.0f;
std::for_each(p.begin(), p.end(), [&sum](T item) { sum += item; });
std::for_each(p.begin(), p.end(), [sum](T &item) { item /= sum; });
sum = 0.0f;
std::for_each(q.begin(), q.end(), [&sum](T item) { sum += item; });
std::for_each(q.begin(), q.end(), [sum](T &item) { item /= sum; });
float result = 0.0f;
const int size = p.size();
for (int i = 0; i < size; ++i) {
if (p[i] != 0) {
if (q[i] == 0) {
result += 1.0f;
} else {
result += (p[i] * std::log((p[i]) / (q[i])));
}
}
}
return result;
}
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_STATISTIC_UTILS_H_

View File

@ -27,12 +27,7 @@ constexpr int kDefaultBinNumber = 2048;
}
int Calibrator::RecordMaxMinValue(const std::vector<float> &data,
const std::unique_ptr<DataDistribution> &diverg_info) {
auto ret = diverg_info->RecordMaxMinValue(data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Record max min value failed.";
return ret;
}
ret = diverg_info->RecordMaxMinValueArray(data);
auto ret = diverg_info->RecordMaxMinValueArray(data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Record max min value array failed.";
return ret;
@ -119,7 +114,7 @@ int Calibrator::AddQuantizedOp(const CNodePtr &cnode) {
auto size = cnode->inputs().size();
for (size_t i = 1; i < size; i++) {
std::unique_ptr<DataDistribution> input_diverg = std::make_unique<DataDistribution>(
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, activation_quant_method_);
MS_CHECK_TRUE_MSG(input_diverg != nullptr, RET_NULL_PTR, "input_diverg is nullptr.");
inputs_diverg_info_[node_name].insert({i - 1, std::move(input_diverg)});
}
@ -131,13 +126,13 @@ int Calibrator::AddQuantizedOp(const CNodePtr &cnode) {
MS_ASSERT(elements.size() > 1);
for (size_t i = 0; i < elements.size(); i++) {
std::unique_ptr<DataDistribution> output_diverg = std::make_unique<DataDistribution>(
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, activation_quant_method_);
MS_CHECK_TRUE_MSG(output_diverg != nullptr, RET_NULL_PTR, "output_diverg is nullptr.");
outputs_diverg_info_[node_name].insert({i, std::move(output_diverg)});
}
} else {
std::unique_ptr<DataDistribution> output_diverg = std::make_unique<DataDistribution>(
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, activation_quant_method_);
MS_CHECK_TRUE_MSG(output_diverg != nullptr, RET_NULL_PTR, "output_diverg is nullptr.");
outputs_diverg_info_[node_name].insert({0, std::move(output_diverg)});
}
@ -149,11 +144,40 @@ int Calibrator::GenerateInputData(const std::string &input_name, size_t image_in
return preprocess::PreProcess(data_pre_process_param_, input_name, image_index, tensor);
}
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *Calibrator::GetInputDivergInfo() {
return &this->inputs_diverg_info_;
}
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *Calibrator::GetOutputDivergInfo() {
return &this->outputs_diverg_info_;
int Calibrator::CollectDataDistribution(
const std::string &node_name, const std::vector<mindspore::tensor::MSTensor *> &tensors,
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *diverg_info_map,
CollectType collect_type) {
if (diverg_info_map->find(node_name) == diverg_info_map->end()) {
return RET_OK;
}
for (size_t i = 0; i < tensors.size(); i++) {
auto tensor = tensors[i];
if (tensor->IsConst() || tensor->data_type() != kNumberTypeFloat32) {
continue;
}
const auto *tensor_data = static_cast<const float *>(tensor->data());
if (tensor_data == nullptr) {
MS_LOG(ERROR) << tensor->tensor_name() << " tensor_data is nullptr.";
return RET_ERROR;
}
size_t elem_count = tensor->ElementsNum();
MS_CHECK_GT(elem_count, 0, RET_ERROR);
std::vector<float> data(tensor_data, tensor_data + elem_count);
if (collect_type == MIN_MAX) {
auto ret = RecordMaxMinValue(data, (*diverg_info_map)[node_name][i]);
if (ret != RET_OK) {
MS_LOG(ERROR) << tensor->tensor_name() << " record max min value failed.";
return RET_ERROR;
}
} else if (collect_type == KL_BIN) {
auto ret = UpdateDataFrequency(data, (*diverg_info_map)[node_name][i]);
if (ret != RET_OK) {
MS_LOG(ERROR) << tensor->tensor_name() << " update data frequency failed.";
return RET_ERROR;
}
}
}
return RET_OK;
}
} // namespace mindspore::lite::quant

View File

@ -26,23 +26,25 @@
#include "tools/converter/quantizer/data_distribution.h"
namespace mindspore::lite::quant {
enum CollectType {
MIN_MAX,
KL_BIN,
};
class Calibrator {
public:
explicit Calibrator(size_t bit_num, int quant_max, int quant_min)
: bit_num_(bit_num), quant_max_(quant_max), quant_min_(quant_min) {}
Calibrator(size_t bit_num, int quant_max, int quant_min, ActivationQuantizedMethod activation_quant_method,
const preprocess::DataPreProcessParam &data_pre_process_param, bool symmetry = true)
: bit_num_(bit_num),
quant_max_(quant_max),
quant_min_(quant_min),
symmetry_(symmetry),
activation_quant_method_(activation_quant_method),
data_pre_process_param_(data_pre_process_param) {}
~Calibrator() = default;
int GenerateInputData(const std::string &input_name, size_t image_index, mindspore::tensor::MSTensor *tensor) const;
size_t GetBatchNum() const { return data_pre_process_param_.calibrate_size; }
uint32_t GetThreadNum() const { return thread_; }
bool GetBiasCorrection() const { return full_quant_param_.bias_correction; }
size_t GetInputNum() const { return data_pre_process_param_.calibrate_path_vector.size(); }
int AddQuantizedOp(const CNodePtr &cnode);
int RecordMaxMinValue(const std::vector<float> &data, const std::unique_ptr<DataDistribution> &diverg_info);
@ -53,25 +55,34 @@ class Calibrator {
int ComputeThreshold();
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *GetInputDivergInfo();
size_t GetBatchNum() const { return data_pre_process_param_.calibrate_size; }
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *GetOutputDivergInfo();
size_t GetInputNum() const { return data_pre_process_param_.calibrate_path_vector.size(); }
FullQuantParam full_quant_param_;
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *GetInputDivergInfo() {
return &this->inputs_diverg_info_;
}
preprocess::DataPreProcessParam data_pre_process_param_;
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *GetOutputDivergInfo() {
return &this->outputs_diverg_info_;
}
int thread_ = 4;
int CollectDataDistribution(
const std::string &node_name, const std::vector<mindspore::tensor::MSTensor *> &tensors,
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *diverg_info_map,
CollectType collect_type);
private:
// {node_name,{tensor_index,DataDistribution}}
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> inputs_diverg_info_;
// {node_name,{tensor_index,DataDistribution}}
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> outputs_diverg_info_;
size_t bit_num_;
int quant_max_;
int quant_min_;
bool symmetry_;
ActivationQuantizedMethod activation_quant_method_;
preprocess::DataPreProcessParam data_pre_process_param_;
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER__CALIBRATOR_H

View File

@ -19,15 +19,9 @@
#include <vector>
#include <utility>
#include <set>
namespace mindspore::lite::quant {
int DataDistribution::RecordMaxMinValue(const std::vector<float> &data) {
for (float val : data) {
max_ = std::max(val, max_);
min_ = std::min(val, min_);
}
return RET_OK;
}
#include "tools/common/statistic_utils.h"
namespace mindspore::lite::quant {
int DataDistribution::RecordMaxMinValueArray(const std::vector<float> &data) {
if (data.empty()) {
return RET_ERROR;
@ -38,13 +32,15 @@ int DataDistribution::RecordMaxMinValueArray(const std::vector<float> &data) {
max_num = std::max(val, max_num);
min_num = std::min(val, min_num);
}
real_max_ = std::max(max_num, real_max_);
real_min_ = std::min(min_num, real_min_);
this->max_datas_.emplace_back(max_num);
this->min_datas_.emplace_back(min_num);
return RET_OK;
}
void DataDistribution::UpdateInterval() {
auto max_value = std::max(fabs(this->max_), fabs(this->min_));
auto max_value = std::max(fabs(this->real_max_), fabs(this->real_min_));
MS_ASSERT(bin_num_ != 0);
this->interval_ = max_value / static_cast<float>(bin_num_);
}
@ -139,15 +135,7 @@ void DataDistribution::HandleBinForKL(int quant_bint_nums, int bin_index, std::v
}
int DataDistribution::ComputeThreshold() {
if (activation_quant_method_ == MAX_MIN) {
this->best_T_ = std::max(fabs(this->max_), fabs(this->min_));
MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T_;
return RET_OK;
}
if (activation_quant_method_ == REMOVAL_OUTLIER && !this->min_datas_.empty()) {
this->percent_result_ = OutlierMethod(min_datas_, max_datas_);
this->best_T_ = std::max(std::fabs(percent_result_.first), std::fabs(percent_result_.second));
if (activation_quant_method_ != KL) {
return RET_OK;
}
@ -163,28 +151,7 @@ int DataDistribution::ComputeThreshold() {
after_threshold_sum -= this->histogram_[i];
// handle bins for computing KL.
HandleBinForKL(INT8_MAX + 1, i, &quantized_histogram, &expanded_histogram);
auto KLDivergence = [](std::vector<float> p, std::vector<float> q) {
auto sum = 0.0f;
std::for_each(p.begin(), p.end(), [&sum](float item) { sum += item; });
std::for_each(p.begin(), p.end(), [sum](float &item) { item /= sum; });
sum = 0.0f;
std::for_each(q.begin(), q.end(), [&sum](float item) { sum += item; });
std::for_each(q.begin(), q.end(), [sum](float &item) { item /= sum; });
float result = 0.0f;
const int size = p.size();
for (int i = 0; i < size; ++i) {
if (p[i] != 0) {
if (q[i] == 0) {
result += 1.0f;
} else {
result += (p[i] * std::log((p[i]) / (q[i])));
}
}
}
return result;
};
const float kl = KLDivergence(reference_histogram, expanded_histogram);
const float kl = lite::KLDivergence(reference_histogram, expanded_histogram);
if (kl < min_kl) {
min_kl = kl;
threshold = i;
@ -192,39 +159,53 @@ int DataDistribution::ComputeThreshold() {
}
this->best_T_ = (static_cast<float>(threshold) + 0.5f) * this->interval_;
MS_LOG(DEBUG) << cnode_->fullname_with_scope() << " Best threshold bin index: " << threshold << " T: " << best_T_
<< " max: " << std::max(fabs(this->max_), fabs(this->min_));
<< " max: " << std::max(fabs(this->real_max_), fabs(this->real_min_));
return RET_OK;
}
float DataDistribution::GetScale() {
float max_value = this->best_T_;
float min_value = -max_value;
float DataDistribution::CalculateMinMaxScale() { return CalculateScale(this->real_min_, this->real_max_); }
if (this->activation_quant_method_ == REMOVAL_OUTLIER) {
min_value = percent_result_.first;
max_value = percent_result_.second;
float DataDistribution::CalculateRemovalOutlierScale() {
this->percent_result_ = OutlierMethod(min_datas_, max_datas_);
return CalculateScale(percent_result_.first, percent_result_.second);
}
float DataDistribution::CalculateScale(float min_value, float max_value) {
if (symmetry_) {
auto abs_max = std::max(fabs(min_value), fabs(max_value));
min_value = -abs_max;
max_value = abs_max;
}
this->encode_min_ = min_value;
this->encode_max_ = max_value;
// Optimize Handle 0.
MS_ASSERT(quant_max_ - quant_min_ > 0);
return (max_value - min_value) / (quant_max_ - quant_min_);
}
MS_CHECK_TRUE_MSG(quant_max_ - quant_min_ > 0, 0, "quant_max_ - quant_min_ <= 0");
this->scale_ = (max_value - min_value) / (quant_max_ - quant_min_);
MS_ASSERT(fabs(this->scale_) <= 0.0f);
float DataDistribution::CalculateKLScale() { return CalculateScale(this->best_T_, this->real_max_); }
float DataDistribution::GetScale() {
switch (this->activation_quant_method_) {
case MAX_MIN:
this->scale_ = CalculateMinMaxScale();
break;
case KL:
this->scale_ = CalculateKLScale();
break;
case REMOVAL_OUTLIER:
this->scale_ = CalculateRemovalOutlierScale();
break;
default:
MS_LOG(ERROR) << "Unsupported activation quant method " << this->activation_quant_method_;
return 0;
}
return this->scale_;
}
// Support for asymmetry in the future
int32_t DataDistribution::GetZeroPoint() {
int zero_point = 0;
if (quant_min_ == 0 && quant_max_ == UINT8_MAX) {
zero_point = INT8_MAX + 1;
} else if (quant_min_ == INT_LEAST8_MIN + 1 && quant_max_ == INT8_MAX) {
zero_point = 0;
} else {
MS_LOG(WARNING) << "unexpected quant range, quant_min_: " << quant_min_ << " quant_max_: " << quant_max_;
}
if (this->activation_quant_method_ == REMOVAL_OUTLIER) {
MS_CHECK_TRUE_MSG(fabs(scale_) <= 0.0f, 1, "fabs(scale) > 0.0f");
zero_point = std::round(quant_max_ - percent_result_.second / scale_);
}
int zero_point = std::round(quant_min_ - encode_min_ / scale_);
return zero_point;
}
} // namespace mindspore::lite::quant

View File

@ -25,21 +25,20 @@ class DataDistribution {
public:
DataDistribution() = default;
DataDistribution(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min,
ActivationQuantizedMethod activation_quant_method) {
ActivationQuantizedMethod activation_quant_method, bool symmetry = true) {
this->activation_quant_method_ = activation_quant_method;
this->cnode_ = std::move(cnode);
this->bin_num_ = bins;
this->bit_num_ = bits;
histogram_.resize(bin_num_);
max_ = -FLT_MAX;
min_ = FLT_MAX;
real_max_ = -FLT_MAX;
real_min_ = FLT_MAX;
this->quant_max_ = quant_max;
this->quant_min_ = quant_min;
std::fill(histogram_.begin(), histogram_.end(), 1.0e-7);
symmetry_ = symmetry;
}
int RecordMaxMinValue(const std::vector<float> &data);
int RecordMaxMinValueArray(const std::vector<float> &data);
void UpdateInterval();
@ -57,21 +56,33 @@ class DataDistribution {
int32_t GetZeroPoint();
float GetMax() { return this->max_; }
float GetRealMax() { return this->real_max_; }
float GetMin() { return this->min_; }
float GetRealMin() { return this->real_min_; }
float GetEncodeMin() { return this->encode_min_; }
float GetEncodeMax() { return this->encode_max_; }
CNodePtr GetCNode() { return this->cnode_; }
private:
float CalculateMinMaxScale();
float CalculateRemovalOutlierScale();
float CalculateKLScale();
float CalculateScale(float min_value, float max_value);
private:
std::vector<float> histogram_;
CNodePtr cnode_;
int bin_num_ = 0;
float interval_ = 0;
float max_ = 0.0f;
float min_ = 0.0f;
float real_max_ = 0.0f;
float real_min_ = 0.0f;
float best_T_ = 0.0f;
size_t bit_num_ = 0;
float encode_min_ = 0.0f;
float encode_max_ = 0.0f;
int quant_max_ = 255;
int quant_min_ = 0;
ActivationQuantizedMethod activation_quant_method_ = MAX_MIN;
@ -79,6 +90,7 @@ class DataDistribution {
std::vector<float> max_datas_;
std::pair<float, float> percent_result_{0.0, 0.0};
float scale_ = 0;
bool symmetry_ = true;
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DATA_DISTRIBUTION_H

View File

@ -140,11 +140,6 @@ FullQuantQuantizer::FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId t
} else {
MS_LOG(ERROR) << "unsupported quant value type: " << target_type;
}
calibrator_ = std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_);
if (calibrator_ == nullptr) {
MS_LOG(ERROR) << "create calibrator failed!";
return;
}
}
FullQuantQuantizer::~FullQuantQuantizer() {
@ -173,8 +168,8 @@ int FullQuantQuantizer::SetInOutQuantParam(const AnfNodePtr &input_node, const s
quant_param.scale = scale;
}
quant_param.zeroPoint = info->GetZeroPoint();
quant_param.max = info->GetMax();
quant_param.min = info->GetMin();
quant_param.max = info->GetEncodeMax();
quant_param.min = info->GetEncodeMin();
quant_param.numBits = bit_num_;
quant_param.narrowRange = true;
quant_param.inited = true;
@ -196,7 +191,6 @@ int FullQuantQuantizer::DoWeightQuant(const std::string &op_name, const AnfNodeP
const PrimitivePtr &primitive, bool per_channel, int input_index) const {
MS_ASSERT(weight != nullptr);
MS_ASSERT(primitive != nullptr);
// perlayer
if (!weight->isa<Parameter>()) {
MS_LOG(ERROR) << "not a parameter";
return RET_PARAM_INVALID;
@ -514,6 +508,13 @@ int FullQuantQuantizer::UpdateDivergeInterval() {
}
int FullQuantQuantizer::PreProcess() {
calibrator_ =
std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_, this->flags.fullQuantParam.activation_quant_method,
this->flags.dataPreProcessParam);
if (calibrator_ == nullptr) {
MS_LOG(ERROR) << "create calibrator failed!";
return RET_NULL_PTR;
}
auto cnodes = funcGraph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
auto anode = cnode->cast<AnfNodePtr>();
@ -557,43 +558,6 @@ int FullQuantQuantizer::CheckFp32TensorVec(const std::string &node_name,
return RET_OK;
}
int FullQuantQuantizer::CollectDataDistribution(
const std::string &node_name, const std::vector<mindspore::tensor::MSTensor *> &tensors,
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *diverg_info_map,
CollectType collect_type) {
if (diverg_info_map->find(node_name) == diverg_info_map->end()) {
return true;
}
for (size_t i = 0; i < tensors.size(); i++) {
auto tensor = tensors[i];
if (tensor->IsConst() || tensor->data_type() != kNumberTypeFloat32) {
continue;
}
const auto *tensor_data = static_cast<const float *>(tensor->data());
if (tensor_data == nullptr) {
MS_LOG(ERROR) << tensor->tensor_name() << " tensor_data is nullptr.";
return RET_ERROR;
}
size_t elem_count = tensor->ElementsNum();
MS_CHECK_GT(elem_count, 0, RET_ERROR);
vector<float> data(tensor_data, tensor_data + elem_count);
if (collect_type == MIN_MAX) {
auto ret = this->calibrator_->RecordMaxMinValue(data, (*diverg_info_map)[node_name][i]);
if (ret != RET_OK) {
MS_LOG(ERROR) << tensor->tensor_name() << " record max min value failed.";
return RET_ERROR;
}
} else if (collect_type == KL_BIN) {
auto ret = this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[node_name][i]);
if (ret != RET_OK) {
MS_LOG(ERROR) << tensor->tensor_name() << " update data frequency failed.";
return RET_ERROR;
}
}
}
return RET_OK;
}
int FullQuantQuantizer::DoInference(CollectType collect_type) {
// get input tensor
vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
@ -614,7 +578,7 @@ int FullQuantQuantizer::DoInference(CollectType collect_type) {
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
const CallBackParam &callParam) -> bool {
auto diverg_info_map = calibrator_->GetInputDivergInfo();
auto ret = CollectDataDistribution(callParam.node_name, beforeInputs, diverg_info_map, collect_type);
auto ret = calibrator_->CollectDataDistribution(callParam.node_name, beforeInputs, diverg_info_map, collect_type);
if (ret != RET_OK) {
MS_LOG(ERROR) << "CollectDataDistribution failed.";
return false;
@ -626,7 +590,7 @@ int FullQuantQuantizer::DoInference(CollectType collect_type) {
const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
const CallBackParam &callParam) -> bool {
auto diverg_info_map = calibrator_->GetOutputDivergInfo();
auto ret = CollectDataDistribution(callParam.node_name, afterOutputs, diverg_info_map, collect_type);
auto ret = calibrator_->CollectDataDistribution(callParam.node_name, afterOutputs, diverg_info_map, collect_type);
if (ret != RET_OK) {
MS_LOG(ERROR) << "CollectDataDistribution failed.";
return false;
@ -813,10 +777,6 @@ int FullQuantQuantizer::ComputeThreshold() { return this->calibrator_->ComputeTh
int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_LOG(INFO) << "start to parse config file";
CHECK_NULL_RETURN(this->calibrator_);
calibrator_->full_quant_param_ = flags.fullQuantParam;
calibrator_->data_pre_process_param_ = flags.dataPreProcessParam;
calibrator_->thread_ = flags.commonQuantParam.thread_num;
if (flags.dataPreProcessParam.calibrate_path.empty()) {
MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir.";
return RET_INPUT_PARAM_INVALID;
@ -831,7 +791,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
// anf -- fb
flags.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
MS_LOG(INFO) << "start create session";
auto sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
auto sm = CreateSessionByFuncGraph(func_graph, flags, this->flags.commonQuantParam.thread_num);
fp32_session_ = sm.session;
fp32_model_ = sm.model;
if (fp32_session_ == nullptr || fp32_model_ == nullptr) {
@ -878,11 +838,11 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
return RET_ERROR;
}
SessionModel int8_sm;
if (calibrator_->GetBiasCorrection()) {
if (this->flags.fullQuantParam.bias_correction) {
// init in8 session
MS_LOG(INFO) << "create quant session";
flags.commonQuantParam.quant_type = schema::QuantType_QUANT_ALL;
int8_sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
int8_sm = CreateSessionByFuncGraph(func_graph, flags, this->flags.commonQuantParam.thread_num);
int8_session_ = int8_sm.session;
int8_model_ = int8_sm.model;
if (int8_session_ == nullptr || int8_model_ == nullptr) {

View File

@ -41,10 +41,6 @@ enum OperationType {
STORE,
FETCH,
};
enum CollectType {
MIN_MAX,
KL_BIN,
};
class FullQuantQuantizer : public Quantizer {
public:
FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type = kNumberTypeInt8);
@ -78,11 +74,6 @@ class FullQuantQuantizer : public Quantizer {
int DoParameterNodeQuant(const CNodePtr &cnode, const AnfNodePtr &input_node, size_t input_index);
int CollectDataDistribution(
const std::string &node_name, const std::vector<mindspore::tensor::MSTensor *> &tensors,
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *diverg_info_map,
CollectType collect_type);
int DoBiasQuant(const AnfNodePtr &bias, const PrimitivePtr &primitive);
int Int8Inference();
int BiasCorrection(const FuncGraphPtr &func_graph);