!26657 abstract QuantStrategy && optimize full quant prepare

Merge pull request !26657 from yeyunpeng2020/quant_bak
This commit is contained in:
i-robot 2021-11-23 06:24:27 +00:00 committed by Gitee
commit 8d38efb2c4
16 changed files with 531 additions and 567 deletions

View File

@ -20,8 +20,18 @@
namespace mindspore {
namespace lite {
// `symmetric` == true -> q range is [-127 , 127];
// abs_max = max(abs(r_min),abs(r_max)); r_min = -abs_max and r_max = abs_max.
// `symmetric` == false q range is [-128 , 127]. r_min or r_max keep the original value.
// `narrow_range` is used to adjust q_min, and symmetric is always true.
int CalQuantizationParams(schema::QuantParamT *quant_param, double real_min, double real_max, int num_bits,
bool narrow_range) {
bool symmetric, bool narrow_range) {
if (symmetric) {
auto abs_max = std::max(std::abs(real_min), std::abs(real_max));
real_min = -abs_max;
real_max = abs_max;
narrow_range = true;
}
int quant_max = QuantMax(num_bits);
int quant_min = QuantMin(num_bits, false, narrow_range);
return CalQuantizationParams(quant_param, real_min, real_max, num_bits, narrow_range, quant_min, quant_max);
@ -100,8 +110,8 @@ int GetBucketIndex(const std::vector<int> &dims, int preferred_dim, int data_ind
return (data_index / stride) % bucket_count;
}
int GetAllChannelMinMmax(const float *raw_datas, int elem_count, const std::vector<int> &dims, int preferred_dim,
std::map<int, MinMax> *per_channel_min_max) {
int GetAllChannelMinMax(const float *raw_datas, int elem_count, const std::vector<int> &dims, int preferred_dim,
std::map<int, MinMax> *per_channel_min_max) {
// the key is bucket_index
std::map<int, std::vector<float>> sorted_data;
for (int i = 0; i < elem_count; ++i) {
@ -172,10 +182,11 @@ int CalWeightQuantBias(const float *raw_datas, size_t elem_count, const std::vec
average_dequants[bucket_index] = total_dequants[bucket_index] / bucket_volume;
}
constexpr int pow_exponent = 2;
for (size_t data_index = 0; data_index < elem_count; data_index++) {
auto bucket_index = GetBucketIndex(dims, preferred_dim, data_index);
var_raws[bucket_index] += std::pow(raw_datas[data_index] - average_raws[bucket_index], 2);
var_dequants[bucket_index] += std::pow(dequant_datas[data_index] - average_dequants[bucket_index], 2);
var_raws[bucket_index] += std::pow(raw_datas[data_index] - average_raws[bucket_index], pow_exponent);
var_dequants[bucket_index] += std::pow(dequant_datas[data_index] - average_dequants[bucket_index], pow_exponent);
}
for (size_t bucket_index = 0; bucket_index < bucket_size; bucket_index++) {
var_raws[bucket_index] = std::sqrt(var_raws[bucket_index] / bucket_volume);

View File

@ -17,7 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
#define MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
#include <float.h>
#include <cfloat>
#include <cmath>
#include <climits>
#include <limits>
@ -66,14 +66,14 @@ int CalQuantizationParams(schema::QuantParamT *quant_param, double real_min, dou
bool narrow_range, int quant_min, int quant_max);
int CalQuantizationParams(schema::QuantParamT *quant_param, double real_min, double real_max, int num_bits,
bool narrow_range);
bool symmetric, bool narrow_range = false);
template <typename T>
T QuantizeData(float origin_data, const schema::QuantParamT *quantParam, int quant_max, int quant_min) {
MS_ASSERT(quantParam != nullptr);
MS_ASSERT(quantParam->inited);
const auto scale = quantParam->scale;
const int zero_point = quantParam->zeroPoint;
T QuantizeData(float origin_data, const schema::QuantParamT *quant_param, int quant_max, int quant_min) {
MS_ASSERT(quant_param != nullptr);
MS_ASSERT(quant_param->inited);
const auto scale = quant_param->scale;
const int zero_point = quant_param->zeroPoint;
if (scale <= SCALE_THREASHOLD) {
return 0;
}
@ -101,8 +101,8 @@ T QuantizeData(const float origin_data, const schema::QuantParamT *quant_param)
template <typename T>
int DoPerLayerQuant(const float *raw_datas, size_t elem_count, std::vector<schema::QuantParamT> *quant_params,
const int &quant_max, const int &quant_min, const size_t &bit_num, const bool &k_means,
std::vector<T> *quant_datas) {
const int &quant_max, const int &quant_min, const size_t &bit_num, std::vector<T> *quant_datas,
bool narrow_range = false, bool k_means = false) {
if (k_means) {
MS_LOG(ERROR) << "Unsupported K-means.";
return RET_ERROR;
@ -115,7 +115,7 @@ int DoPerLayerQuant(const float *raw_datas, size_t elem_count, std::vector<schem
}
schema::QuantParamT quant_param;
int status = CalQuantizationParams(&quant_param, min, max, bit_num, false, quant_min, quant_max);
int status = CalQuantizationParams(&quant_param, min, max, bit_num, narrow_range, quant_min, quant_max);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
@ -137,8 +137,8 @@ int GetBucketIndex(const std::vector<int> &dims, int preferred_dim, int data_ind
int CalPerChannelGain(size_t bit_num, const std::vector<int> &dims, int preferred_dim);
// Get the min max of each channel
int GetAllChannelMinMmax(const float *raw_datas, int elem_count, const std::vector<int> &dims, int preferred_dim,
std::map<int, MinMax> *per_channel_min_max);
int GetAllChannelMinMax(const float *raw_datas, int elem_count, const std::vector<int> &dims, int preferred_dim,
std::map<int, MinMax> *per_channel_min_max);
// Calculate the distribution difference between quant and origin
int CalWeightQuantBias(const float *raw_datas, size_t elem_count, const std::vector<float> &dequant_datas,
@ -147,8 +147,12 @@ int CalWeightQuantBias(const float *raw_datas, size_t elem_count, const std::vec
template <typename T>
int DoPerChannelQuant(const float *raw_datas, size_t elem_count, const schema::QuantType &quant_type,
std::vector<schema::QuantParamT> *quant_params, const int &quant_max, const int &quant_min,
const size_t &bit_num, const bool &k_means, std::vector<T> *quant_datas,
const std::vector<int> &dims, int preferred_dim) {
const size_t &bit_num, std::vector<T> *quant_datas, const std::vector<int> &dims,
int preferred_dim, bool narrow_range = false, bool k_means = false) {
if (k_means) {
MS_LOG(ERROR) << "Unsupported K-means.";
return RET_ERROR;
}
int ret;
auto count = std::accumulate(std::begin(dims), std::end(dims), 1, std::multiplies<>());
if (static_cast<size_t>(count) != elem_count) {
@ -167,7 +171,7 @@ int DoPerChannelQuant(const float *raw_datas, size_t elem_count, const schema::Q
std::vector<float> dequant_datas(quant_datas->size());
// the key is bucket_index
std::map<int, MinMax> per_channel_min_max;
ret = GetAllChannelMinMmax(raw_datas, elem_count, dims, preferred_dim, &per_channel_min_max);
ret = GetAllChannelMinMax(raw_datas, elem_count, dims, preferred_dim, &per_channel_min_max);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Get all channel min max failed.";
return ret;
@ -178,7 +182,7 @@ int DoPerChannelQuant(const float *raw_datas, size_t elem_count, const schema::Q
float min = min_max_map.second.min;
float max = min_max_map.second.max;
schema::QuantParamT quant_param;
ret = CalQuantizationParams(&quant_param, min, max, bit_num, false, quant_min, quant_max);
ret = CalQuantizationParams(&quant_param, min, max, bit_num, narrow_range, quant_min, quant_max);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Cal quantization params failed.";
return ret;

View File

@ -125,11 +125,11 @@ int TrainExport::QuantTensorData(schema::TensorT *dest_tensor, const lite::Tenso
STATUS ret = RET_OK;
if (channels == kPerTensor) {
ret = DoPerLayerQuant<int8_t>(reinterpret_cast<float *>(src_tensor->data()), src_tensor->ElementsNum(),
&(quant_params), quant_max, quant_min, bit_num, false, &data);
&(quant_params), quant_max, quant_min, bit_num, &data, false, false);
} else {
ret = DoPerChannelQuant<int8_t>(reinterpret_cast<float *>(src_tensor->data()), src_tensor->ElementsNum(),
schema::QuantType_QUANT_WEIGHT, &(quant_params), quant_max, quant_min, bit_num,
false, &data, dest_tensor->dims, preferred_dim);
&data, dest_tensor->dims, preferred_dim, false, false);
}
if (ret == RET_NO_CHANGE) {
MS_LOG(DEBUG) << "No Need to quant per channel";

View File

@ -375,7 +375,7 @@ int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const con
if (config->commonQuantParam.is_debug) {
converter::Flags new_flag = *config;
new_flag.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
origin = quant::CreateSessionByFuncGraph(old_graph, new_flag, thread_num, true);
origin = quant::CreateSessionByFuncGraph(old_graph, new_flag, thread_num);
}
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) {
this->m_quantizer_ = std::make_unique<quant::FullQuantQuantizer>(old_graph, config->commonQuantParam.bit_num);
@ -431,7 +431,7 @@ int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const con
}
}
if (config->commonQuantParam.is_debug) {
quant = quant::CreateSessionByFuncGraph(old_graph, *config, thread_num, true);
quant = quant::CreateSessionByFuncGraph(old_graph, *config, thread_num);
std::map<std::string, OpParameter *> op_parameters;
FetchOpParameterFromFuncGraph(old_graph, &op_parameters);
DebugInfoManager manager;

View File

@ -25,7 +25,8 @@ namespace mindspore::lite::quant {
namespace {
constexpr int kDefaultBinNumber = 2048;
}
int Calibrator::RecordMaxMinValue(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info) {
int Calibrator::RecordMaxMinValue(const std::vector<float> &data,
const std::unique_ptr<DataDistribution> &diverg_info) {
auto ret = diverg_info->RecordMaxMinValue(data);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Record max min value failed.";
@ -43,7 +44,7 @@ int Calibrator::ComputeThreshold() {
for (auto &kv : this->outputs_diverg_info_) {
auto &outputs_diverg_info = kv.second;
for (auto &diverg_info : outputs_diverg_info) {
auto ret = diverg_info->ComputeThreshold();
auto ret = diverg_info.second->ComputeThreshold();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Compute threshold failed.";
return ret;
@ -64,10 +65,10 @@ int Calibrator::ComputeThreshold() {
break;
}
for (const auto &output_diverg_info : outputs_diverg_info.second) {
auto output_diverg_cnode = output_diverg_info->GetCNode();
auto output_diverg_cnode = output_diverg_info.second->GetCNode();
if (output_diverg_cnode == input_cnode) {
if (NodePrimitiveType(input_cnode) != lite::kNameTupleGetItem) {
*(input_infos[i]) = *output_diverg_info;
*(input_infos[i]) = *output_diverg_info.second;
input_infos[i]->GetCNode() = cnode;
already_computed = true;
break;
@ -88,18 +89,23 @@ int Calibrator::ComputeThreshold() {
return RET_OK;
}
int Calibrator::UpdateDivergInterval(
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info) {
int Calibrator::UpdateDivergInterval() {
MS_ASSERT(diverg_info != nullptr);
for (auto &kv : *diverg_info) {
for (auto &kv : inputs_diverg_info_) {
for (auto &info : kv.second) {
info->UpdateInterval();
info.second->UpdateInterval();
}
}
for (auto &kv : outputs_diverg_info_) {
for (auto &info : kv.second) {
info.second->UpdateInterval();
}
}
return RET_OK;
}
int Calibrator::UpdateDataFrequency(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info) {
int Calibrator::UpdateDataFrequency(const std::vector<float> &data,
const std::unique_ptr<DataDistribution> &diverg_info) {
MS_ASSERT(diverg_info != nullptr);
return diverg_info->UpdateHistogram(data);
}
@ -110,14 +116,31 @@ int Calibrator::AddQuantizedOp(const CNodePtr &cnode) {
return RET_ERROR;
}
auto node_name = cnode->fullname_with_scope();
std::unique_ptr<DivergInfo> input_diverg = std::make_unique<DivergInfo>(
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
MS_CHECK_TRUE_MSG(input_diverg != nullptr, RET_NULL_PTR, "input_diverg is nullptr.");
std::unique_ptr<DivergInfo> output_diverg = std::make_unique<DivergInfo>(
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
MS_CHECK_TRUE_MSG(output_diverg != nullptr, RET_NULL_PTR, "output_diverg is nullptr.");
inputs_diverg_info_[node_name].push_back(std::move(input_diverg));
outputs_diverg_info_[node_name].push_back(std::move(output_diverg));
auto size = cnode->inputs().size();
for (size_t i = 1; i < size; i++) {
std::unique_ptr<DataDistribution> input_diverg = std::make_unique<DataDistribution>(
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
MS_CHECK_TRUE_MSG(input_diverg != nullptr, RET_NULL_PTR, "input_diverg is nullptr.");
inputs_diverg_info_[node_name].insert({i - 1, std::move(input_diverg)});
}
if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
MS_CHECK_TRUE_MSG(tuple != nullptr, RET_ERROR, "tuple is nullptr");
auto elements = tuple->elements();
MS_ASSERT(elements.size() > 1);
for (size_t i = 0; i < elements.size(); i++) {
std::unique_ptr<DataDistribution> output_diverg = std::make_unique<DataDistribution>(
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
MS_CHECK_TRUE_MSG(output_diverg != nullptr, RET_NULL_PTR, "output_diverg is nullptr.");
outputs_diverg_info_[node_name].insert({i, std::move(output_diverg)});
}
} else {
std::unique_ptr<DataDistribution> output_diverg = std::make_unique<DataDistribution>(
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
MS_CHECK_TRUE_MSG(output_diverg != nullptr, RET_NULL_PTR, "output_diverg is nullptr.");
outputs_diverg_info_[node_name].insert({0, std::move(output_diverg)});
}
return RET_OK;
}
@ -126,11 +149,11 @@ int Calibrator::GenerateInputData(const std::string &input_name, size_t image_in
return preprocess::PreProcess(data_pre_process_param_, input_name, image_index, tensor);
}
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *Calibrator::GetInputDivergInfo() {
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *Calibrator::GetInputDivergInfo() {
return &this->inputs_diverg_info_;
}
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *Calibrator::GetOutputDivergInfo() {
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *Calibrator::GetOutputDivergInfo() {
return &this->outputs_diverg_info_;
}
} // namespace mindspore::lite::quant

View File

@ -23,7 +23,7 @@
#include <memory>
#include "tools/converter/quantizer/quant_params.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/converter/quantizer/diverg_info.h"
#include "tools/converter/quantizer/data_distribution.h"
namespace mindspore::lite::quant {
class Calibrator {
@ -45,17 +45,17 @@ class Calibrator {
int AddQuantizedOp(const CNodePtr &cnode);
int RecordMaxMinValue(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info);
int RecordMaxMinValue(const std::vector<float> &data, const std::unique_ptr<DataDistribution> &diverg_info);
int UpdateDivergInterval(std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info);
int UpdateDivergInterval();
int UpdateDataFrequency(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info);
int UpdateDataFrequency(const std::vector<float> &data, const std::unique_ptr<DataDistribution> &diverg_info);
int ComputeThreshold();
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetInputDivergInfo();
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *GetInputDivergInfo();
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetOutputDivergInfo();
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *GetOutputDivergInfo();
FullQuantParam full_quant_param_;
@ -64,9 +64,10 @@ class Calibrator {
int thread_ = 4;
private:
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> inputs_diverg_info_;
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> outputs_diverg_info_;
// {node_name,{tensor_index,DataDistribution}}
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> inputs_diverg_info_;
// {node_name,{tensor_index,DataDistribution}}
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> outputs_diverg_info_;
size_t bit_num_;
int quant_max_;

View File

@ -14,20 +14,21 @@
* limitations under the License.
*/
#include "tools/converter/quantizer/diverg_info.h"
#include "tools/converter/quantizer/data_distribution.h"
#include <algorithm>
#include <vector>
#include <utility>
#include <set>
namespace mindspore::lite::quant {
int DivergInfo::RecordMaxMinValue(const std::vector<float> &data) {
int DataDistribution::RecordMaxMinValue(const std::vector<float> &data) {
for (float val : data) {
max = std::max(val, max);
min = std::min(val, min);
max_ = std::max(val, max_);
min_ = std::min(val, min_);
}
return RET_OK;
}
int DivergInfo::RecordMaxMinValueArray(const std::vector<float> &data) {
int DataDistribution::RecordMaxMinValueArray(const std::vector<float> &data) {
if (data.empty()) {
return RET_ERROR;
}
@ -37,42 +38,42 @@ int DivergInfo::RecordMaxMinValueArray(const std::vector<float> &data) {
max_num = std::max(val, max_num);
min_num = std::min(val, min_num);
}
this->max_datas.emplace_back(max_num);
this->min_datas.emplace_back(min_num);
this->max_datas_.emplace_back(max_num);
this->min_datas_.emplace_back(min_num);
return RET_OK;
}
void DivergInfo::UpdateInterval() {
auto max_value = std::max(fabs(this->max), fabs(this->min));
MS_ASSERT(bin_num != 0);
this->interval = max_value / static_cast<float>(bin_num);
void DataDistribution::UpdateInterval() {
auto max_value = std::max(fabs(this->max_), fabs(this->min_));
MS_ASSERT(bin_num_ != 0);
this->interval_ = max_value / static_cast<float>(bin_num_);
}
int DivergInfo::UpdateHistogram(const std::vector<float> &data) {
int DataDistribution::UpdateHistogram(const std::vector<float> &data) {
for (auto value : data) {
if (value == 0) {
continue;
}
if (this->interval == 0) {
if (this->interval_ == 0) {
MS_LOG(ERROR) << "divisor 'interval' cannot be 0.";
return RET_ERROR;
}
int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval), bin_num - 1);
this->histogram[bin_index]++;
int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval_), bin_num_ - 1);
this->histogram_[bin_index]++;
}
return RET_OK;
}
void DivergInfo::DumpHistogram() {
MS_LOG(INFO) << "Print node " << cnode->fullname_with_scope() << " histogram";
for (float item : this->histogram) {
void DataDistribution::DumpHistogram() {
MS_LOG(INFO) << "Print node " << cnode_->fullname_with_scope() << " histogram";
for (float item : this->histogram_) {
std::cout << item << " ";
}
std::cout << std::endl;
}
void DivergInfo::HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<float> *quantized_histogram,
std::vector<float> *expanded_histogram) {
void DataDistribution::HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<float> *quantized_histogram,
std::vector<float> *expanded_histogram) {
MS_ASSERT(quantized_histogram != nullptr && expanded_histogram != nullptr);
MS_ASSERT(quant_bint_nums != 0);
const float bin_interval = static_cast<float>(bin_index) / static_cast<float>(quant_bint_nums);
@ -83,14 +84,14 @@ void DivergInfo::HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<
const int left_upper = static_cast<int>(std::ceil(start));
if (left_upper > start) {
const double left_scale = left_upper - start;
quantized_histogram->at(i) += left_scale * this->histogram[left_upper - 1];
quantized_histogram->at(i) += left_scale * this->histogram_[left_upper - 1];
}
const int right_lower = static_cast<int>(std::floor(end));
if (right_lower < end) {
const double right_scale = end - right_lower;
quantized_histogram->at(i) += right_scale * this->histogram[right_lower];
quantized_histogram->at(i) += right_scale * this->histogram_[right_lower];
}
std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower,
std::for_each(this->histogram_.begin() + left_upper, this->histogram_.begin() + right_lower,
[&quantized_histogram, i](float item) { quantized_histogram->at(i) += item; });
}
// expand target bins to i bins in order to calculate KL with reference_histogram
@ -102,7 +103,7 @@ void DivergInfo::HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<
float left_scale = 0.0f;
if (left_upper > start) {
left_scale = left_upper - start;
if (this->histogram[left_upper - 1] != 0) {
if (this->histogram_[left_upper - 1] != 0) {
count += left_scale;
}
}
@ -110,11 +111,11 @@ void DivergInfo::HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<
double right_scale = 0.0f;
if (right_lower < end) {
right_scale = end - right_lower;
if (this->histogram[right_lower] != 0) {
if (this->histogram_[right_lower] != 0) {
count += right_scale;
}
}
std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower, [&count](float item) {
std::for_each(this->histogram_.begin() + left_upper, this->histogram_.begin() + right_lower, [&count](float item) {
if (item != 0) {
count += 1;
}
@ -123,43 +124,43 @@ void DivergInfo::HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<
continue;
}
const float average_num = quantized_histogram->at(i) / count;
if (left_upper > start && this->histogram[left_upper - 1] != 0) {
if (left_upper > start && this->histogram_[left_upper - 1] != 0) {
expanded_histogram->at(left_upper - 1) += average_num * left_scale;
}
if (right_lower < end && this->histogram[right_lower] != 0) {
if (right_lower < end && this->histogram_[right_lower] != 0) {
expanded_histogram->at(right_lower) += average_num * right_scale;
}
for (int k = left_upper; k < right_lower; ++k) {
if (this->histogram[k] != 0) {
if (this->histogram_[k] != 0) {
expanded_histogram->at(k) += average_num;
}
}
}
}
int DivergInfo::ComputeThreshold() {
if (activation_quant_method == MAX_MIN) {
this->best_T = std::max(fabs(this->max), fabs(this->min));
MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T;
int DataDistribution::ComputeThreshold() {
if (activation_quant_method_ == MAX_MIN) {
this->best_T_ = std::max(fabs(this->max_), fabs(this->min_));
MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T_;
return RET_OK;
}
if (activation_quant_method == REMOVAL_OUTLIER && !this->min_datas.empty()) {
this->percent_result = OutlierMethod(min_datas, max_datas);
this->best_T = std::max(std::fabs(percent_result.first), std::fabs(percent_result.second));
if (activation_quant_method_ == REMOVAL_OUTLIER && !this->min_datas_.empty()) {
this->percent_result_ = OutlierMethod(min_datas_, max_datas_);
this->best_T_ = std::max(std::fabs(percent_result_.first), std::fabs(percent_result_.second));
return RET_OK;
}
int threshold = INT8_MAX + 1;
float min_kl = FLT_MAX;
float after_threshold_sum = std::accumulate(this->histogram.begin() + INT8_MAX + 1, this->histogram.end(), 0.0f);
float after_threshold_sum = std::accumulate(this->histogram_.begin() + INT8_MAX + 1, this->histogram_.end(), 0.0f);
for (int i = INT8_MAX + 1; i < this->bin_num; ++i) {
for (int i = INT8_MAX + 1; i < this->bin_num_; ++i) {
std::vector<float> quantized_histogram(INT8_MAX + 1, 0);
std::vector<float> reference_histogram(this->histogram.begin(), this->histogram.begin() + i);
std::vector<float> reference_histogram(this->histogram_.begin(), this->histogram_.begin() + i);
std::vector<float> expanded_histogram(i, 0);
reference_histogram[i - 1] += after_threshold_sum;
after_threshold_sum -= this->histogram[i];
after_threshold_sum -= this->histogram_[i];
// handle bins for computing KL.
HandleBinForKL(INT8_MAX + 1, i, &quantized_histogram, &expanded_histogram);
auto KLDivergence = [](std::vector<float> p, std::vector<float> q) {
@ -189,41 +190,41 @@ int DivergInfo::ComputeThreshold() {
threshold = i;
}
}
this->best_T = (static_cast<float>(threshold) + 0.5f) * this->interval;
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " Best threshold bin index: " << threshold << " T: " << best_T
<< " max: " << std::max(fabs(this->max), fabs(this->min));
this->best_T_ = (static_cast<float>(threshold) + 0.5f) * this->interval_;
MS_LOG(DEBUG) << cnode_->fullname_with_scope() << " Best threshold bin index: " << threshold << " T: " << best_T_
<< " max: " << std::max(fabs(this->max_), fabs(this->min_));
return RET_OK;
}
std::pair<CNodePtr, float> DivergInfo::GetScale() {
float max_value = this->best_T;
float DataDistribution::GetScale() {
float max_value = this->best_T_;
float min_value = -max_value;
if (this->activation_quant_method == REMOVAL_OUTLIER) {
min_value = percent_result.first;
max_value = percent_result.second;
if (this->activation_quant_method_ == REMOVAL_OUTLIER) {
min_value = percent_result_.first;
max_value = percent_result_.second;
}
MS_CHECK_TRUE_MSG(quant_max - quant_min != 0, {}, "quant_max - quant_min == 0");
float scale = (max_value - min_value) / (quant_max - quant_min);
this->scale_tmp = scale;
MS_ASSERT(fabs(scale) <= 0.0f);
return std::make_pair(this->cnode, scale);
MS_CHECK_TRUE_MSG(quant_max_ - quant_min_ > 0, 0, "quant_max_ - quant_min_ <= 0");
this->scale_ = (max_value - min_value) / (quant_max_ - quant_min_);
MS_ASSERT(fabs(this->scale_) <= 0.0f);
return this->scale_;
}
std::pair<CNodePtr, int32_t> DivergInfo::GetZeropoint() {
// Support for asymmetry in the future
int32_t DataDistribution::GetZeroPoint() {
int zero_point = 0;
if (quant_min == 0 && quant_max == UINT8_MAX) {
if (quant_min_ == 0 && quant_max_ == UINT8_MAX) {
zero_point = INT8_MAX + 1;
} else if (quant_min == INT_LEAST8_MIN + 1 && quant_max == INT8_MAX) {
} else if (quant_min_ == INT_LEAST8_MIN + 1 && quant_max_ == INT8_MAX) {
zero_point = 0;
} else {
MS_LOG(WARNING) << "unexpected quant range, quant_min: " << quant_min << " quant_max: " << quant_max;
MS_LOG(WARNING) << "unexpected quant range, quant_min_: " << quant_min_ << " quant_max_: " << quant_max_;
}
if (this->activation_quant_method == REMOVAL_OUTLIER) {
MS_CHECK_TRUE_MSG(fabs(scale_tmp) <= 0.0f, {}, "fabs(scale_tmp) > 0.0f");
zero_point = std::round(quant_max - percent_result.second / scale_tmp);
if (this->activation_quant_method_ == REMOVAL_OUTLIER) {
MS_CHECK_TRUE_MSG(fabs(scale_) <= 0.0f, 1, "fabs(scale) > 0.0f");
zero_point = std::round(quant_max_ - percent_result_.second / scale_);
}
return std::make_pair(this->cnode, zero_point);
return zero_point;
}
} // namespace mindspore::lite::quant

View File

@ -0,0 +1,84 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DATA_DISTRIBUTION_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DATA_DISTRIBUTION_H
#include <vector>
#include <utility>
#include "tools/converter/quantizer/quant_params.h"
#include "tools/converter/quantizer/quantize_util.h"
namespace mindspore::lite::quant {
class DataDistribution {
public:
DataDistribution() = default;
DataDistribution(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min,
ActivationQuantizedMethod activation_quant_method) {
this->activation_quant_method_ = activation_quant_method;
this->cnode_ = std::move(cnode);
this->bin_num_ = bins;
this->bit_num_ = bits;
histogram_.resize(bin_num_);
max_ = -FLT_MAX;
min_ = FLT_MAX;
this->quant_max_ = quant_max;
this->quant_min_ = quant_min;
std::fill(histogram_.begin(), histogram_.end(), 1.0e-7);
}
int RecordMaxMinValue(const std::vector<float> &data);
int RecordMaxMinValueArray(const std::vector<float> &data);
void UpdateInterval();
int UpdateHistogram(const std::vector<float> &data);
void DumpHistogram();
void HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<float> *quantized_histogram,
std::vector<float> *expanded_histogram);
int ComputeThreshold();
float GetScale();
int32_t GetZeroPoint();
float GetMax() { return this->max_; }
float GetMin() { return this->min_; }
CNodePtr GetCNode() { return this->cnode_; }
private:
std::vector<float> histogram_;
CNodePtr cnode_;
int bin_num_ = 0;
float interval_ = 0;
float max_ = 0.0f;
float min_ = 0.0f;
float best_T_ = 0.0f;
size_t bit_num_ = 0;
int quant_max_ = 255;
int quant_min_ = 0;
ActivationQuantizedMethod activation_quant_method_ = MAX_MIN;
std::vector<float> min_datas_;
std::vector<float> max_datas_;
std::pair<float, float> percent_result_{0.0, 0.0};
float scale_ = 0;
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DATA_DISTRIBUTION_H

View File

@ -1,84 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DIVERG_INFO_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DIVERG_INFO_H
#include <vector>
#include <utility>
#include "tools/converter/quantizer/quant_params.h"
#include "tools/converter/quantizer/quantize_util.h"
namespace mindspore::lite::quant {
class DivergInfo {
public:
DivergInfo() = default;
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min,
ActivationQuantizedMethod activation_quant_method) {
this->activation_quant_method = activation_quant_method;
this->cnode = std::move(cnode);
this->bin_num = bins;
this->bit_num = bits;
histogram.resize(bin_num);
max = -FLT_MAX;
min = FLT_MAX;
this->quant_max = quant_max;
this->quant_min = quant_min;
std::fill(histogram.begin(), histogram.end(), 1.0e-7);
}
int RecordMaxMinValue(const std::vector<float> &data);
int RecordMaxMinValueArray(const std::vector<float> &data);
void UpdateInterval();
int UpdateHistogram(const std::vector<float> &data);
void DumpHistogram();
void HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<float> *quantized_histogram,
std::vector<float> *expanded_histogram);
int ComputeThreshold();
std::pair<CNodePtr, float> GetScale();
std::pair<CNodePtr, int32_t> GetZeropoint();
float GetMax() { return this->max; }
float GetMin() { return this->min; }
CNodePtr GetCNode() { return this->cnode; }
private:
std::vector<float> histogram;
CNodePtr cnode;
int bin_num = 0;
float interval = 0;
float max = 0.0f;
float min = 0.0f;
float best_T = 0.0f;
size_t bit_num = 0;
int quant_max = 255;
int quant_min = 0;
ActivationQuantizedMethod activation_quant_method = MAX_MIN;
std::vector<float> min_datas;
std::vector<float> max_datas;
std::pair<float, float> percent_result{0.0, 0.0};
float scale_tmp = 0;
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DIVERG_INFO_H

View File

@ -32,6 +32,7 @@
#include "src/tensor.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/converter/quantizer/quant_strategy.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "src/common/log_adapter.h"
#include "securec/include/securec.h"
@ -128,18 +129,18 @@ int ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, const s
FullQuantQuantizer::FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type)
: Quantizer(std::move(graph)) {
MS_ASSERT(graph != nullptr);
this->bit_num = bit_num;
this->target_type_ = target_type;
this->bit_num_ = bit_num;
this->target_data_type_ = target_type;
if (target_type == kNumberTypeInt8) {
quant_max = (1 << (this->bit_num - 1)) - 1; // 127
quant_min = -quant_max; // -127
q_max_ = (1 << (this->bit_num_ - 1)) - 1; // 127
q_min_ = -q_max_; // -127
} else if (target_type == kNumberTypeUInt8) {
quant_max = (1 << this->bit_num) - 1; // 255
quant_min = 0;
q_max_ = (1 << this->bit_num_) - 1; // 255
q_min_ = 0;
} else {
MS_LOG(ERROR) << "unsupported quant value type: " << target_type;
}
calibrator_ = std::make_unique<Calibrator>(this->bit_num, quant_max, quant_min);
calibrator_ = std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_);
if (calibrator_ == nullptr) {
MS_LOG(ERROR) << "create calibrator failed!";
return;
@ -153,7 +154,7 @@ FullQuantQuantizer::~FullQuantQuantizer() {
delete int8_model_;
}
int FullQuantQuantizer::SetInOutQuantParam(const AnfNodePtr &input_node, const std::unique_ptr<DivergInfo> &info,
int FullQuantQuantizer::SetInOutQuantParam(const AnfNodePtr &input_node, const std::unique_ptr<DataDistribution> &info,
const PrimitivePtr &primitive, bool is_input, size_t index) const {
auto quant_param_holder = GetCNodeQuantHolder(primitive);
MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
@ -164,17 +165,17 @@ int FullQuantQuantizer::SetInOutQuantParam(const AnfNodePtr &input_node, const s
return RET_ERROR;
}
if (type_id == kNumberTypeFloat32 && info != nullptr) {
auto scale = info->GetScale().second;
auto scale = info->GetScale();
if (scale == 0) {
MS_LOG(WARNING) << "The input or output values are very close to 0, so set the scale to 1.";
quant_param.scale = 1;
} else {
quant_param.scale = scale;
}
quant_param.zeroPoint = info->GetZeropoint().second;
quant_param.zeroPoint = info->GetZeroPoint();
quant_param.max = info->GetMax();
quant_param.min = info->GetMin();
quant_param.numBits = bit_num;
quant_param.numBits = bit_num_;
quant_param.narrowRange = true;
quant_param.inited = true;
quant_param.roundType = 1;
@ -210,13 +211,9 @@ int FullQuantQuantizer::DoWeightQuant(const std::string &op_name, const AnfNodeP
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
return RET_NULL_PTR;
}
auto bit_num_t = bit_num;
auto quant_max_t = quant_max;
auto quant_min_t = quant_min;
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
auto status =
FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, QuantType_QUANT_ALL, quant_max_t, quant_min_t,
bit_num_t, weight_quant_type, kNumberTypeInt8, input_index - 1);
auto status = FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, QuantType_QUANT_ALL, q_max_, q_min_,
bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1, true);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status;
@ -332,6 +329,7 @@ int FullQuantQuantizer::DoParameterNodeQuant(const CNodePtr &cnode, const AnfNod
if (type_id == kNumberTypeInt8) {
return RET_CONTINUE;
}
// Only data the data type is fp32 can be quant.
if (type_id != kNumberTypeFloat32) {
ret = SetInOutQuantParam(input_node, nullptr, primitive, true, input_index - 1);
if (ret != RET_OK) {
@ -348,15 +346,7 @@ int FullQuantQuantizer::DoParameterNodeQuant(const CNodePtr &cnode, const AnfNod
return ret;
}
} else {
if (opt::CheckPrimitiveType(cnode, prim::kPrimMatMul)) {
if (input_index == FIRST_INPUT + 1) {
ret = DoWeightQuant(op_name, input_node, primitive, false, input_index);
} else {
ret = DoWeightQuant(op_name, input_node, primitive, true, input_index);
}
} else {
ret = DoWeightQuant(op_name, input_node, primitive, true, input_index);
}
ret = DoWeightQuant(op_name, input_node, primitive, true, input_index);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Do bias quant failed.";
return ret;
@ -515,31 +505,24 @@ int FullQuantQuantizer::QuantNode() {
}
int FullQuantQuantizer::UpdateDivergeInterval() {
auto ret = this->calibrator_->UpdateDivergInterval(this->calibrator_->GetInputDivergInfo());
auto ret = this->calibrator_->UpdateDivergInterval();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Update input diverge interval failed.";
return ret;
}
ret = this->calibrator_->UpdateDivergInterval(this->calibrator_->GetOutputDivergInfo());
if (ret != RET_OK) {
MS_LOG(ERROR) << "Update output diverge interval failed.";
return ret;
}
return RET_OK;
}
/**
* Mark quantifiable nodes
**/
int FullQuantQuantizer::PreProcess() {
auto cnodes = funcGraph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
AnfNodePtr anf = cnode->cast<AnfNodePtr>();
if (anf == nullptr) {
auto anode = cnode->cast<AnfNodePtr>();
if (anode == nullptr) {
MS_LOG(ERROR) << " cnode is null";
return RET_NULL_PTR;
}
if (mindspore::lite::quant::QuantStrategy::CanOpFullQuantized(anf)) {
// Mark quantifiable nodes
if (mindspore::lite::quant::QuantStrategy::CanOpFullQuantized(anode)) {
auto ret = calibrator_->AddQuantizedOp(cnode);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Add Quantized Op failed.";
@ -574,12 +557,44 @@ int FullQuantQuantizer::CheckFp32TensorVec(const std::string &node_name,
return RET_OK;
}
/**
* 1. create input tensor
* 2. insert callback to session
* 3. run session
**/
int FullQuantQuantizer::DoInference() {
int FullQuantQuantizer::CollectDataDistribution(
const std::string &node_name, const std::vector<mindspore::tensor::MSTensor *> &tensors,
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *diverg_info_map,
CollectType collect_type) {
if (diverg_info_map->find(node_name) == diverg_info_map->end()) {
return true;
}
for (size_t i = 0; i < tensors.size(); i++) {
auto tensor = tensors[i];
if (tensor->IsConst() || tensor->data_type() != kNumberTypeFloat32) {
continue;
}
const auto *tensor_data = static_cast<const float *>(tensor->data());
if (tensor_data == nullptr) {
MS_LOG(ERROR) << tensor->tensor_name() << " tensor_data is nullptr.";
return RET_ERROR;
}
size_t elem_count = tensor->ElementsNum();
MS_CHECK_GT(elem_count, 0, RET_ERROR);
vector<float> data(tensor_data, tensor_data + elem_count);
if (collect_type == MIN_MAX) {
auto ret = this->calibrator_->RecordMaxMinValue(data, (*diverg_info_map)[node_name][i]);
if (ret != RET_OK) {
MS_LOG(ERROR) << tensor->tensor_name() << " record max min value failed.";
return RET_ERROR;
}
} else if (collect_type == KL_BIN) {
auto ret = this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[node_name][i]);
if (ret != RET_OK) {
MS_LOG(ERROR) << tensor->tensor_name() << " update data frequency failed.";
return RET_ERROR;
}
}
}
return RET_OK;
}
int FullQuantQuantizer::DoInference(CollectType collect_type) {
// get input tensor
vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
if (inputs.size() != calibrator_->GetInputNum()) {
@ -599,34 +614,10 @@ int FullQuantQuantizer::DoInference() {
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
const CallBackParam &callParam) -> bool {
auto diverg_info_map = calibrator_->GetInputDivergInfo();
if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
return true;
}
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, beforeOutputs) != RET_OK) {
return true;
}
bool is_init = beforeInputs.size() > 1 && (*diverg_info_map)[callParam.node_name].size() == 1;
if (is_init) {
for (size_t i = 1; i < beforeInputs.size(); i++) {
if (beforeInputs.at(i)->data_type() != kNumberTypeFloat32 || beforeInputs.at(i)->IsConst()) {
continue;
}
auto input_diverg = std::make_unique<DivergInfo>();
MS_CHECK_TRUE_MSG(input_diverg != nullptr, false, "input_diverg is nullptr.");
*input_diverg = *((*diverg_info_map)[callParam.node_name][0]);
(*diverg_info_map)[callParam.node_name].push_back(std::move(input_diverg));
}
}
for (size_t i = 0; i < (*diverg_info_map)[callParam.node_name].size(); i++) {
auto tensor = beforeInputs[i];
MS_CHECK_TRUE_MSG(tensor != nullptr, false, "tensor is nullptr.");
const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
MS_CHECK_TRUE_MSG(tensor_data != nullptr, false, "tensor_data is nullptr.");
size_t elem_count = tensor->ElementsNum();
MS_CHECK_GT(elem_count, 0, false);
vector<float> data(tensor_data, tensor_data + elem_count);
auto ret = this->calibrator_->RecordMaxMinValue(data, (*diverg_info_map)[callParam.node_name][i]);
MS_CHECK_TRUE_MSG(ret == RET_OK, false, "Record MaxMinValue failed!");
auto ret = CollectDataDistribution(callParam.node_name, beforeInputs, diverg_info_map, collect_type);
if (ret != RET_OK) {
MS_LOG(ERROR) << "CollectDataDistribution failed.";
return false;
}
return true;
};
@ -635,31 +626,10 @@ int FullQuantQuantizer::DoInference() {
const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
const CallBackParam &callParam) -> bool {
auto diverg_info_map = calibrator_->GetOutputDivergInfo();
if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
return true;
}
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
return true;
}
bool is_init = afterOutputs.size() > 1 && (*diverg_info_map)[callParam.node_name].size() == 1;
if (is_init) {
for (size_t i = 1; i < afterOutputs.size(); i++) {
auto output_diverg = std::make_unique<DivergInfo>();
CHECK_NULL_RETURN(output_diverg);
*output_diverg = *((*diverg_info_map)[callParam.node_name][0]);
(*diverg_info_map)[callParam.node_name].push_back(std::move(output_diverg));
}
}
size_t output_i = 0;
for (const auto &tensor : afterOutputs) {
const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
CHECK_NULL_RETURN(tensor_data);
size_t elem_count = tensor->ElementsNum();
MS_CHECK_GT(elem_count, 0, false);
vector<float> data(tensor_data, tensor_data + elem_count);
auto ret = this->calibrator_->RecordMaxMinValue(data, (*diverg_info_map)[callParam.node_name][output_i]);
MS_CHECK_TRUE_MSG(ret == RET_OK, false, "Record MaxMinValue failed!");
output_i++;
auto ret = CollectDataDistribution(callParam.node_name, afterOutputs, diverg_info_map, collect_type);
if (ret != RET_OK) {
MS_LOG(ERROR) << "CollectDataDistribution failed.";
return false;
}
return true;
};
@ -735,14 +705,14 @@ int FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
MS_LOG(ERROR) << "divisor 'calibrate_size' cannot be 0.";
return RET_ERROR;
}
for (auto &key_value : op_bias_diff_map) {
for (auto &key_value : op_bias_diff_map_) {
std::for_each(key_value.second.begin(), key_value.second.end(),
[this](float &data) { data = data / calibrator_->GetBatchNum(); });
}
auto cnodes = func_graph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
auto op_name = cnode->fullname_with_scope();
if (op_bias_diff_map.find(op_name) == op_bias_diff_map.end()) {
if (op_bias_diff_map_.find(op_name) == op_bias_diff_map_.end()) {
continue;
}
status = BiasCorrection(func_graph, cnode);
@ -756,7 +726,7 @@ int FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
int FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
auto op_name = cnode->fullname_with_scope();
const auto &bias_diff = op_bias_diff_map[op_name];
const auto &bias_diff = op_bias_diff_map_[op_name];
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
@ -839,89 +809,6 @@ int FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, const CNo
return RET_OK;
}
int FullQuantQuantizer::CollectDataFrequency() {
// get input tensor
vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
if (inputs.size() != calibrator_->GetInputNum()) {
MS_LOG(ERROR) << "model's input tensor cnt: " << inputs.size() << " != " << calibrator_->GetInputNum();
return RET_ERROR;
}
for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) {
// set multi-input data
for (size_t input_index = 0; input_index < inputs.size(); input_index++) {
int status = calibrator_->GenerateInputData(inputs[input_index]->tensor_name(), i, inputs[input_index]);
if (status != RET_OK) {
MS_LOG(ERROR) << "generate input data from images failed!";
return RET_ERROR;
}
}
KernelCallBack before_callback = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
const CallBackParam &callParam) {
auto diverg_info_map = calibrator_->GetInputDivergInfo();
if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
return true;
}
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, before_inputs) != RET_OK) {
return true;
}
int input_i = 0;
for (auto tensor : before_inputs) {
if (tensor->data_type() != kNumberTypeFloat32 || tensor->IsConst()) {
continue;
}
const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
MS_ASSERT(tensor_data != nullptr);
size_t elem_count = tensor->ElementsNum();
MS_CHECK_GT(elem_count, 0, false);
vector<float> data(tensor_data, tensor_data + elem_count);
auto ret = this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[callParam.node_name][input_i++]);
if (ret != RET_OK) {
return false;
}
}
return true;
};
KernelCallBack after_callBack = [&](const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
const CallBackParam &call_param) {
auto diverg_info_map = calibrator_->GetOutputDivergInfo();
if (diverg_info_map->find(call_param.node_name) == diverg_info_map->end()) {
return true;
}
if (FullQuantQuantizer::CheckFp32TensorVec(call_param.node_name, after_outputs) != RET_OK) {
return true;
}
int output_i = 0;
// all outputs are same dtype.
for (const auto &tensor : after_outputs) {
const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
MS_ASSERT(tensor_data != nullptr);
size_t elem_count = tensor->ElementsNum();
MS_CHECK_GT(elem_count, 0, false);
vector<float> data(tensor_data, tensor_data + elem_count);
auto ret = this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[call_param.node_name][output_i++]);
if (ret != RET_OK) {
return false;
}
}
return true;
};
fp32_session_->BindThread(true);
auto status = fp32_session_->RunGraph(before_callback, after_callBack);
fp32_session_->BindThread(false);
if (status != RET_OK) {
MS_LOG(ERROR) << "run model failed!";
return RET_ERROR;
}
}
return RET_OK;
}
int FullQuantQuantizer::ComputeThreshold() { return this->calibrator_->ComputeThreshold(); }
int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
@ -944,7 +831,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
// anf -- fb
flags.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
MS_LOG(INFO) << "start create session";
auto sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum(), false);
auto sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
fp32_session_ = sm.session;
fp32_model_ = sm.model;
if (fp32_session_ == nullptr || fp32_model_ == nullptr) {
@ -952,7 +839,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
return RET_ERROR;
}
MS_LOG(INFO) << "start to update divergence's max value";
status = DoInference();
status = DoInference(MIN_MAX);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do inference failed.";
return status;
@ -964,7 +851,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
return status;
}
MS_LOG(INFO) << "start to collect data's distribution";
status = CollectDataFrequency();
status = DoInference(KL_BIN);
if (status != RET_OK) {
MS_LOG(ERROR) << "Collect data frequency failed.";
return status;
@ -995,7 +882,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
// init in8 session
MS_LOG(INFO) << "create quant session";
flags.commonQuantParam.quant_type = schema::QuantType_QUANT_ALL;
int8_sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum(), false);
int8_sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
int8_session_ = int8_sm.session;
int8_model_ = int8_sm.model;
if (int8_session_ == nullptr || int8_model_ == nullptr) {
@ -1013,21 +900,21 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
bool FullQuantQuantizer::OpInputDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
MS_ASSERT(data != nullptr);
std::lock_guard<std::mutex> lg(mutex_op_input);
std::lock_guard<std::mutex> lg(mutex_op_input_);
if (type == STORE) {
if (fp32_op_input_map.find(op_name) != fp32_op_input_map.end()) {
if (fp32_op_input_map_.find(op_name) != fp32_op_input_map_.end()) {
// the data has not been fetched by int8 model
return false;
}
fp32_op_input_map[op_name] = *data;
fp32_op_input_map_[op_name] = *data;
return true;
} else if (type == FETCH) {
if (fp32_op_input_map.find(op_name) == fp32_op_input_map.end()) {
if (fp32_op_input_map_.find(op_name) == fp32_op_input_map_.end()) {
// the data not generated by fp32 model yet
return false;
}
*data = fp32_op_input_map[op_name];
fp32_op_input_map.erase(op_name);
*data = fp32_op_input_map_[op_name];
fp32_op_input_map_.erase(op_name);
return true;
} else {
MS_LOG(ERROR) << "unexpected type: " << type;
@ -1037,21 +924,21 @@ bool FullQuantQuantizer::OpInputDataHandle(OperationType type, const string &op_
bool FullQuantQuantizer::OpOutputChMeanDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
MS_ASSERT(data != nullptr);
std::lock_guard<std::mutex> lg(mutex_op_output);
std::lock_guard<std::mutex> lg(mutex_op_output_);
if (type == STORE) {
if (fp32_op_output_ch_mean_map.find(op_name) != fp32_op_output_ch_mean_map.end()) {
if (fp32_op_output_ch_mean_map_.find(op_name) != fp32_op_output_ch_mean_map_.end()) {
// the data has not been fetched by int8 model
return false;
}
fp32_op_output_ch_mean_map[op_name] = *data;
fp32_op_output_ch_mean_map_[op_name] = *data;
return true;
} else if (type == FETCH) {
if (fp32_op_output_ch_mean_map.find(op_name) == fp32_op_output_ch_mean_map.end()) {
if (fp32_op_output_ch_mean_map_.find(op_name) == fp32_op_output_ch_mean_map_.end()) {
// the data not generated by fp32 model yet
return false;
}
*data = fp32_op_output_ch_mean_map[op_name];
fp32_op_output_ch_mean_map.erase(op_name);
*data = fp32_op_output_ch_mean_map_[op_name];
fp32_op_output_ch_mean_map_.erase(op_name);
return true;
} else {
MS_LOG(ERROR) << "unexpected type: " << type;
@ -1074,8 +961,7 @@ KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
size_t elem_count = tensor->ElementsNum();
MS_CHECK_GT(elem_count, 0, false);
std::vector<float> fp32_op_input(elem_count);
auto ret =
memcpy_s(fp32_op_input.data(), fp32_op_input.size() * sizeof(float), tensor->MutableData(), tensor->Size());
auto ret = memcpy_s(fp32_op_input.data(), fp32_op_input.size() * sizeof(float), tensor->data(), tensor->Size());
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return false;
@ -1112,7 +998,7 @@ KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
quant_param_t.scale = quant_params[0].scale;
quant_param_t.zeroPoint = quant_params[0].zeroPoint;
for (auto float_data : fp32_op_input) {
auto quant_data = QuantizeData<int8_t>(float_data, &quant_param_t, quant_max, quant_min);
auto quant_data = QuantizeData<int8_t>(float_data, &quant_param_t, q_max_, q_min_);
quant_datas.push_back(quant_data);
}
@ -1122,8 +1008,7 @@ KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
return false;
}
auto ret =
memcpy_s(tensor->MutableData(), tensor->Size(), quant_datas.data(), quant_datas.size() * sizeof(int8_t));
auto ret = memcpy_s(tensor->data(), tensor->Size(), quant_datas.data(), quant_datas.size() * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return false;
@ -1158,7 +1043,7 @@ KernelCallBack FullQuantQuantizer::GetInt8AfterCallBack() {
MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type();
return false;
}
const int8_t *tensor_data = static_cast<int8_t *>(tensor->MutableData());
const int8_t *tensor_data = static_cast<int8_t *>(tensor->data());
size_t elem_count = tensor->ElementsNum();
MS_CHECK_GT(elem_count, 0, false);
auto shapes = tensor->shape();
@ -1203,12 +1088,12 @@ KernelCallBack FullQuantQuantizer::GetInt8AfterCallBack() {
std::transform(fp32_op_output_ch_mean.begin(), fp32_op_output_ch_mean.end(), dequant_op_output_ch_mean.begin(),
dequant_op_output_ch_mean.begin(), std::minus<>());
if (op_bias_diff_map.find(callParam.node_name) != op_bias_diff_map.end()) {
auto &bias_diff = op_bias_diff_map[callParam.node_name];
if (op_bias_diff_map_.find(callParam.node_name) != op_bias_diff_map_.end()) {
auto &bias_diff = op_bias_diff_map_[callParam.node_name];
std::transform(bias_diff.begin(), bias_diff.end(), dequant_op_output_ch_mean.begin(), bias_diff.begin(),
std::plus<>());
} else {
op_bias_diff_map[callParam.node_name] = dequant_op_output_ch_mean;
op_bias_diff_map_[callParam.node_name] = dequant_op_output_ch_mean;
}
}
return true;
@ -1226,7 +1111,7 @@ KernelCallBack FullQuantQuantizer::GetFloatAfterCallBack() {
}
auto tensor = afterOutputs[0];
MS_ASSERT(tensor != nullptr);
const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
const auto *tensor_data = static_cast<const float *>(tensor->data());
size_t elem_count = tensor->ElementsNum();
MS_CHECK_GT(elem_count, 0, false);
auto shapes = tensor->shape();

View File

@ -34,13 +34,17 @@
#include "tools/converter/quantizer/quant_params.h"
#include "tools/converter/preprocess/preprocess_param.h"
#include "tools/converter/quantizer/calibrator.h"
#include "tools/converter/quantizer/diverg_info.h"
#include "tools/converter/quantizer/data_distribution.h"
namespace mindspore::lite::quant {
enum OperationType {
STORE,
FETCH,
};
enum CollectType {
MIN_MAX,
KL_BIN,
};
class FullQuantQuantizer : public Quantizer {
public:
FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type = kNumberTypeInt8);
@ -54,22 +58,19 @@ class FullQuantQuantizer : public Quantizer {
int PreProcess();
static int CheckFp32TensorVec(const std::string &node_name,
const std::vector<mindspore::tensor::MSTensor *> &tensor_vec);
int CheckFp32TensorVec(const std::string &node_name, const std::vector<mindspore::tensor::MSTensor *> &tensor_vec);
int DoInference();
int DoInference(CollectType collect_type);
int UpdateDivergeInterval();
int CollectDataFrequency();
int ComputeThreshold();
int QuantNodeSimpleOp(const CNodePtr &cnode);
int QuantNode();
int SetInOutQuantParam(const AnfNodePtr &input_node, const std::unique_ptr<DivergInfo> &info,
int SetInOutQuantParam(const AnfNodePtr &input_node, const std::unique_ptr<DataDistribution> &info,
const PrimitivePtr &primitive, bool is_input, size_t index) const;
int DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight, const PrimitivePtr &primitive,
@ -77,7 +78,12 @@ class FullQuantQuantizer : public Quantizer {
int DoParameterNodeQuant(const CNodePtr &cnode, const AnfNodePtr &input_node, size_t input_index);
static int DoBiasQuant(const AnfNodePtr &bias, const PrimitivePtr &primitive);
int CollectDataDistribution(
const std::string &node_name, const std::vector<mindspore::tensor::MSTensor *> &tensors,
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *diverg_info_map,
CollectType collect_type);
int DoBiasQuant(const AnfNodePtr &bias, const PrimitivePtr &primitive);
int Int8Inference();
int BiasCorrection(const FuncGraphPtr &func_graph);
int BiasCorrection(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
@ -87,22 +93,22 @@ class FullQuantQuantizer : public Quantizer {
KernelCallBack GetFloatAfterCallBack();
private:
TypeId target_type_{kNumberTypeInt8};
TypeId target_data_type_{kNumberTypeInt8};
std::unique_ptr<Calibrator> calibrator_{nullptr};
session::LiteSession *fp32_session_{nullptr};
Model *fp32_model_{nullptr};
session::LiteSession *int8_session_{nullptr};
Model *int8_model_{nullptr};
std::map<std::string, std::vector<float>> fp32_op_input_map; // concurrency
std::map<std::string, std::vector<float>> fp32_op_output_ch_mean_map; // concurrency
std::map<std::string, std::vector<float>> op_bias_diff_map; // only use by int8 model
std::mutex mutex_op_input;
std::mutex mutex_op_output;
std::map<std::string, std::vector<float>> fp32_op_input_map_; // concurrency
std::map<std::string, std::vector<float>> fp32_op_output_ch_mean_map_; // concurrency
std::map<std::string, std::vector<float>> op_bias_diff_map_; // only use by int8 model
std::mutex mutex_op_input_;
std::mutex mutex_op_output_;
size_t bit_num;
int quant_max{INT8_MAX};
int quant_min{INT8_MIN};
size_t bit_num_;
int q_max_{INT8_MAX};
int q_min_{INT8_MIN};
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FULL_QUANT_QUANTIZER_H

View File

@ -0,0 +1,123 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/quant_strategy.h"
#include <set>
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "base/core_ops.h"
#include "src/common/log_adapter.h"
#include "src/common/log_util.h"
#include "nnacl/op_base.h"
namespace mindspore::lite::quant {
bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &input_node, int preferred_dim) {
if (input_node == nullptr) {
MS_LOG(INFO) << "CanTensorQuantized input is nullptr!";
return false;
}
ParameterPtr param_node = nullptr;
if (input_node->isa<Parameter>()) {
param_node = input_node->cast<ParameterPtr>();
}
if (param_node == nullptr) {
MS_LOG(INFO) << "CanTensorQuantized invalid param_node!";
return false;
}
if (!param_node->has_default()) {
MS_LOG(INFO) << "param_node don't has default.";
return false;
}
auto abstract_base = param_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(INFO) << "abstract is nullptr";
return false;
}
if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << param_node->name();
return false;
}
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
MS_ASSERT(weight_shape != nullptr);
if (weight_shape.size() < DIMENSION_2D) { // do not quant single dim tensors
return false;
}
int64_t total_shape_size = 1;
for (auto shape : weight_shape) {
MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(total_shape_size, shape), RET_ERROR, "Int mul overflow");
total_shape_size *= shape;
}
if (total_shape_size < 0 || static_cast<size_t>(total_shape_size) < min_quant_weight_size_) {
MS_LOG(INFO) << "shape_size " << total_shape_size << " less min_quant_weight_size_ " << min_quant_weight_size_;
return false;
}
// min_quant_weight_channel_ only supports convolution
if (weight_shape.size() > DIMENSION_2D &&
weight_shape[preferred_dim] <= static_cast<int>(min_quant_weight_channel_)) {
MS_LOG(INFO) << "preferred_dim shape:" << weight_shape[preferred_dim] << " less min_quant_weight_channel_ "
<< min_quant_weight_channel_;
return false;
}
return true;
}
bool QuantStrategy::CanOpFullQuantized(const AnfNodePtr &node) {
MS_CHECK_TRUE_RET(node != nullptr, false);
if (!node->isa<mindspore::CNode>()) {
return false;
}
const auto cnode = std::dynamic_pointer_cast<mindspore::CNode>(node);
MS_ASSERT(cnode != nullptr);
auto type = NodePrimitiveType(cnode);
static const std::set<PrimitivePtr> support_int8_ops = {prim::kPrimAddFusion, prim::kPrimActivation,
prim::kPrimAvgPoolFusion, prim::kPrimConcat,
prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
prim::kPrimCrop, prim::kPrimFullConnection,
prim::kPrimGather, prim::kPrimLayerNormFusion,
prim::kPrimMatMul, prim::kPrimMaxPoolFusion,
prim::kPrimMulFusion, prim::kPrimReshape,
prim::kPrimSplit, prim::kPrimTranspose,
prim::kPrimReduceFusion, prim::kPrimDivFusion,
prim::kPrimSqrt, prim::kPrimPowFusion,
prim::kPrimUnsqueeze, prim::kPrimAffine};
// The return node does not need to be quantified.
if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn) || opt::CheckPrimitiveType(cnode, prim::kPrimMakeTuple)) {
return false;
}
// These operators do not need to check the data type.
if (opt::CheckPrimitiveType(cnode, prim::kPrimShape) || opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
return true;
}
auto is_support_node = CheckNodeInSet(cnode, support_int8_ops);
if (!is_support_node && type != "Eltwise") {
MS_LOG(WARNING) << "node:" << cnode->fullname_with_scope() << " type:" << type << " is not support quantization.";
return false;
}
TypeId type_id;
auto ret = opt::GetDataTypeFromAnfNode(cnode, &type_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Fetch DataType from cnode failed.";
return false;
}
bool is_data_type_fp32 = type_id == kNumberTypeFloat32;
if (!is_data_type_fp32) {
MS_LOG(INFO) << cnode->fullname_with_scope() << " type_id is " << type_id << " , and is not float32.";
}
return is_data_type_fp32;
}
} // namespace mindspore::lite::quant

View File

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_STRATEGY_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_STRATEGY_H
#include <cstddef>
#include "ir/anf.h"
namespace mindspore::lite::quant {
class QuantStrategy {
public:
QuantStrategy(size_t min_quant_weight_size, size_t min_quant_weight_channel)
: min_quant_weight_size_(min_quant_weight_size), min_quant_weight_channel_(min_quant_weight_channel) {}
~QuantStrategy() = default;
static bool CanOpFullQuantized(const AnfNodePtr &node);
bool CanTensorQuantized(const AnfNodePtr &input_node, int preferred_dim);
private:
size_t min_quant_weight_size_;
size_t min_quant_weight_channel_;
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_STRATEGY_H

View File

@ -17,17 +17,11 @@
#include "mindspore/lite/tools/converter/quantizer/quantize_util.h"
#include <cmath>
#include <string>
#include <map>
#include <fstream>
#include <algorithm>
#include <memory>
#include <vector>
#include <set>
#include <functional>
#include "include/version.h"
#include "ops/affine.h"
#include "ops/fusion/conv2d_fusion.h"
#include "ops/fusion/conv2d_transpose_fusion.h"
#include "ops/fusion/full_connection.h"
#include "ops/mat_mul.h"
#include "tools/converter/ops/ops_def.h"
@ -44,9 +38,6 @@ using std::string;
using std::vector;
namespace mindspore::lite::quant {
constexpr int kDim2 = 2;
constexpr int kDim4 = 4;
const int kLstmInputWeightIndex = 1;
const int kLstmStateWeightIndex = 2;
const int kLstmWeightShapeSize = 3;
@ -54,102 +45,6 @@ const int kSingleDirBiasTensorSize = 4;
const int kLstmBiasShapeSize = 2;
const int kLstmBiasIndex = 3;
bool QuantStrategy::CanOpFullQuantized(const AnfNodePtr &node) {
MS_CHECK_TRUE_RET(node != nullptr, false);
if (!node->isa<mindspore::CNode>()) {
return false;
}
const auto cnode = std::dynamic_pointer_cast<mindspore::CNode>(node);
MS_ASSERT(cnode != nullptr);
auto type = NodePrimitiveType(cnode);
static const std::set<PrimitivePtr> support_int8_ops = {prim::kPrimAddFusion, prim::kPrimActivation,
prim::kPrimAvgPoolFusion, prim::kPrimConcat,
prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
prim::kPrimCrop, prim::kPrimFullConnection,
prim::kPrimGather, prim::kPrimLayerNormFusion,
prim::kPrimMatMul, prim::kPrimMaxPoolFusion,
prim::kPrimMulFusion, prim::kPrimReshape,
prim::kPrimSplit, prim::kPrimTranspose,
prim::kPrimReduceFusion, prim::kPrimDivFusion,
prim::kPrimSqrt, prim::kPrimPowFusion,
prim::kPrimUnsqueeze, prim::kPrimAffine};
// The return node does not need to be quantified.
if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn) || opt::CheckPrimitiveType(cnode, prim::kPrimMakeTuple)) {
return false;
}
// These operators do not need to check the data type.
if (opt::CheckPrimitiveType(cnode, prim::kPrimShape) || opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
return true;
}
auto is_support_node = CheckNodeInSet(cnode, support_int8_ops);
if (!is_support_node && type != "Eltwise") {
MS_LOG(WARNING) << "node:" << cnode->fullname_with_scope() << " type:" << type << " is not support quantization.";
return false;
}
TypeId type_id;
auto ret = opt::GetDataTypeFromAnfNode(cnode, &type_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Fetch DataType from cnode failed.";
return false;
}
bool is_data_type_fp32 = type_id == kNumberTypeFloat32;
if (!is_data_type_fp32) {
MS_LOG(INFO) << cnode->fullname_with_scope() << " type_id is " << type_id << " , and is not float32.";
}
return is_data_type_fp32;
}
bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &input_node, int preferred_dim) const {
if (input_node == nullptr) {
MS_LOG(INFO) << "CanTensorQuantized input is nullptr!";
return false;
}
ParameterPtr param_node = nullptr;
if (input_node->isa<Parameter>()) {
param_node = input_node->cast<ParameterPtr>();
}
if (param_node == nullptr) {
MS_LOG(INFO) << "CanTensorQuantized invalid param_node!";
return false;
}
if (!param_node->has_default()) {
MS_LOG(INFO) << "param_node don't has default.";
return false;
}
auto abstract_base = param_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(INFO) << "abstract is nullptr";
return false;
}
if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << param_node->name();
return false;
}
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
MS_ASSERT(weight_shape != nullptr);
if (weight_shape.size() < kDim2) { // do not quant single dim tensors
return false;
}
int64_t total_shape_size = 1;
for (auto shape : weight_shape) {
MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(total_shape_size, shape), RET_ERROR, "Int mul overflow");
total_shape_size *= shape;
}
if (total_shape_size < 0 || static_cast<size_t>(total_shape_size) < min_quant_weight_size_) {
MS_LOG(INFO) << "shape_size " << total_shape_size << " less min_quant_weight_size_ " << min_quant_weight_size_;
return false;
}
// min_quant_weight_channel_ only supports convolution
if (weight_shape.size() > kDim2 && weight_shape[preferred_dim] <= static_cast<int>(min_quant_weight_channel_)) {
MS_LOG(INFO) << "preferred_dim shape:" << weight_shape[preferred_dim] << " less min_quant_weight_channel_ "
<< min_quant_weight_channel_;
return false;
}
return true;
}
QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) {
MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
QuantParamHolderPtr quant_params_holder = nullptr;
@ -359,7 +254,7 @@ std::string NodePrimitiveType(const CNodePtr &cnode) {
}
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num,
int *size, bool is_debug) {
int *size) {
SessionModel sm;
auto meta_graph = Export(func_graph, true, true);
if (meta_graph == nullptr) {
@ -414,19 +309,15 @@ SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const conv
delete model;
return sm;
}
if (!is_debug) {
model->Free();
}
delete meta_graph;
sm.session = session;
sm.model = model;
return sm;
}
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num,
bool is_debug) {
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num) {
int size = 0;
return CreateSessionByFuncGraph(func_graph, flags, thread_num, &size, is_debug);
return CreateSessionByFuncGraph(func_graph, flags, thread_num, &size);
}
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info) {
@ -550,7 +441,7 @@ void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<in
MS_LOG(ERROR) << " shape vector is empty.";
return;
}
if (primitive.value.type == schema::PrimitiveType_MatMul && static_cast<int>(shapes.size()) == kDim2) {
if (primitive.value.type == schema::PrimitiveType_MatMul && static_cast<int>(shapes.size()) == DIMENSION_2D) {
auto matmul_prim = primitive.value.AsMatMul();
MS_ASSERT(matmul_prim != nullptr);
*channel_at_first = index != 1 || matmul_prim->transpose_b;

View File

@ -60,36 +60,15 @@ constexpr size_t kMaxBit = 8;
constexpr size_t kMaxNum1024 = 1024;
constexpr float kPercentBase = 100.0;
constexpr size_t kMillisecondsBase = 10;
constexpr float delta = 0.1;
constexpr float ratio = 10.0;
constexpr int percent = 10;
struct SessionModel {
session::LiteSession *session{nullptr};
Model *model{nullptr};
};
/**
* 1. when op's weight size > mWeightSize just skip
* 2. only do conv/deconv/convdepthwise/deconvdepthwise/mul/matmul/batchmatmul quantization
* 3. when conv/deconv/convdepthwise/deconvdepthwise ops' weight channel size > covWeightQuantChannelThreshold just skip
* */
class QuantStrategy {
public:
QuantStrategy(size_t min_quant_weight_size, size_t min_quant_weight_channel)
: min_quant_weight_size_(min_quant_weight_size), min_quant_weight_channel_(min_quant_weight_channel) {}
~QuantStrategy() = default;
static bool CanOpFullQuantized(const AnfNodePtr &node);
bool CanTensorQuantized(const AnfNodePtr &input_node, int preferred_dim) const;
private:
size_t min_quant_weight_size_;
size_t min_quant_weight_channel_;
};
constexpr float delta = 0.1;
constexpr float ratio = 10.0;
constexpr int percent = 10;
QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive);
std::pair<float, float> OutlierMethod(std::vector<float> min_datas, std::vector<float> max_datas);
@ -117,7 +96,8 @@ std::vector<int> ConvertShapeVectorToInt32(const ShapeVector &dims);
template <typename T>
int FixedBitQuantFilter(const ParameterPtr &parameter, const tensor::TensorPtr &weight, const PrimitivePtr &primitive,
QuantType quant_type, int quant_max, int quant_min, size_t bit_num,
WeightQuantType weight_quant_type, TypeId quant_data_type, int index, bool k_means = false) {
WeightQuantType weight_quant_type, TypeId quant_data_type, int index, bool narrow_range = false,
bool k_means = false) {
MS_ASSERT(weight != nullptr);
MS_ASSERT(primitive != nullptr);
auto dims = weight->shape();
@ -143,7 +123,7 @@ int FixedBitQuantFilter(const ParameterPtr &parameter, const tensor::TensorPtr &
ret =
DoPerChannelQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(),
static_cast<mindspore::schema::QuantType>(quant_type), &quant_params, quant_max, quant_min,
bit_num, k_means, &quant_data, ConvertShapeVectorToInt32(dims), preferred_dim);
bit_num, &quant_data, ConvertShapeVectorToInt32(dims), preferred_dim, narrow_range, k_means);
if (ret == RET_NO_CHANGE) {
return ret;
} else if (ret != RET_OK) {
@ -152,7 +132,7 @@ int FixedBitQuantFilter(const ParameterPtr &parameter, const tensor::TensorPtr &
}
} else if (weight_quant_type == FIXED_BIT_PER_LAYER) {
ret = DoPerLayerQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(), &quant_params, quant_max,
quant_min, bit_num, k_means, &quant_data);
quant_min, bit_num, &quant_data, narrow_range, k_means);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Do per layer quant failed.";
return ret;
@ -188,10 +168,9 @@ int FixedBitQuantFilter(const ParameterPtr &parameter, const tensor::TensorPtr &
std::string NodePrimitiveType(const CNodePtr &cnode);
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num);
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num,
bool is_debug = false);
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num,
int *size, bool is_debug = false);
int *size);
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info);
bool CheckNodeInSet(const CNodePtr &cnode, const std::set<PrimitivePtr> &support_primitive_types);

View File

@ -28,6 +28,7 @@
#include "tools/converter/quantizer/quantizer.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/converter/quantizer/quant_params.h"
#include "tools/converter/quantizer/quant_strategy.h"
#include "tools/converter/preprocess/preprocess_param.h"
#include "ir/func_graph.h"
#include "ir/anf.h"