forked from mindspore-Ecosystem/mindspore
!26657 abstract QuantStrategy && optimize full quant prepare
Merge pull request !26657 from yeyunpeng2020/quant_bak
This commit is contained in:
commit
8d38efb2c4
|
@ -20,8 +20,18 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
// `symmetric` == true -> q range is [-127 , 127];
|
||||
// abs_max = max(abs(r_min),abs(r_max)); r_min = -abs_max and r_max = abs_max.
|
||||
// `symmetric` == false q range is [-128 , 127]. r_min or r_max keep the original value.
|
||||
// `narrow_range` is used to adjust q_min, and symmetric is always true.
|
||||
int CalQuantizationParams(schema::QuantParamT *quant_param, double real_min, double real_max, int num_bits,
|
||||
bool narrow_range) {
|
||||
bool symmetric, bool narrow_range) {
|
||||
if (symmetric) {
|
||||
auto abs_max = std::max(std::abs(real_min), std::abs(real_max));
|
||||
real_min = -abs_max;
|
||||
real_max = abs_max;
|
||||
narrow_range = true;
|
||||
}
|
||||
int quant_max = QuantMax(num_bits);
|
||||
int quant_min = QuantMin(num_bits, false, narrow_range);
|
||||
return CalQuantizationParams(quant_param, real_min, real_max, num_bits, narrow_range, quant_min, quant_max);
|
||||
|
@ -100,8 +110,8 @@ int GetBucketIndex(const std::vector<int> &dims, int preferred_dim, int data_ind
|
|||
return (data_index / stride) % bucket_count;
|
||||
}
|
||||
|
||||
int GetAllChannelMinMmax(const float *raw_datas, int elem_count, const std::vector<int> &dims, int preferred_dim,
|
||||
std::map<int, MinMax> *per_channel_min_max) {
|
||||
int GetAllChannelMinMax(const float *raw_datas, int elem_count, const std::vector<int> &dims, int preferred_dim,
|
||||
std::map<int, MinMax> *per_channel_min_max) {
|
||||
// the key is bucket_index
|
||||
std::map<int, std::vector<float>> sorted_data;
|
||||
for (int i = 0; i < elem_count; ++i) {
|
||||
|
@ -172,10 +182,11 @@ int CalWeightQuantBias(const float *raw_datas, size_t elem_count, const std::vec
|
|||
average_dequants[bucket_index] = total_dequants[bucket_index] / bucket_volume;
|
||||
}
|
||||
|
||||
constexpr int pow_exponent = 2;
|
||||
for (size_t data_index = 0; data_index < elem_count; data_index++) {
|
||||
auto bucket_index = GetBucketIndex(dims, preferred_dim, data_index);
|
||||
var_raws[bucket_index] += std::pow(raw_datas[data_index] - average_raws[bucket_index], 2);
|
||||
var_dequants[bucket_index] += std::pow(dequant_datas[data_index] - average_dequants[bucket_index], 2);
|
||||
var_raws[bucket_index] += std::pow(raw_datas[data_index] - average_raws[bucket_index], pow_exponent);
|
||||
var_dequants[bucket_index] += std::pow(dequant_datas[data_index] - average_dequants[bucket_index], pow_exponent);
|
||||
}
|
||||
for (size_t bucket_index = 0; bucket_index < bucket_size; bucket_index++) {
|
||||
var_raws[bucket_index] = std::sqrt(var_raws[bucket_index] / bucket_volume);
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
|
||||
#define MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
|
||||
|
||||
#include <float.h>
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
#include <climits>
|
||||
#include <limits>
|
||||
|
@ -66,14 +66,14 @@ int CalQuantizationParams(schema::QuantParamT *quant_param, double real_min, dou
|
|||
bool narrow_range, int quant_min, int quant_max);
|
||||
|
||||
int CalQuantizationParams(schema::QuantParamT *quant_param, double real_min, double real_max, int num_bits,
|
||||
bool narrow_range);
|
||||
bool symmetric, bool narrow_range = false);
|
||||
|
||||
template <typename T>
|
||||
T QuantizeData(float origin_data, const schema::QuantParamT *quantParam, int quant_max, int quant_min) {
|
||||
MS_ASSERT(quantParam != nullptr);
|
||||
MS_ASSERT(quantParam->inited);
|
||||
const auto scale = quantParam->scale;
|
||||
const int zero_point = quantParam->zeroPoint;
|
||||
T QuantizeData(float origin_data, const schema::QuantParamT *quant_param, int quant_max, int quant_min) {
|
||||
MS_ASSERT(quant_param != nullptr);
|
||||
MS_ASSERT(quant_param->inited);
|
||||
const auto scale = quant_param->scale;
|
||||
const int zero_point = quant_param->zeroPoint;
|
||||
if (scale <= SCALE_THREASHOLD) {
|
||||
return 0;
|
||||
}
|
||||
|
@ -101,8 +101,8 @@ T QuantizeData(const float origin_data, const schema::QuantParamT *quant_param)
|
|||
|
||||
template <typename T>
|
||||
int DoPerLayerQuant(const float *raw_datas, size_t elem_count, std::vector<schema::QuantParamT> *quant_params,
|
||||
const int &quant_max, const int &quant_min, const size_t &bit_num, const bool &k_means,
|
||||
std::vector<T> *quant_datas) {
|
||||
const int &quant_max, const int &quant_min, const size_t &bit_num, std::vector<T> *quant_datas,
|
||||
bool narrow_range = false, bool k_means = false) {
|
||||
if (k_means) {
|
||||
MS_LOG(ERROR) << "Unsupported K-means.";
|
||||
return RET_ERROR;
|
||||
|
@ -115,7 +115,7 @@ int DoPerLayerQuant(const float *raw_datas, size_t elem_count, std::vector<schem
|
|||
}
|
||||
|
||||
schema::QuantParamT quant_param;
|
||||
int status = CalQuantizationParams(&quant_param, min, max, bit_num, false, quant_min, quant_max);
|
||||
int status = CalQuantizationParams(&quant_param, min, max, bit_num, narrow_range, quant_min, quant_max);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
|
||||
return status;
|
||||
|
@ -137,8 +137,8 @@ int GetBucketIndex(const std::vector<int> &dims, int preferred_dim, int data_ind
|
|||
int CalPerChannelGain(size_t bit_num, const std::vector<int> &dims, int preferred_dim);
|
||||
|
||||
// Get the min max of each channel
|
||||
int GetAllChannelMinMmax(const float *raw_datas, int elem_count, const std::vector<int> &dims, int preferred_dim,
|
||||
std::map<int, MinMax> *per_channel_min_max);
|
||||
int GetAllChannelMinMax(const float *raw_datas, int elem_count, const std::vector<int> &dims, int preferred_dim,
|
||||
std::map<int, MinMax> *per_channel_min_max);
|
||||
|
||||
// Calculate the distribution difference between quant and origin
|
||||
int CalWeightQuantBias(const float *raw_datas, size_t elem_count, const std::vector<float> &dequant_datas,
|
||||
|
@ -147,8 +147,12 @@ int CalWeightQuantBias(const float *raw_datas, size_t elem_count, const std::vec
|
|||
template <typename T>
|
||||
int DoPerChannelQuant(const float *raw_datas, size_t elem_count, const schema::QuantType &quant_type,
|
||||
std::vector<schema::QuantParamT> *quant_params, const int &quant_max, const int &quant_min,
|
||||
const size_t &bit_num, const bool &k_means, std::vector<T> *quant_datas,
|
||||
const std::vector<int> &dims, int preferred_dim) {
|
||||
const size_t &bit_num, std::vector<T> *quant_datas, const std::vector<int> &dims,
|
||||
int preferred_dim, bool narrow_range = false, bool k_means = false) {
|
||||
if (k_means) {
|
||||
MS_LOG(ERROR) << "Unsupported K-means.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int ret;
|
||||
auto count = std::accumulate(std::begin(dims), std::end(dims), 1, std::multiplies<>());
|
||||
if (static_cast<size_t>(count) != elem_count) {
|
||||
|
@ -167,7 +171,7 @@ int DoPerChannelQuant(const float *raw_datas, size_t elem_count, const schema::Q
|
|||
std::vector<float> dequant_datas(quant_datas->size());
|
||||
// the key is bucket_index
|
||||
std::map<int, MinMax> per_channel_min_max;
|
||||
ret = GetAllChannelMinMmax(raw_datas, elem_count, dims, preferred_dim, &per_channel_min_max);
|
||||
ret = GetAllChannelMinMax(raw_datas, elem_count, dims, preferred_dim, &per_channel_min_max);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Get all channel min max failed.";
|
||||
return ret;
|
||||
|
@ -178,7 +182,7 @@ int DoPerChannelQuant(const float *raw_datas, size_t elem_count, const schema::Q
|
|||
float min = min_max_map.second.min;
|
||||
float max = min_max_map.second.max;
|
||||
schema::QuantParamT quant_param;
|
||||
ret = CalQuantizationParams(&quant_param, min, max, bit_num, false, quant_min, quant_max);
|
||||
ret = CalQuantizationParams(&quant_param, min, max, bit_num, narrow_range, quant_min, quant_max);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Cal quantization params failed.";
|
||||
return ret;
|
||||
|
|
|
@ -125,11 +125,11 @@ int TrainExport::QuantTensorData(schema::TensorT *dest_tensor, const lite::Tenso
|
|||
STATUS ret = RET_OK;
|
||||
if (channels == kPerTensor) {
|
||||
ret = DoPerLayerQuant<int8_t>(reinterpret_cast<float *>(src_tensor->data()), src_tensor->ElementsNum(),
|
||||
&(quant_params), quant_max, quant_min, bit_num, false, &data);
|
||||
&(quant_params), quant_max, quant_min, bit_num, &data, false, false);
|
||||
} else {
|
||||
ret = DoPerChannelQuant<int8_t>(reinterpret_cast<float *>(src_tensor->data()), src_tensor->ElementsNum(),
|
||||
schema::QuantType_QUANT_WEIGHT, &(quant_params), quant_max, quant_min, bit_num,
|
||||
false, &data, dest_tensor->dims, preferred_dim);
|
||||
&data, dest_tensor->dims, preferred_dim, false, false);
|
||||
}
|
||||
if (ret == RET_NO_CHANGE) {
|
||||
MS_LOG(DEBUG) << "No Need to quant per channel";
|
||||
|
|
|
@ -375,7 +375,7 @@ int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const con
|
|||
if (config->commonQuantParam.is_debug) {
|
||||
converter::Flags new_flag = *config;
|
||||
new_flag.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
|
||||
origin = quant::CreateSessionByFuncGraph(old_graph, new_flag, thread_num, true);
|
||||
origin = quant::CreateSessionByFuncGraph(old_graph, new_flag, thread_num);
|
||||
}
|
||||
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) {
|
||||
this->m_quantizer_ = std::make_unique<quant::FullQuantQuantizer>(old_graph, config->commonQuantParam.bit_num);
|
||||
|
@ -431,7 +431,7 @@ int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const con
|
|||
}
|
||||
}
|
||||
if (config->commonQuantParam.is_debug) {
|
||||
quant = quant::CreateSessionByFuncGraph(old_graph, *config, thread_num, true);
|
||||
quant = quant::CreateSessionByFuncGraph(old_graph, *config, thread_num);
|
||||
std::map<std::string, OpParameter *> op_parameters;
|
||||
FetchOpParameterFromFuncGraph(old_graph, &op_parameters);
|
||||
DebugInfoManager manager;
|
||||
|
|
|
@ -25,7 +25,8 @@ namespace mindspore::lite::quant {
|
|||
namespace {
|
||||
constexpr int kDefaultBinNumber = 2048;
|
||||
}
|
||||
int Calibrator::RecordMaxMinValue(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info) {
|
||||
int Calibrator::RecordMaxMinValue(const std::vector<float> &data,
|
||||
const std::unique_ptr<DataDistribution> &diverg_info) {
|
||||
auto ret = diverg_info->RecordMaxMinValue(data);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Record max min value failed.";
|
||||
|
@ -43,7 +44,7 @@ int Calibrator::ComputeThreshold() {
|
|||
for (auto &kv : this->outputs_diverg_info_) {
|
||||
auto &outputs_diverg_info = kv.second;
|
||||
for (auto &diverg_info : outputs_diverg_info) {
|
||||
auto ret = diverg_info->ComputeThreshold();
|
||||
auto ret = diverg_info.second->ComputeThreshold();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Compute threshold failed.";
|
||||
return ret;
|
||||
|
@ -64,10 +65,10 @@ int Calibrator::ComputeThreshold() {
|
|||
break;
|
||||
}
|
||||
for (const auto &output_diverg_info : outputs_diverg_info.second) {
|
||||
auto output_diverg_cnode = output_diverg_info->GetCNode();
|
||||
auto output_diverg_cnode = output_diverg_info.second->GetCNode();
|
||||
if (output_diverg_cnode == input_cnode) {
|
||||
if (NodePrimitiveType(input_cnode) != lite::kNameTupleGetItem) {
|
||||
*(input_infos[i]) = *output_diverg_info;
|
||||
*(input_infos[i]) = *output_diverg_info.second;
|
||||
input_infos[i]->GetCNode() = cnode;
|
||||
already_computed = true;
|
||||
break;
|
||||
|
@ -88,18 +89,23 @@ int Calibrator::ComputeThreshold() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int Calibrator::UpdateDivergInterval(
|
||||
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info) {
|
||||
int Calibrator::UpdateDivergInterval() {
|
||||
MS_ASSERT(diverg_info != nullptr);
|
||||
for (auto &kv : *diverg_info) {
|
||||
for (auto &kv : inputs_diverg_info_) {
|
||||
for (auto &info : kv.second) {
|
||||
info->UpdateInterval();
|
||||
info.second->UpdateInterval();
|
||||
}
|
||||
}
|
||||
for (auto &kv : outputs_diverg_info_) {
|
||||
for (auto &info : kv.second) {
|
||||
info.second->UpdateInterval();
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int Calibrator::UpdateDataFrequency(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info) {
|
||||
int Calibrator::UpdateDataFrequency(const std::vector<float> &data,
|
||||
const std::unique_ptr<DataDistribution> &diverg_info) {
|
||||
MS_ASSERT(diverg_info != nullptr);
|
||||
return diverg_info->UpdateHistogram(data);
|
||||
}
|
||||
|
@ -110,14 +116,31 @@ int Calibrator::AddQuantizedOp(const CNodePtr &cnode) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto node_name = cnode->fullname_with_scope();
|
||||
std::unique_ptr<DivergInfo> input_diverg = std::make_unique<DivergInfo>(
|
||||
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
|
||||
MS_CHECK_TRUE_MSG(input_diverg != nullptr, RET_NULL_PTR, "input_diverg is nullptr.");
|
||||
std::unique_ptr<DivergInfo> output_diverg = std::make_unique<DivergInfo>(
|
||||
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
|
||||
MS_CHECK_TRUE_MSG(output_diverg != nullptr, RET_NULL_PTR, "output_diverg is nullptr.");
|
||||
inputs_diverg_info_[node_name].push_back(std::move(input_diverg));
|
||||
outputs_diverg_info_[node_name].push_back(std::move(output_diverg));
|
||||
auto size = cnode->inputs().size();
|
||||
for (size_t i = 1; i < size; i++) {
|
||||
std::unique_ptr<DataDistribution> input_diverg = std::make_unique<DataDistribution>(
|
||||
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
|
||||
MS_CHECK_TRUE_MSG(input_diverg != nullptr, RET_NULL_PTR, "input_diverg is nullptr.");
|
||||
inputs_diverg_info_[node_name].insert({i - 1, std::move(input_diverg)});
|
||||
}
|
||||
|
||||
if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) {
|
||||
auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
|
||||
MS_CHECK_TRUE_MSG(tuple != nullptr, RET_ERROR, "tuple is nullptr");
|
||||
auto elements = tuple->elements();
|
||||
MS_ASSERT(elements.size() > 1);
|
||||
for (size_t i = 0; i < elements.size(); i++) {
|
||||
std::unique_ptr<DataDistribution> output_diverg = std::make_unique<DataDistribution>(
|
||||
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
|
||||
MS_CHECK_TRUE_MSG(output_diverg != nullptr, RET_NULL_PTR, "output_diverg is nullptr.");
|
||||
outputs_diverg_info_[node_name].insert({i, std::move(output_diverg)});
|
||||
}
|
||||
} else {
|
||||
std::unique_ptr<DataDistribution> output_diverg = std::make_unique<DataDistribution>(
|
||||
cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
|
||||
MS_CHECK_TRUE_MSG(output_diverg != nullptr, RET_NULL_PTR, "output_diverg is nullptr.");
|
||||
outputs_diverg_info_[node_name].insert({0, std::move(output_diverg)});
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -126,11 +149,11 @@ int Calibrator::GenerateInputData(const std::string &input_name, size_t image_in
|
|||
return preprocess::PreProcess(data_pre_process_param_, input_name, image_index, tensor);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *Calibrator::GetInputDivergInfo() {
|
||||
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *Calibrator::GetInputDivergInfo() {
|
||||
return &this->inputs_diverg_info_;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *Calibrator::GetOutputDivergInfo() {
|
||||
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *Calibrator::GetOutputDivergInfo() {
|
||||
return &this->outputs_diverg_info_;
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include <memory>
|
||||
#include "tools/converter/quantizer/quant_params.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/converter/quantizer/diverg_info.h"
|
||||
#include "tools/converter/quantizer/data_distribution.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class Calibrator {
|
||||
|
@ -45,17 +45,17 @@ class Calibrator {
|
|||
|
||||
int AddQuantizedOp(const CNodePtr &cnode);
|
||||
|
||||
int RecordMaxMinValue(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info);
|
||||
int RecordMaxMinValue(const std::vector<float> &data, const std::unique_ptr<DataDistribution> &diverg_info);
|
||||
|
||||
int UpdateDivergInterval(std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info);
|
||||
int UpdateDivergInterval();
|
||||
|
||||
int UpdateDataFrequency(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info);
|
||||
int UpdateDataFrequency(const std::vector<float> &data, const std::unique_ptr<DataDistribution> &diverg_info);
|
||||
|
||||
int ComputeThreshold();
|
||||
|
||||
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetInputDivergInfo();
|
||||
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *GetInputDivergInfo();
|
||||
|
||||
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetOutputDivergInfo();
|
||||
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *GetOutputDivergInfo();
|
||||
|
||||
FullQuantParam full_quant_param_;
|
||||
|
||||
|
@ -64,9 +64,10 @@ class Calibrator {
|
|||
int thread_ = 4;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> inputs_diverg_info_;
|
||||
|
||||
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> outputs_diverg_info_;
|
||||
// {node_name,{tensor_index,DataDistribution}}
|
||||
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> inputs_diverg_info_;
|
||||
// {node_name,{tensor_index,DataDistribution}}
|
||||
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> outputs_diverg_info_;
|
||||
|
||||
size_t bit_num_;
|
||||
int quant_max_;
|
||||
|
|
|
@ -14,20 +14,21 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/quantizer/diverg_info.h"
|
||||
#include "tools/converter/quantizer/data_distribution.h"
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
namespace mindspore::lite::quant {
|
||||
int DivergInfo::RecordMaxMinValue(const std::vector<float> &data) {
|
||||
int DataDistribution::RecordMaxMinValue(const std::vector<float> &data) {
|
||||
for (float val : data) {
|
||||
max = std::max(val, max);
|
||||
min = std::min(val, min);
|
||||
max_ = std::max(val, max_);
|
||||
min_ = std::min(val, min_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DivergInfo::RecordMaxMinValueArray(const std::vector<float> &data) {
|
||||
int DataDistribution::RecordMaxMinValueArray(const std::vector<float> &data) {
|
||||
if (data.empty()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -37,42 +38,42 @@ int DivergInfo::RecordMaxMinValueArray(const std::vector<float> &data) {
|
|||
max_num = std::max(val, max_num);
|
||||
min_num = std::min(val, min_num);
|
||||
}
|
||||
this->max_datas.emplace_back(max_num);
|
||||
this->min_datas.emplace_back(min_num);
|
||||
this->max_datas_.emplace_back(max_num);
|
||||
this->min_datas_.emplace_back(min_num);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void DivergInfo::UpdateInterval() {
|
||||
auto max_value = std::max(fabs(this->max), fabs(this->min));
|
||||
MS_ASSERT(bin_num != 0);
|
||||
this->interval = max_value / static_cast<float>(bin_num);
|
||||
void DataDistribution::UpdateInterval() {
|
||||
auto max_value = std::max(fabs(this->max_), fabs(this->min_));
|
||||
MS_ASSERT(bin_num_ != 0);
|
||||
this->interval_ = max_value / static_cast<float>(bin_num_);
|
||||
}
|
||||
|
||||
int DivergInfo::UpdateHistogram(const std::vector<float> &data) {
|
||||
int DataDistribution::UpdateHistogram(const std::vector<float> &data) {
|
||||
for (auto value : data) {
|
||||
if (value == 0) {
|
||||
continue;
|
||||
}
|
||||
if (this->interval == 0) {
|
||||
if (this->interval_ == 0) {
|
||||
MS_LOG(ERROR) << "divisor 'interval' cannot be 0.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval), bin_num - 1);
|
||||
this->histogram[bin_index]++;
|
||||
int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval_), bin_num_ - 1);
|
||||
this->histogram_[bin_index]++;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void DivergInfo::DumpHistogram() {
|
||||
MS_LOG(INFO) << "Print node " << cnode->fullname_with_scope() << " histogram";
|
||||
for (float item : this->histogram) {
|
||||
void DataDistribution::DumpHistogram() {
|
||||
MS_LOG(INFO) << "Print node " << cnode_->fullname_with_scope() << " histogram";
|
||||
for (float item : this->histogram_) {
|
||||
std::cout << item << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
void DivergInfo::HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<float> *quantized_histogram,
|
||||
std::vector<float> *expanded_histogram) {
|
||||
void DataDistribution::HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<float> *quantized_histogram,
|
||||
std::vector<float> *expanded_histogram) {
|
||||
MS_ASSERT(quantized_histogram != nullptr && expanded_histogram != nullptr);
|
||||
MS_ASSERT(quant_bint_nums != 0);
|
||||
const float bin_interval = static_cast<float>(bin_index) / static_cast<float>(quant_bint_nums);
|
||||
|
@ -83,14 +84,14 @@ void DivergInfo::HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<
|
|||
const int left_upper = static_cast<int>(std::ceil(start));
|
||||
if (left_upper > start) {
|
||||
const double left_scale = left_upper - start;
|
||||
quantized_histogram->at(i) += left_scale * this->histogram[left_upper - 1];
|
||||
quantized_histogram->at(i) += left_scale * this->histogram_[left_upper - 1];
|
||||
}
|
||||
const int right_lower = static_cast<int>(std::floor(end));
|
||||
if (right_lower < end) {
|
||||
const double right_scale = end - right_lower;
|
||||
quantized_histogram->at(i) += right_scale * this->histogram[right_lower];
|
||||
quantized_histogram->at(i) += right_scale * this->histogram_[right_lower];
|
||||
}
|
||||
std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower,
|
||||
std::for_each(this->histogram_.begin() + left_upper, this->histogram_.begin() + right_lower,
|
||||
[&quantized_histogram, i](float item) { quantized_histogram->at(i) += item; });
|
||||
}
|
||||
// expand target bins to i bins in order to calculate KL with reference_histogram
|
||||
|
@ -102,7 +103,7 @@ void DivergInfo::HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<
|
|||
float left_scale = 0.0f;
|
||||
if (left_upper > start) {
|
||||
left_scale = left_upper - start;
|
||||
if (this->histogram[left_upper - 1] != 0) {
|
||||
if (this->histogram_[left_upper - 1] != 0) {
|
||||
count += left_scale;
|
||||
}
|
||||
}
|
||||
|
@ -110,11 +111,11 @@ void DivergInfo::HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<
|
|||
double right_scale = 0.0f;
|
||||
if (right_lower < end) {
|
||||
right_scale = end - right_lower;
|
||||
if (this->histogram[right_lower] != 0) {
|
||||
if (this->histogram_[right_lower] != 0) {
|
||||
count += right_scale;
|
||||
}
|
||||
}
|
||||
std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower, [&count](float item) {
|
||||
std::for_each(this->histogram_.begin() + left_upper, this->histogram_.begin() + right_lower, [&count](float item) {
|
||||
if (item != 0) {
|
||||
count += 1;
|
||||
}
|
||||
|
@ -123,43 +124,43 @@ void DivergInfo::HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<
|
|||
continue;
|
||||
}
|
||||
const float average_num = quantized_histogram->at(i) / count;
|
||||
if (left_upper > start && this->histogram[left_upper - 1] != 0) {
|
||||
if (left_upper > start && this->histogram_[left_upper - 1] != 0) {
|
||||
expanded_histogram->at(left_upper - 1) += average_num * left_scale;
|
||||
}
|
||||
if (right_lower < end && this->histogram[right_lower] != 0) {
|
||||
if (right_lower < end && this->histogram_[right_lower] != 0) {
|
||||
expanded_histogram->at(right_lower) += average_num * right_scale;
|
||||
}
|
||||
for (int k = left_upper; k < right_lower; ++k) {
|
||||
if (this->histogram[k] != 0) {
|
||||
if (this->histogram_[k] != 0) {
|
||||
expanded_histogram->at(k) += average_num;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int DivergInfo::ComputeThreshold() {
|
||||
if (activation_quant_method == MAX_MIN) {
|
||||
this->best_T = std::max(fabs(this->max), fabs(this->min));
|
||||
MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T;
|
||||
int DataDistribution::ComputeThreshold() {
|
||||
if (activation_quant_method_ == MAX_MIN) {
|
||||
this->best_T_ = std::max(fabs(this->max_), fabs(this->min_));
|
||||
MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T_;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
if (activation_quant_method == REMOVAL_OUTLIER && !this->min_datas.empty()) {
|
||||
this->percent_result = OutlierMethod(min_datas, max_datas);
|
||||
this->best_T = std::max(std::fabs(percent_result.first), std::fabs(percent_result.second));
|
||||
if (activation_quant_method_ == REMOVAL_OUTLIER && !this->min_datas_.empty()) {
|
||||
this->percent_result_ = OutlierMethod(min_datas_, max_datas_);
|
||||
this->best_T_ = std::max(std::fabs(percent_result_.first), std::fabs(percent_result_.second));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int threshold = INT8_MAX + 1;
|
||||
float min_kl = FLT_MAX;
|
||||
float after_threshold_sum = std::accumulate(this->histogram.begin() + INT8_MAX + 1, this->histogram.end(), 0.0f);
|
||||
float after_threshold_sum = std::accumulate(this->histogram_.begin() + INT8_MAX + 1, this->histogram_.end(), 0.0f);
|
||||
|
||||
for (int i = INT8_MAX + 1; i < this->bin_num; ++i) {
|
||||
for (int i = INT8_MAX + 1; i < this->bin_num_; ++i) {
|
||||
std::vector<float> quantized_histogram(INT8_MAX + 1, 0);
|
||||
std::vector<float> reference_histogram(this->histogram.begin(), this->histogram.begin() + i);
|
||||
std::vector<float> reference_histogram(this->histogram_.begin(), this->histogram_.begin() + i);
|
||||
std::vector<float> expanded_histogram(i, 0);
|
||||
reference_histogram[i - 1] += after_threshold_sum;
|
||||
after_threshold_sum -= this->histogram[i];
|
||||
after_threshold_sum -= this->histogram_[i];
|
||||
// handle bins for computing KL.
|
||||
HandleBinForKL(INT8_MAX + 1, i, &quantized_histogram, &expanded_histogram);
|
||||
auto KLDivergence = [](std::vector<float> p, std::vector<float> q) {
|
||||
|
@ -189,41 +190,41 @@ int DivergInfo::ComputeThreshold() {
|
|||
threshold = i;
|
||||
}
|
||||
}
|
||||
this->best_T = (static_cast<float>(threshold) + 0.5f) * this->interval;
|
||||
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " Best threshold bin index: " << threshold << " T: " << best_T
|
||||
<< " max: " << std::max(fabs(this->max), fabs(this->min));
|
||||
this->best_T_ = (static_cast<float>(threshold) + 0.5f) * this->interval_;
|
||||
MS_LOG(DEBUG) << cnode_->fullname_with_scope() << " Best threshold bin index: " << threshold << " T: " << best_T_
|
||||
<< " max: " << std::max(fabs(this->max_), fabs(this->min_));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
std::pair<CNodePtr, float> DivergInfo::GetScale() {
|
||||
float max_value = this->best_T;
|
||||
float DataDistribution::GetScale() {
|
||||
float max_value = this->best_T_;
|
||||
float min_value = -max_value;
|
||||
|
||||
if (this->activation_quant_method == REMOVAL_OUTLIER) {
|
||||
min_value = percent_result.first;
|
||||
max_value = percent_result.second;
|
||||
if (this->activation_quant_method_ == REMOVAL_OUTLIER) {
|
||||
min_value = percent_result_.first;
|
||||
max_value = percent_result_.second;
|
||||
}
|
||||
|
||||
MS_CHECK_TRUE_MSG(quant_max - quant_min != 0, {}, "quant_max - quant_min == 0");
|
||||
float scale = (max_value - min_value) / (quant_max - quant_min);
|
||||
this->scale_tmp = scale;
|
||||
MS_ASSERT(fabs(scale) <= 0.0f);
|
||||
return std::make_pair(this->cnode, scale);
|
||||
MS_CHECK_TRUE_MSG(quant_max_ - quant_min_ > 0, 0, "quant_max_ - quant_min_ <= 0");
|
||||
this->scale_ = (max_value - min_value) / (quant_max_ - quant_min_);
|
||||
MS_ASSERT(fabs(this->scale_) <= 0.0f);
|
||||
return this->scale_;
|
||||
}
|
||||
|
||||
std::pair<CNodePtr, int32_t> DivergInfo::GetZeropoint() {
|
||||
// Support for asymmetry in the future
|
||||
int32_t DataDistribution::GetZeroPoint() {
|
||||
int zero_point = 0;
|
||||
if (quant_min == 0 && quant_max == UINT8_MAX) {
|
||||
if (quant_min_ == 0 && quant_max_ == UINT8_MAX) {
|
||||
zero_point = INT8_MAX + 1;
|
||||
} else if (quant_min == INT_LEAST8_MIN + 1 && quant_max == INT8_MAX) {
|
||||
} else if (quant_min_ == INT_LEAST8_MIN + 1 && quant_max_ == INT8_MAX) {
|
||||
zero_point = 0;
|
||||
} else {
|
||||
MS_LOG(WARNING) << "unexpected quant range, quant_min: " << quant_min << " quant_max: " << quant_max;
|
||||
MS_LOG(WARNING) << "unexpected quant range, quant_min_: " << quant_min_ << " quant_max_: " << quant_max_;
|
||||
}
|
||||
if (this->activation_quant_method == REMOVAL_OUTLIER) {
|
||||
MS_CHECK_TRUE_MSG(fabs(scale_tmp) <= 0.0f, {}, "fabs(scale_tmp) > 0.0f");
|
||||
zero_point = std::round(quant_max - percent_result.second / scale_tmp);
|
||||
if (this->activation_quant_method_ == REMOVAL_OUTLIER) {
|
||||
MS_CHECK_TRUE_MSG(fabs(scale_) <= 0.0f, 1, "fabs(scale) > 0.0f");
|
||||
zero_point = std::round(quant_max_ - percent_result_.second / scale_);
|
||||
}
|
||||
return std::make_pair(this->cnode, zero_point);
|
||||
return zero_point;
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
|
@ -0,0 +1,84 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DATA_DISTRIBUTION_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DATA_DISTRIBUTION_H
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "tools/converter/quantizer/quant_params.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
namespace mindspore::lite::quant {
|
||||
class DataDistribution {
|
||||
public:
|
||||
DataDistribution() = default;
|
||||
DataDistribution(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min,
|
||||
ActivationQuantizedMethod activation_quant_method) {
|
||||
this->activation_quant_method_ = activation_quant_method;
|
||||
this->cnode_ = std::move(cnode);
|
||||
this->bin_num_ = bins;
|
||||
this->bit_num_ = bits;
|
||||
histogram_.resize(bin_num_);
|
||||
max_ = -FLT_MAX;
|
||||
min_ = FLT_MAX;
|
||||
this->quant_max_ = quant_max;
|
||||
this->quant_min_ = quant_min;
|
||||
std::fill(histogram_.begin(), histogram_.end(), 1.0e-7);
|
||||
}
|
||||
|
||||
int RecordMaxMinValue(const std::vector<float> &data);
|
||||
|
||||
int RecordMaxMinValueArray(const std::vector<float> &data);
|
||||
|
||||
void UpdateInterval();
|
||||
|
||||
int UpdateHistogram(const std::vector<float> &data);
|
||||
|
||||
void DumpHistogram();
|
||||
|
||||
void HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<float> *quantized_histogram,
|
||||
std::vector<float> *expanded_histogram);
|
||||
|
||||
int ComputeThreshold();
|
||||
|
||||
float GetScale();
|
||||
|
||||
int32_t GetZeroPoint();
|
||||
|
||||
float GetMax() { return this->max_; }
|
||||
|
||||
float GetMin() { return this->min_; }
|
||||
|
||||
CNodePtr GetCNode() { return this->cnode_; }
|
||||
|
||||
private:
|
||||
std::vector<float> histogram_;
|
||||
CNodePtr cnode_;
|
||||
int bin_num_ = 0;
|
||||
float interval_ = 0;
|
||||
float max_ = 0.0f;
|
||||
float min_ = 0.0f;
|
||||
float best_T_ = 0.0f;
|
||||
size_t bit_num_ = 0;
|
||||
int quant_max_ = 255;
|
||||
int quant_min_ = 0;
|
||||
ActivationQuantizedMethod activation_quant_method_ = MAX_MIN;
|
||||
std::vector<float> min_datas_;
|
||||
std::vector<float> max_datas_;
|
||||
std::pair<float, float> percent_result_{0.0, 0.0};
|
||||
float scale_ = 0;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DATA_DISTRIBUTION_H
|
|
@ -1,84 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DIVERG_INFO_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DIVERG_INFO_H
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "tools/converter/quantizer/quant_params.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
namespace mindspore::lite::quant {
|
||||
class DivergInfo {
|
||||
public:
|
||||
DivergInfo() = default;
|
||||
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min,
|
||||
ActivationQuantizedMethod activation_quant_method) {
|
||||
this->activation_quant_method = activation_quant_method;
|
||||
this->cnode = std::move(cnode);
|
||||
this->bin_num = bins;
|
||||
this->bit_num = bits;
|
||||
histogram.resize(bin_num);
|
||||
max = -FLT_MAX;
|
||||
min = FLT_MAX;
|
||||
this->quant_max = quant_max;
|
||||
this->quant_min = quant_min;
|
||||
std::fill(histogram.begin(), histogram.end(), 1.0e-7);
|
||||
}
|
||||
|
||||
int RecordMaxMinValue(const std::vector<float> &data);
|
||||
|
||||
int RecordMaxMinValueArray(const std::vector<float> &data);
|
||||
|
||||
void UpdateInterval();
|
||||
|
||||
int UpdateHistogram(const std::vector<float> &data);
|
||||
|
||||
void DumpHistogram();
|
||||
|
||||
void HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<float> *quantized_histogram,
|
||||
std::vector<float> *expanded_histogram);
|
||||
|
||||
int ComputeThreshold();
|
||||
|
||||
std::pair<CNodePtr, float> GetScale();
|
||||
|
||||
std::pair<CNodePtr, int32_t> GetZeropoint();
|
||||
|
||||
float GetMax() { return this->max; }
|
||||
|
||||
float GetMin() { return this->min; }
|
||||
|
||||
CNodePtr GetCNode() { return this->cnode; }
|
||||
|
||||
private:
|
||||
std::vector<float> histogram;
|
||||
CNodePtr cnode;
|
||||
int bin_num = 0;
|
||||
float interval = 0;
|
||||
float max = 0.0f;
|
||||
float min = 0.0f;
|
||||
float best_T = 0.0f;
|
||||
size_t bit_num = 0;
|
||||
int quant_max = 255;
|
||||
int quant_min = 0;
|
||||
ActivationQuantizedMethod activation_quant_method = MAX_MIN;
|
||||
std::vector<float> min_datas;
|
||||
std::vector<float> max_datas;
|
||||
std::pair<float, float> percent_result{0.0, 0.0};
|
||||
float scale_tmp = 0;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_DIVERG_INFO_H
|
|
@ -32,6 +32,7 @@
|
|||
#include "src/tensor.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/converter/quantizer/quant_strategy.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "securec/include/securec.h"
|
||||
|
@ -128,18 +129,18 @@ int ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, const s
|
|||
FullQuantQuantizer::FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type)
|
||||
: Quantizer(std::move(graph)) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
this->bit_num = bit_num;
|
||||
this->target_type_ = target_type;
|
||||
this->bit_num_ = bit_num;
|
||||
this->target_data_type_ = target_type;
|
||||
if (target_type == kNumberTypeInt8) {
|
||||
quant_max = (1 << (this->bit_num - 1)) - 1; // 127
|
||||
quant_min = -quant_max; // -127
|
||||
q_max_ = (1 << (this->bit_num_ - 1)) - 1; // 127
|
||||
q_min_ = -q_max_; // -127
|
||||
} else if (target_type == kNumberTypeUInt8) {
|
||||
quant_max = (1 << this->bit_num) - 1; // 255
|
||||
quant_min = 0;
|
||||
q_max_ = (1 << this->bit_num_) - 1; // 255
|
||||
q_min_ = 0;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported quant value type: " << target_type;
|
||||
}
|
||||
calibrator_ = std::make_unique<Calibrator>(this->bit_num, quant_max, quant_min);
|
||||
calibrator_ = std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_);
|
||||
if (calibrator_ == nullptr) {
|
||||
MS_LOG(ERROR) << "create calibrator failed!";
|
||||
return;
|
||||
|
@ -153,7 +154,7 @@ FullQuantQuantizer::~FullQuantQuantizer() {
|
|||
delete int8_model_;
|
||||
}
|
||||
|
||||
int FullQuantQuantizer::SetInOutQuantParam(const AnfNodePtr &input_node, const std::unique_ptr<DivergInfo> &info,
|
||||
int FullQuantQuantizer::SetInOutQuantParam(const AnfNodePtr &input_node, const std::unique_ptr<DataDistribution> &info,
|
||||
const PrimitivePtr &primitive, bool is_input, size_t index) const {
|
||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
|
||||
|
@ -164,17 +165,17 @@ int FullQuantQuantizer::SetInOutQuantParam(const AnfNodePtr &input_node, const s
|
|||
return RET_ERROR;
|
||||
}
|
||||
if (type_id == kNumberTypeFloat32 && info != nullptr) {
|
||||
auto scale = info->GetScale().second;
|
||||
auto scale = info->GetScale();
|
||||
if (scale == 0) {
|
||||
MS_LOG(WARNING) << "The input or output values are very close to 0, so set the scale to 1.";
|
||||
quant_param.scale = 1;
|
||||
} else {
|
||||
quant_param.scale = scale;
|
||||
}
|
||||
quant_param.zeroPoint = info->GetZeropoint().second;
|
||||
quant_param.zeroPoint = info->GetZeroPoint();
|
||||
quant_param.max = info->GetMax();
|
||||
quant_param.min = info->GetMin();
|
||||
quant_param.numBits = bit_num;
|
||||
quant_param.numBits = bit_num_;
|
||||
quant_param.narrowRange = true;
|
||||
quant_param.inited = true;
|
||||
quant_param.roundType = 1;
|
||||
|
@ -210,13 +211,9 @@ int FullQuantQuantizer::DoWeightQuant(const std::string &op_name, const AnfNodeP
|
|||
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto bit_num_t = bit_num;
|
||||
auto quant_max_t = quant_max;
|
||||
auto quant_min_t = quant_min;
|
||||
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||
auto status =
|
||||
FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, QuantType_QUANT_ALL, quant_max_t, quant_min_t,
|
||||
bit_num_t, weight_quant_type, kNumberTypeInt8, input_index - 1);
|
||||
auto status = FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, QuantType_QUANT_ALL, q_max_, q_min_,
|
||||
bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1, true);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||
return status;
|
||||
|
@ -332,6 +329,7 @@ int FullQuantQuantizer::DoParameterNodeQuant(const CNodePtr &cnode, const AnfNod
|
|||
if (type_id == kNumberTypeInt8) {
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
// Only data the data type is fp32 can be quant.
|
||||
if (type_id != kNumberTypeFloat32) {
|
||||
ret = SetInOutQuantParam(input_node, nullptr, primitive, true, input_index - 1);
|
||||
if (ret != RET_OK) {
|
||||
|
@ -348,15 +346,7 @@ int FullQuantQuantizer::DoParameterNodeQuant(const CNodePtr &cnode, const AnfNod
|
|||
return ret;
|
||||
}
|
||||
} else {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimMatMul)) {
|
||||
if (input_index == FIRST_INPUT + 1) {
|
||||
ret = DoWeightQuant(op_name, input_node, primitive, false, input_index);
|
||||
} else {
|
||||
ret = DoWeightQuant(op_name, input_node, primitive, true, input_index);
|
||||
}
|
||||
} else {
|
||||
ret = DoWeightQuant(op_name, input_node, primitive, true, input_index);
|
||||
}
|
||||
ret = DoWeightQuant(op_name, input_node, primitive, true, input_index);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do bias quant failed.";
|
||||
return ret;
|
||||
|
@ -515,31 +505,24 @@ int FullQuantQuantizer::QuantNode() {
|
|||
}
|
||||
|
||||
int FullQuantQuantizer::UpdateDivergeInterval() {
|
||||
auto ret = this->calibrator_->UpdateDivergInterval(this->calibrator_->GetInputDivergInfo());
|
||||
auto ret = this->calibrator_->UpdateDivergInterval();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Update input diverge interval failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = this->calibrator_->UpdateDivergInterval(this->calibrator_->GetOutputDivergInfo());
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Update output diverge interval failed.";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
/**
|
||||
* Mark quantifiable nodes
|
||||
**/
|
||||
int FullQuantQuantizer::PreProcess() {
|
||||
auto cnodes = funcGraph->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
AnfNodePtr anf = cnode->cast<AnfNodePtr>();
|
||||
if (anf == nullptr) {
|
||||
auto anode = cnode->cast<AnfNodePtr>();
|
||||
if (anode == nullptr) {
|
||||
MS_LOG(ERROR) << " cnode is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (mindspore::lite::quant::QuantStrategy::CanOpFullQuantized(anf)) {
|
||||
// Mark quantifiable nodes
|
||||
if (mindspore::lite::quant::QuantStrategy::CanOpFullQuantized(anode)) {
|
||||
auto ret = calibrator_->AddQuantizedOp(cnode);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Add Quantized Op failed.";
|
||||
|
@ -574,12 +557,44 @@ int FullQuantQuantizer::CheckFp32TensorVec(const std::string &node_name,
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
/**
|
||||
* 1. create input tensor
|
||||
* 2. insert callback to session
|
||||
* 3. run session
|
||||
**/
|
||||
int FullQuantQuantizer::DoInference() {
|
||||
int FullQuantQuantizer::CollectDataDistribution(
|
||||
const std::string &node_name, const std::vector<mindspore::tensor::MSTensor *> &tensors,
|
||||
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *diverg_info_map,
|
||||
CollectType collect_type) {
|
||||
if (diverg_info_map->find(node_name) == diverg_info_map->end()) {
|
||||
return true;
|
||||
}
|
||||
for (size_t i = 0; i < tensors.size(); i++) {
|
||||
auto tensor = tensors[i];
|
||||
if (tensor->IsConst() || tensor->data_type() != kNumberTypeFloat32) {
|
||||
continue;
|
||||
}
|
||||
const auto *tensor_data = static_cast<const float *>(tensor->data());
|
||||
if (tensor_data == nullptr) {
|
||||
MS_LOG(ERROR) << tensor->tensor_name() << " tensor_data is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t elem_count = tensor->ElementsNum();
|
||||
MS_CHECK_GT(elem_count, 0, RET_ERROR);
|
||||
vector<float> data(tensor_data, tensor_data + elem_count);
|
||||
if (collect_type == MIN_MAX) {
|
||||
auto ret = this->calibrator_->RecordMaxMinValue(data, (*diverg_info_map)[node_name][i]);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << tensor->tensor_name() << " record max min value failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (collect_type == KL_BIN) {
|
||||
auto ret = this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[node_name][i]);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << tensor->tensor_name() << " update data frequency failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int FullQuantQuantizer::DoInference(CollectType collect_type) {
|
||||
// get input tensor
|
||||
vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
|
||||
if (inputs.size() != calibrator_->GetInputNum()) {
|
||||
|
@ -599,34 +614,10 @@ int FullQuantQuantizer::DoInference() {
|
|||
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
|
||||
const CallBackParam &callParam) -> bool {
|
||||
auto diverg_info_map = calibrator_->GetInputDivergInfo();
|
||||
if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
|
||||
return true;
|
||||
}
|
||||
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, beforeOutputs) != RET_OK) {
|
||||
return true;
|
||||
}
|
||||
bool is_init = beforeInputs.size() > 1 && (*diverg_info_map)[callParam.node_name].size() == 1;
|
||||
if (is_init) {
|
||||
for (size_t i = 1; i < beforeInputs.size(); i++) {
|
||||
if (beforeInputs.at(i)->data_type() != kNumberTypeFloat32 || beforeInputs.at(i)->IsConst()) {
|
||||
continue;
|
||||
}
|
||||
auto input_diverg = std::make_unique<DivergInfo>();
|
||||
MS_CHECK_TRUE_MSG(input_diverg != nullptr, false, "input_diverg is nullptr.");
|
||||
*input_diverg = *((*diverg_info_map)[callParam.node_name][0]);
|
||||
(*diverg_info_map)[callParam.node_name].push_back(std::move(input_diverg));
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < (*diverg_info_map)[callParam.node_name].size(); i++) {
|
||||
auto tensor = beforeInputs[i];
|
||||
MS_CHECK_TRUE_MSG(tensor != nullptr, false, "tensor is nullptr.");
|
||||
const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
|
||||
MS_CHECK_TRUE_MSG(tensor_data != nullptr, false, "tensor_data is nullptr.");
|
||||
size_t elem_count = tensor->ElementsNum();
|
||||
MS_CHECK_GT(elem_count, 0, false);
|
||||
vector<float> data(tensor_data, tensor_data + elem_count);
|
||||
auto ret = this->calibrator_->RecordMaxMinValue(data, (*diverg_info_map)[callParam.node_name][i]);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, false, "Record MaxMinValue failed!");
|
||||
auto ret = CollectDataDistribution(callParam.node_name, beforeInputs, diverg_info_map, collect_type);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "CollectDataDistribution failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
@ -635,31 +626,10 @@ int FullQuantQuantizer::DoInference() {
|
|||
const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
|
||||
const CallBackParam &callParam) -> bool {
|
||||
auto diverg_info_map = calibrator_->GetOutputDivergInfo();
|
||||
if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
|
||||
return true;
|
||||
}
|
||||
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
|
||||
return true;
|
||||
}
|
||||
bool is_init = afterOutputs.size() > 1 && (*diverg_info_map)[callParam.node_name].size() == 1;
|
||||
if (is_init) {
|
||||
for (size_t i = 1; i < afterOutputs.size(); i++) {
|
||||
auto output_diverg = std::make_unique<DivergInfo>();
|
||||
CHECK_NULL_RETURN(output_diverg);
|
||||
*output_diverg = *((*diverg_info_map)[callParam.node_name][0]);
|
||||
(*diverg_info_map)[callParam.node_name].push_back(std::move(output_diverg));
|
||||
}
|
||||
}
|
||||
size_t output_i = 0;
|
||||
for (const auto &tensor : afterOutputs) {
|
||||
const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
|
||||
CHECK_NULL_RETURN(tensor_data);
|
||||
size_t elem_count = tensor->ElementsNum();
|
||||
MS_CHECK_GT(elem_count, 0, false);
|
||||
vector<float> data(tensor_data, tensor_data + elem_count);
|
||||
auto ret = this->calibrator_->RecordMaxMinValue(data, (*diverg_info_map)[callParam.node_name][output_i]);
|
||||
MS_CHECK_TRUE_MSG(ret == RET_OK, false, "Record MaxMinValue failed!");
|
||||
output_i++;
|
||||
auto ret = CollectDataDistribution(callParam.node_name, afterOutputs, diverg_info_map, collect_type);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "CollectDataDistribution failed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
@ -735,14 +705,14 @@ int FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
|
|||
MS_LOG(ERROR) << "divisor 'calibrate_size' cannot be 0.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (auto &key_value : op_bias_diff_map) {
|
||||
for (auto &key_value : op_bias_diff_map_) {
|
||||
std::for_each(key_value.second.begin(), key_value.second.end(),
|
||||
[this](float &data) { data = data / calibrator_->GetBatchNum(); });
|
||||
}
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
if (op_bias_diff_map.find(op_name) == op_bias_diff_map.end()) {
|
||||
if (op_bias_diff_map_.find(op_name) == op_bias_diff_map_.end()) {
|
||||
continue;
|
||||
}
|
||||
status = BiasCorrection(func_graph, cnode);
|
||||
|
@ -756,7 +726,7 @@ int FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
|
|||
|
||||
int FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
const auto &bias_diff = op_bias_diff_map[op_name];
|
||||
const auto &bias_diff = op_bias_diff_map_[op_name];
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is nullptr";
|
||||
|
@ -839,89 +809,6 @@ int FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, const CNo
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int FullQuantQuantizer::CollectDataFrequency() {
|
||||
// get input tensor
|
||||
vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
|
||||
if (inputs.size() != calibrator_->GetInputNum()) {
|
||||
MS_LOG(ERROR) << "model's input tensor cnt: " << inputs.size() << " != " << calibrator_->GetInputNum();
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) {
|
||||
// set multi-input data
|
||||
for (size_t input_index = 0; input_index < inputs.size(); input_index++) {
|
||||
int status = calibrator_->GenerateInputData(inputs[input_index]->tensor_name(), i, inputs[input_index]);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "generate input data from images failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
KernelCallBack before_callback = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
|
||||
const CallBackParam &callParam) {
|
||||
auto diverg_info_map = calibrator_->GetInputDivergInfo();
|
||||
if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
|
||||
return true;
|
||||
}
|
||||
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, before_inputs) != RET_OK) {
|
||||
return true;
|
||||
}
|
||||
int input_i = 0;
|
||||
for (auto tensor : before_inputs) {
|
||||
if (tensor->data_type() != kNumberTypeFloat32 || tensor->IsConst()) {
|
||||
continue;
|
||||
}
|
||||
const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
|
||||
MS_ASSERT(tensor_data != nullptr);
|
||||
size_t elem_count = tensor->ElementsNum();
|
||||
MS_CHECK_GT(elem_count, 0, false);
|
||||
vector<float> data(tensor_data, tensor_data + elem_count);
|
||||
auto ret = this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[callParam.node_name][input_i++]);
|
||||
if (ret != RET_OK) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
KernelCallBack after_callBack = [&](const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
|
||||
const CallBackParam &call_param) {
|
||||
auto diverg_info_map = calibrator_->GetOutputDivergInfo();
|
||||
if (diverg_info_map->find(call_param.node_name) == diverg_info_map->end()) {
|
||||
return true;
|
||||
}
|
||||
if (FullQuantQuantizer::CheckFp32TensorVec(call_param.node_name, after_outputs) != RET_OK) {
|
||||
return true;
|
||||
}
|
||||
int output_i = 0;
|
||||
// all outputs are same dtype.
|
||||
for (const auto &tensor : after_outputs) {
|
||||
const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
|
||||
MS_ASSERT(tensor_data != nullptr);
|
||||
size_t elem_count = tensor->ElementsNum();
|
||||
MS_CHECK_GT(elem_count, 0, false);
|
||||
vector<float> data(tensor_data, tensor_data + elem_count);
|
||||
auto ret = this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[call_param.node_name][output_i++]);
|
||||
if (ret != RET_OK) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
fp32_session_->BindThread(true);
|
||||
auto status = fp32_session_->RunGraph(before_callback, after_callBack);
|
||||
fp32_session_->BindThread(false);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "run model failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int FullQuantQuantizer::ComputeThreshold() { return this->calibrator_->ComputeThreshold(); }
|
||||
|
||||
int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||
|
@ -944,7 +831,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
// anf -- fb
|
||||
flags.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
|
||||
MS_LOG(INFO) << "start create session";
|
||||
auto sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum(), false);
|
||||
auto sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
|
||||
fp32_session_ = sm.session;
|
||||
fp32_model_ = sm.model;
|
||||
if (fp32_session_ == nullptr || fp32_model_ == nullptr) {
|
||||
|
@ -952,7 +839,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
MS_LOG(INFO) << "start to update divergence's max value";
|
||||
status = DoInference();
|
||||
status = DoInference(MIN_MAX);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do inference failed.";
|
||||
return status;
|
||||
|
@ -964,7 +851,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
return status;
|
||||
}
|
||||
MS_LOG(INFO) << "start to collect data's distribution";
|
||||
status = CollectDataFrequency();
|
||||
status = DoInference(KL_BIN);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Collect data frequency failed.";
|
||||
return status;
|
||||
|
@ -995,7 +882,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
// init in8 session
|
||||
MS_LOG(INFO) << "create quant session";
|
||||
flags.commonQuantParam.quant_type = schema::QuantType_QUANT_ALL;
|
||||
int8_sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum(), false);
|
||||
int8_sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
|
||||
int8_session_ = int8_sm.session;
|
||||
int8_model_ = int8_sm.model;
|
||||
if (int8_session_ == nullptr || int8_model_ == nullptr) {
|
||||
|
@ -1013,21 +900,21 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
|
||||
bool FullQuantQuantizer::OpInputDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
|
||||
MS_ASSERT(data != nullptr);
|
||||
std::lock_guard<std::mutex> lg(mutex_op_input);
|
||||
std::lock_guard<std::mutex> lg(mutex_op_input_);
|
||||
if (type == STORE) {
|
||||
if (fp32_op_input_map.find(op_name) != fp32_op_input_map.end()) {
|
||||
if (fp32_op_input_map_.find(op_name) != fp32_op_input_map_.end()) {
|
||||
// the data has not been fetched by int8 model
|
||||
return false;
|
||||
}
|
||||
fp32_op_input_map[op_name] = *data;
|
||||
fp32_op_input_map_[op_name] = *data;
|
||||
return true;
|
||||
} else if (type == FETCH) {
|
||||
if (fp32_op_input_map.find(op_name) == fp32_op_input_map.end()) {
|
||||
if (fp32_op_input_map_.find(op_name) == fp32_op_input_map_.end()) {
|
||||
// the data not generated by fp32 model yet
|
||||
return false;
|
||||
}
|
||||
*data = fp32_op_input_map[op_name];
|
||||
fp32_op_input_map.erase(op_name);
|
||||
*data = fp32_op_input_map_[op_name];
|
||||
fp32_op_input_map_.erase(op_name);
|
||||
return true;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unexpected type: " << type;
|
||||
|
@ -1037,21 +924,21 @@ bool FullQuantQuantizer::OpInputDataHandle(OperationType type, const string &op_
|
|||
|
||||
bool FullQuantQuantizer::OpOutputChMeanDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
|
||||
MS_ASSERT(data != nullptr);
|
||||
std::lock_guard<std::mutex> lg(mutex_op_output);
|
||||
std::lock_guard<std::mutex> lg(mutex_op_output_);
|
||||
if (type == STORE) {
|
||||
if (fp32_op_output_ch_mean_map.find(op_name) != fp32_op_output_ch_mean_map.end()) {
|
||||
if (fp32_op_output_ch_mean_map_.find(op_name) != fp32_op_output_ch_mean_map_.end()) {
|
||||
// the data has not been fetched by int8 model
|
||||
return false;
|
||||
}
|
||||
fp32_op_output_ch_mean_map[op_name] = *data;
|
||||
fp32_op_output_ch_mean_map_[op_name] = *data;
|
||||
return true;
|
||||
} else if (type == FETCH) {
|
||||
if (fp32_op_output_ch_mean_map.find(op_name) == fp32_op_output_ch_mean_map.end()) {
|
||||
if (fp32_op_output_ch_mean_map_.find(op_name) == fp32_op_output_ch_mean_map_.end()) {
|
||||
// the data not generated by fp32 model yet
|
||||
return false;
|
||||
}
|
||||
*data = fp32_op_output_ch_mean_map[op_name];
|
||||
fp32_op_output_ch_mean_map.erase(op_name);
|
||||
*data = fp32_op_output_ch_mean_map_[op_name];
|
||||
fp32_op_output_ch_mean_map_.erase(op_name);
|
||||
return true;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unexpected type: " << type;
|
||||
|
@ -1074,8 +961,7 @@ KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
|
|||
size_t elem_count = tensor->ElementsNum();
|
||||
MS_CHECK_GT(elem_count, 0, false);
|
||||
std::vector<float> fp32_op_input(elem_count);
|
||||
auto ret =
|
||||
memcpy_s(fp32_op_input.data(), fp32_op_input.size() * sizeof(float), tensor->MutableData(), tensor->Size());
|
||||
auto ret = memcpy_s(fp32_op_input.data(), fp32_op_input.size() * sizeof(float), tensor->data(), tensor->Size());
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
return false;
|
||||
|
@ -1112,7 +998,7 @@ KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
|
|||
quant_param_t.scale = quant_params[0].scale;
|
||||
quant_param_t.zeroPoint = quant_params[0].zeroPoint;
|
||||
for (auto float_data : fp32_op_input) {
|
||||
auto quant_data = QuantizeData<int8_t>(float_data, &quant_param_t, quant_max, quant_min);
|
||||
auto quant_data = QuantizeData<int8_t>(float_data, &quant_param_t, q_max_, q_min_);
|
||||
quant_datas.push_back(quant_data);
|
||||
}
|
||||
|
||||
|
@ -1122,8 +1008,7 @@ KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
|
|||
return false;
|
||||
}
|
||||
|
||||
auto ret =
|
||||
memcpy_s(tensor->MutableData(), tensor->Size(), quant_datas.data(), quant_datas.size() * sizeof(int8_t));
|
||||
auto ret = memcpy_s(tensor->data(), tensor->Size(), quant_datas.data(), quant_datas.size() * sizeof(int8_t));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
return false;
|
||||
|
@ -1158,7 +1043,7 @@ KernelCallBack FullQuantQuantizer::GetInt8AfterCallBack() {
|
|||
MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type();
|
||||
return false;
|
||||
}
|
||||
const int8_t *tensor_data = static_cast<int8_t *>(tensor->MutableData());
|
||||
const int8_t *tensor_data = static_cast<int8_t *>(tensor->data());
|
||||
size_t elem_count = tensor->ElementsNum();
|
||||
MS_CHECK_GT(elem_count, 0, false);
|
||||
auto shapes = tensor->shape();
|
||||
|
@ -1203,12 +1088,12 @@ KernelCallBack FullQuantQuantizer::GetInt8AfterCallBack() {
|
|||
std::transform(fp32_op_output_ch_mean.begin(), fp32_op_output_ch_mean.end(), dequant_op_output_ch_mean.begin(),
|
||||
dequant_op_output_ch_mean.begin(), std::minus<>());
|
||||
|
||||
if (op_bias_diff_map.find(callParam.node_name) != op_bias_diff_map.end()) {
|
||||
auto &bias_diff = op_bias_diff_map[callParam.node_name];
|
||||
if (op_bias_diff_map_.find(callParam.node_name) != op_bias_diff_map_.end()) {
|
||||
auto &bias_diff = op_bias_diff_map_[callParam.node_name];
|
||||
std::transform(bias_diff.begin(), bias_diff.end(), dequant_op_output_ch_mean.begin(), bias_diff.begin(),
|
||||
std::plus<>());
|
||||
} else {
|
||||
op_bias_diff_map[callParam.node_name] = dequant_op_output_ch_mean;
|
||||
op_bias_diff_map_[callParam.node_name] = dequant_op_output_ch_mean;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
@ -1226,7 +1111,7 @@ KernelCallBack FullQuantQuantizer::GetFloatAfterCallBack() {
|
|||
}
|
||||
auto tensor = afterOutputs[0];
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
|
||||
const auto *tensor_data = static_cast<const float *>(tensor->data());
|
||||
size_t elem_count = tensor->ElementsNum();
|
||||
MS_CHECK_GT(elem_count, 0, false);
|
||||
auto shapes = tensor->shape();
|
||||
|
|
|
@ -34,13 +34,17 @@
|
|||
#include "tools/converter/quantizer/quant_params.h"
|
||||
#include "tools/converter/preprocess/preprocess_param.h"
|
||||
#include "tools/converter/quantizer/calibrator.h"
|
||||
#include "tools/converter/quantizer/diverg_info.h"
|
||||
#include "tools/converter/quantizer/data_distribution.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
enum OperationType {
|
||||
STORE,
|
||||
FETCH,
|
||||
};
|
||||
enum CollectType {
|
||||
MIN_MAX,
|
||||
KL_BIN,
|
||||
};
|
||||
class FullQuantQuantizer : public Quantizer {
|
||||
public:
|
||||
FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type = kNumberTypeInt8);
|
||||
|
@ -54,22 +58,19 @@ class FullQuantQuantizer : public Quantizer {
|
|||
|
||||
int PreProcess();
|
||||
|
||||
static int CheckFp32TensorVec(const std::string &node_name,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &tensor_vec);
|
||||
int CheckFp32TensorVec(const std::string &node_name, const std::vector<mindspore::tensor::MSTensor *> &tensor_vec);
|
||||
|
||||
int DoInference();
|
||||
int DoInference(CollectType collect_type);
|
||||
|
||||
int UpdateDivergeInterval();
|
||||
|
||||
int CollectDataFrequency();
|
||||
|
||||
int ComputeThreshold();
|
||||
|
||||
int QuantNodeSimpleOp(const CNodePtr &cnode);
|
||||
|
||||
int QuantNode();
|
||||
|
||||
int SetInOutQuantParam(const AnfNodePtr &input_node, const std::unique_ptr<DivergInfo> &info,
|
||||
int SetInOutQuantParam(const AnfNodePtr &input_node, const std::unique_ptr<DataDistribution> &info,
|
||||
const PrimitivePtr &primitive, bool is_input, size_t index) const;
|
||||
|
||||
int DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight, const PrimitivePtr &primitive,
|
||||
|
@ -77,7 +78,12 @@ class FullQuantQuantizer : public Quantizer {
|
|||
|
||||
int DoParameterNodeQuant(const CNodePtr &cnode, const AnfNodePtr &input_node, size_t input_index);
|
||||
|
||||
static int DoBiasQuant(const AnfNodePtr &bias, const PrimitivePtr &primitive);
|
||||
int CollectDataDistribution(
|
||||
const std::string &node_name, const std::vector<mindspore::tensor::MSTensor *> &tensors,
|
||||
std::unordered_map<std::string, std::map<int, std::unique_ptr<DataDistribution>>> *diverg_info_map,
|
||||
CollectType collect_type);
|
||||
|
||||
int DoBiasQuant(const AnfNodePtr &bias, const PrimitivePtr &primitive);
|
||||
int Int8Inference();
|
||||
int BiasCorrection(const FuncGraphPtr &func_graph);
|
||||
int BiasCorrection(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
|
@ -87,22 +93,22 @@ class FullQuantQuantizer : public Quantizer {
|
|||
KernelCallBack GetFloatAfterCallBack();
|
||||
|
||||
private:
|
||||
TypeId target_type_{kNumberTypeInt8};
|
||||
TypeId target_data_type_{kNumberTypeInt8};
|
||||
std::unique_ptr<Calibrator> calibrator_{nullptr};
|
||||
session::LiteSession *fp32_session_{nullptr};
|
||||
Model *fp32_model_{nullptr};
|
||||
session::LiteSession *int8_session_{nullptr};
|
||||
Model *int8_model_{nullptr};
|
||||
|
||||
std::map<std::string, std::vector<float>> fp32_op_input_map; // concurrency
|
||||
std::map<std::string, std::vector<float>> fp32_op_output_ch_mean_map; // concurrency
|
||||
std::map<std::string, std::vector<float>> op_bias_diff_map; // only use by int8 model
|
||||
std::mutex mutex_op_input;
|
||||
std::mutex mutex_op_output;
|
||||
std::map<std::string, std::vector<float>> fp32_op_input_map_; // concurrency
|
||||
std::map<std::string, std::vector<float>> fp32_op_output_ch_mean_map_; // concurrency
|
||||
std::map<std::string, std::vector<float>> op_bias_diff_map_; // only use by int8 model
|
||||
std::mutex mutex_op_input_;
|
||||
std::mutex mutex_op_output_;
|
||||
|
||||
size_t bit_num;
|
||||
int quant_max{INT8_MAX};
|
||||
int quant_min{INT8_MIN};
|
||||
size_t bit_num_;
|
||||
int q_max_{INT8_MAX};
|
||||
int q_min_{INT8_MIN};
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FULL_QUANT_QUANTIZER_H
|
||||
|
|
|
@ -0,0 +1,123 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/quantizer/quant_strategy.h"
|
||||
#include <set>
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/log_util.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &input_node, int preferred_dim) {
|
||||
if (input_node == nullptr) {
|
||||
MS_LOG(INFO) << "CanTensorQuantized input is nullptr!";
|
||||
return false;
|
||||
}
|
||||
ParameterPtr param_node = nullptr;
|
||||
if (input_node->isa<Parameter>()) {
|
||||
param_node = input_node->cast<ParameterPtr>();
|
||||
}
|
||||
if (param_node == nullptr) {
|
||||
MS_LOG(INFO) << "CanTensorQuantized invalid param_node!";
|
||||
return false;
|
||||
}
|
||||
if (!param_node->has_default()) {
|
||||
MS_LOG(INFO) << "param_node don't has default.";
|
||||
return false;
|
||||
}
|
||||
auto abstract_base = param_node->abstract();
|
||||
if (abstract_base == nullptr) {
|
||||
MS_LOG(INFO) << "abstract is nullptr";
|
||||
return false;
|
||||
}
|
||||
if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
|
||||
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << param_node->name();
|
||||
return false;
|
||||
}
|
||||
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
|
||||
MS_ASSERT(weight_shape != nullptr);
|
||||
if (weight_shape.size() < DIMENSION_2D) { // do not quant single dim tensors
|
||||
return false;
|
||||
}
|
||||
int64_t total_shape_size = 1;
|
||||
for (auto shape : weight_shape) {
|
||||
MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(total_shape_size, shape), RET_ERROR, "Int mul overflow");
|
||||
total_shape_size *= shape;
|
||||
}
|
||||
if (total_shape_size < 0 || static_cast<size_t>(total_shape_size) < min_quant_weight_size_) {
|
||||
MS_LOG(INFO) << "shape_size " << total_shape_size << " less min_quant_weight_size_ " << min_quant_weight_size_;
|
||||
return false;
|
||||
}
|
||||
|
||||
// min_quant_weight_channel_ only supports convolution
|
||||
if (weight_shape.size() > DIMENSION_2D &&
|
||||
weight_shape[preferred_dim] <= static_cast<int>(min_quant_weight_channel_)) {
|
||||
MS_LOG(INFO) << "preferred_dim shape:" << weight_shape[preferred_dim] << " less min_quant_weight_channel_ "
|
||||
<< min_quant_weight_channel_;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool QuantStrategy::CanOpFullQuantized(const AnfNodePtr &node) {
|
||||
MS_CHECK_TRUE_RET(node != nullptr, false);
|
||||
if (!node->isa<mindspore::CNode>()) {
|
||||
return false;
|
||||
}
|
||||
const auto cnode = std::dynamic_pointer_cast<mindspore::CNode>(node);
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto type = NodePrimitiveType(cnode);
|
||||
static const std::set<PrimitivePtr> support_int8_ops = {prim::kPrimAddFusion, prim::kPrimActivation,
|
||||
prim::kPrimAvgPoolFusion, prim::kPrimConcat,
|
||||
prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
|
||||
prim::kPrimCrop, prim::kPrimFullConnection,
|
||||
prim::kPrimGather, prim::kPrimLayerNormFusion,
|
||||
prim::kPrimMatMul, prim::kPrimMaxPoolFusion,
|
||||
prim::kPrimMulFusion, prim::kPrimReshape,
|
||||
prim::kPrimSplit, prim::kPrimTranspose,
|
||||
prim::kPrimReduceFusion, prim::kPrimDivFusion,
|
||||
prim::kPrimSqrt, prim::kPrimPowFusion,
|
||||
prim::kPrimUnsqueeze, prim::kPrimAffine};
|
||||
// The return node does not need to be quantified.
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn) || opt::CheckPrimitiveType(cnode, prim::kPrimMakeTuple)) {
|
||||
return false;
|
||||
}
|
||||
// These operators do not need to check the data type.
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimShape) || opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
|
||||
return true;
|
||||
}
|
||||
auto is_support_node = CheckNodeInSet(cnode, support_int8_ops);
|
||||
if (!is_support_node && type != "Eltwise") {
|
||||
MS_LOG(WARNING) << "node:" << cnode->fullname_with_scope() << " type:" << type << " is not support quantization.";
|
||||
return false;
|
||||
}
|
||||
TypeId type_id;
|
||||
auto ret = opt::GetDataTypeFromAnfNode(cnode, &type_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Fetch DataType from cnode failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_data_type_fp32 = type_id == kNumberTypeFloat32;
|
||||
if (!is_data_type_fp32) {
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << " type_id is " << type_id << " , and is not float32.";
|
||||
}
|
||||
return is_data_type_fp32;
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_STRATEGY_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_STRATEGY_H
|
||||
#include <cstddef>
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class QuantStrategy {
|
||||
public:
|
||||
QuantStrategy(size_t min_quant_weight_size, size_t min_quant_weight_channel)
|
||||
: min_quant_weight_size_(min_quant_weight_size), min_quant_weight_channel_(min_quant_weight_channel) {}
|
||||
|
||||
~QuantStrategy() = default;
|
||||
|
||||
static bool CanOpFullQuantized(const AnfNodePtr &node);
|
||||
bool CanTensorQuantized(const AnfNodePtr &input_node, int preferred_dim);
|
||||
|
||||
private:
|
||||
size_t min_quant_weight_size_;
|
||||
size_t min_quant_weight_channel_;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_STRATEGY_H
|
|
@ -17,17 +17,11 @@
|
|||
#include "mindspore/lite/tools/converter/quantizer/quantize_util.h"
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <functional>
|
||||
#include "include/version.h"
|
||||
#include "ops/affine.h"
|
||||
#include "ops/fusion/conv2d_fusion.h"
|
||||
#include "ops/fusion/conv2d_transpose_fusion.h"
|
||||
#include "ops/fusion/full_connection.h"
|
||||
#include "ops/mat_mul.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
|
@ -44,9 +38,6 @@ using std::string;
|
|||
using std::vector;
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
constexpr int kDim2 = 2;
|
||||
constexpr int kDim4 = 4;
|
||||
|
||||
const int kLstmInputWeightIndex = 1;
|
||||
const int kLstmStateWeightIndex = 2;
|
||||
const int kLstmWeightShapeSize = 3;
|
||||
|
@ -54,102 +45,6 @@ const int kSingleDirBiasTensorSize = 4;
|
|||
const int kLstmBiasShapeSize = 2;
|
||||
const int kLstmBiasIndex = 3;
|
||||
|
||||
bool QuantStrategy::CanOpFullQuantized(const AnfNodePtr &node) {
|
||||
MS_CHECK_TRUE_RET(node != nullptr, false);
|
||||
if (!node->isa<mindspore::CNode>()) {
|
||||
return false;
|
||||
}
|
||||
const auto cnode = std::dynamic_pointer_cast<mindspore::CNode>(node);
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto type = NodePrimitiveType(cnode);
|
||||
static const std::set<PrimitivePtr> support_int8_ops = {prim::kPrimAddFusion, prim::kPrimActivation,
|
||||
prim::kPrimAvgPoolFusion, prim::kPrimConcat,
|
||||
prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
|
||||
prim::kPrimCrop, prim::kPrimFullConnection,
|
||||
prim::kPrimGather, prim::kPrimLayerNormFusion,
|
||||
prim::kPrimMatMul, prim::kPrimMaxPoolFusion,
|
||||
prim::kPrimMulFusion, prim::kPrimReshape,
|
||||
prim::kPrimSplit, prim::kPrimTranspose,
|
||||
prim::kPrimReduceFusion, prim::kPrimDivFusion,
|
||||
prim::kPrimSqrt, prim::kPrimPowFusion,
|
||||
prim::kPrimUnsqueeze, prim::kPrimAffine};
|
||||
// The return node does not need to be quantified.
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimReturn) || opt::CheckPrimitiveType(cnode, prim::kPrimMakeTuple)) {
|
||||
return false;
|
||||
}
|
||||
// These operators do not need to check the data type.
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimShape) || opt::CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
|
||||
return true;
|
||||
}
|
||||
auto is_support_node = CheckNodeInSet(cnode, support_int8_ops);
|
||||
if (!is_support_node && type != "Eltwise") {
|
||||
MS_LOG(WARNING) << "node:" << cnode->fullname_with_scope() << " type:" << type << " is not support quantization.";
|
||||
return false;
|
||||
}
|
||||
TypeId type_id;
|
||||
auto ret = opt::GetDataTypeFromAnfNode(cnode, &type_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Fetch DataType from cnode failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool is_data_type_fp32 = type_id == kNumberTypeFloat32;
|
||||
if (!is_data_type_fp32) {
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << " type_id is " << type_id << " , and is not float32.";
|
||||
}
|
||||
return is_data_type_fp32;
|
||||
}
|
||||
|
||||
bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &input_node, int preferred_dim) const {
|
||||
if (input_node == nullptr) {
|
||||
MS_LOG(INFO) << "CanTensorQuantized input is nullptr!";
|
||||
return false;
|
||||
}
|
||||
ParameterPtr param_node = nullptr;
|
||||
if (input_node->isa<Parameter>()) {
|
||||
param_node = input_node->cast<ParameterPtr>();
|
||||
}
|
||||
if (param_node == nullptr) {
|
||||
MS_LOG(INFO) << "CanTensorQuantized invalid param_node!";
|
||||
return false;
|
||||
}
|
||||
if (!param_node->has_default()) {
|
||||
MS_LOG(INFO) << "param_node don't has default.";
|
||||
return false;
|
||||
}
|
||||
auto abstract_base = param_node->abstract();
|
||||
if (abstract_base == nullptr) {
|
||||
MS_LOG(INFO) << "abstract is nullptr";
|
||||
return false;
|
||||
}
|
||||
if (!utils::isa<abstract::ShapePtr>(abstract_base->GetShapeTrack())) {
|
||||
MS_LOG(INFO) << "Shape of Abstract of parameter should be ShapePtr " << param_node->name();
|
||||
return false;
|
||||
}
|
||||
auto weight_shape = utils::cast<abstract::ShapePtr>(abstract_base->GetShapeTrack())->shape();
|
||||
MS_ASSERT(weight_shape != nullptr);
|
||||
if (weight_shape.size() < kDim2) { // do not quant single dim tensors
|
||||
return false;
|
||||
}
|
||||
int64_t total_shape_size = 1;
|
||||
for (auto shape : weight_shape) {
|
||||
MS_CHECK_FALSE_MSG(INT_MUL_OVERFLOW(total_shape_size, shape), RET_ERROR, "Int mul overflow");
|
||||
total_shape_size *= shape;
|
||||
}
|
||||
if (total_shape_size < 0 || static_cast<size_t>(total_shape_size) < min_quant_weight_size_) {
|
||||
MS_LOG(INFO) << "shape_size " << total_shape_size << " less min_quant_weight_size_ " << min_quant_weight_size_;
|
||||
return false;
|
||||
}
|
||||
|
||||
// min_quant_weight_channel_ only supports convolution
|
||||
if (weight_shape.size() > kDim2 && weight_shape[preferred_dim] <= static_cast<int>(min_quant_weight_channel_)) {
|
||||
MS_LOG(INFO) << "preferred_dim shape:" << weight_shape[preferred_dim] << " less min_quant_weight_channel_ "
|
||||
<< min_quant_weight_channel_;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive) {
|
||||
MS_CHECK_TRUE_RET(primitive != nullptr, nullptr);
|
||||
QuantParamHolderPtr quant_params_holder = nullptr;
|
||||
|
@ -359,7 +254,7 @@ std::string NodePrimitiveType(const CNodePtr &cnode) {
|
|||
}
|
||||
|
||||
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num,
|
||||
int *size, bool is_debug) {
|
||||
int *size) {
|
||||
SessionModel sm;
|
||||
auto meta_graph = Export(func_graph, true, true);
|
||||
if (meta_graph == nullptr) {
|
||||
|
@ -414,19 +309,15 @@ SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const conv
|
|||
delete model;
|
||||
return sm;
|
||||
}
|
||||
if (!is_debug) {
|
||||
model->Free();
|
||||
}
|
||||
delete meta_graph;
|
||||
sm.session = session;
|
||||
sm.model = model;
|
||||
return sm;
|
||||
}
|
||||
|
||||
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num,
|
||||
bool is_debug) {
|
||||
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num) {
|
||||
int size = 0;
|
||||
return CreateSessionByFuncGraph(func_graph, flags, thread_num, &size, is_debug);
|
||||
return CreateSessionByFuncGraph(func_graph, flags, thread_num, &size);
|
||||
}
|
||||
|
||||
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info) {
|
||||
|
@ -550,7 +441,7 @@ void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<in
|
|||
MS_LOG(ERROR) << " shape vector is empty.";
|
||||
return;
|
||||
}
|
||||
if (primitive.value.type == schema::PrimitiveType_MatMul && static_cast<int>(shapes.size()) == kDim2) {
|
||||
if (primitive.value.type == schema::PrimitiveType_MatMul && static_cast<int>(shapes.size()) == DIMENSION_2D) {
|
||||
auto matmul_prim = primitive.value.AsMatMul();
|
||||
MS_ASSERT(matmul_prim != nullptr);
|
||||
*channel_at_first = index != 1 || matmul_prim->transpose_b;
|
||||
|
|
|
@ -60,36 +60,15 @@ constexpr size_t kMaxBit = 8;
|
|||
constexpr size_t kMaxNum1024 = 1024;
|
||||
constexpr float kPercentBase = 100.0;
|
||||
constexpr size_t kMillisecondsBase = 10;
|
||||
constexpr float delta = 0.1;
|
||||
constexpr float ratio = 10.0;
|
||||
constexpr int percent = 10;
|
||||
|
||||
struct SessionModel {
|
||||
session::LiteSession *session{nullptr};
|
||||
Model *model{nullptr};
|
||||
};
|
||||
|
||||
/**
|
||||
* 1. when op's weight size > mWeightSize just skip
|
||||
* 2. only do conv/deconv/convdepthwise/deconvdepthwise/mul/matmul/batchmatmul quantization
|
||||
* 3. when conv/deconv/convdepthwise/deconvdepthwise ops' weight channel size > covWeightQuantChannelThreshold just skip
|
||||
* */
|
||||
class QuantStrategy {
|
||||
public:
|
||||
QuantStrategy(size_t min_quant_weight_size, size_t min_quant_weight_channel)
|
||||
: min_quant_weight_size_(min_quant_weight_size), min_quant_weight_channel_(min_quant_weight_channel) {}
|
||||
|
||||
~QuantStrategy() = default;
|
||||
|
||||
static bool CanOpFullQuantized(const AnfNodePtr &node);
|
||||
bool CanTensorQuantized(const AnfNodePtr &input_node, int preferred_dim) const;
|
||||
|
||||
private:
|
||||
size_t min_quant_weight_size_;
|
||||
size_t min_quant_weight_channel_;
|
||||
};
|
||||
|
||||
constexpr float delta = 0.1;
|
||||
constexpr float ratio = 10.0;
|
||||
constexpr int percent = 10;
|
||||
|
||||
QuantParamHolderPtr GetCNodeQuantHolder(const PrimitivePtr &primitive);
|
||||
|
||||
std::pair<float, float> OutlierMethod(std::vector<float> min_datas, std::vector<float> max_datas);
|
||||
|
@ -117,7 +96,8 @@ std::vector<int> ConvertShapeVectorToInt32(const ShapeVector &dims);
|
|||
template <typename T>
|
||||
int FixedBitQuantFilter(const ParameterPtr ¶meter, const tensor::TensorPtr &weight, const PrimitivePtr &primitive,
|
||||
QuantType quant_type, int quant_max, int quant_min, size_t bit_num,
|
||||
WeightQuantType weight_quant_type, TypeId quant_data_type, int index, bool k_means = false) {
|
||||
WeightQuantType weight_quant_type, TypeId quant_data_type, int index, bool narrow_range = false,
|
||||
bool k_means = false) {
|
||||
MS_ASSERT(weight != nullptr);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto dims = weight->shape();
|
||||
|
@ -143,7 +123,7 @@ int FixedBitQuantFilter(const ParameterPtr ¶meter, const tensor::TensorPtr &
|
|||
ret =
|
||||
DoPerChannelQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(),
|
||||
static_cast<mindspore::schema::QuantType>(quant_type), &quant_params, quant_max, quant_min,
|
||||
bit_num, k_means, &quant_data, ConvertShapeVectorToInt32(dims), preferred_dim);
|
||||
bit_num, &quant_data, ConvertShapeVectorToInt32(dims), preferred_dim, narrow_range, k_means);
|
||||
if (ret == RET_NO_CHANGE) {
|
||||
return ret;
|
||||
} else if (ret != RET_OK) {
|
||||
|
@ -152,7 +132,7 @@ int FixedBitQuantFilter(const ParameterPtr ¶meter, const tensor::TensorPtr &
|
|||
}
|
||||
} else if (weight_quant_type == FIXED_BIT_PER_LAYER) {
|
||||
ret = DoPerLayerQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(), &quant_params, quant_max,
|
||||
quant_min, bit_num, k_means, &quant_data);
|
||||
quant_min, bit_num, &quant_data, narrow_range, k_means);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do per layer quant failed.";
|
||||
return ret;
|
||||
|
@ -188,10 +168,9 @@ int FixedBitQuantFilter(const ParameterPtr ¶meter, const tensor::TensorPtr &
|
|||
|
||||
std::string NodePrimitiveType(const CNodePtr &cnode);
|
||||
|
||||
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num);
|
||||
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num,
|
||||
bool is_debug = false);
|
||||
SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const converter::Flags &flags, int thread_num,
|
||||
int *size, bool is_debug = false);
|
||||
int *size);
|
||||
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info);
|
||||
|
||||
bool CheckNodeInSet(const CNodePtr &cnode, const std::set<PrimitivePtr> &support_primitive_types);
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "tools/converter/quantizer/quantizer.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/converter/quantizer/quant_params.h"
|
||||
#include "tools/converter/quantizer/quant_strategy.h"
|
||||
#include "tools/converter/preprocess/preprocess_param.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
|
|
Loading…
Reference in New Issue