!22918 fix mixed bit quant bug && delete unused weight quant code

Merge pull request !22918 from yeyunpeng2020/quant_bak
This commit is contained in:
i-robot 2021-09-06 06:13:24 +00:00 committed by Gitee
commit 5bbb0647ee
22 changed files with 248 additions and 736 deletions

View File

@ -5,7 +5,6 @@ source ./scripts/base_functions.sh
function Run_Converter() {
# Unzip x86 runtime and converter
cd ${x86_path} || exit 1
tar -zxf ${x86_path}/avx/mindspore-lite-${version}-linux-x64.tar.gz || exit 1
tar -zxf mindspore-lite-${version}-linux-x64.tar.gz || exit 1
cd ${x86_path}/mindspore-lite-${version}-linux-x64/ || exit 1

View File

@ -23,15 +23,8 @@ SET INSTRUCTION=%3
SET BASEPATH=%CD%
SET RET_CODE=0
SET PACKAGE_PATH_CONVERT=%PACKAGE_PATH:"=%\windows_x64\avx
SET PACKAGE_PATH=%PACKAGE_PATH:"=%\windows_x64\%INSTRUCTION%
7z x -r "%PACKAGE_PATH_CONVERT%\mindspore-lite-*.zip"
IF NOT %errorlevel% == 0 (
echo "Decompression of runtime tool fail!"
SET RET_CODE=1
goto run_eof
)
echo A | 7z x -r "%PACKAGE_PATH%\mindspore-lite-*.zip"
7z x -r "%PACKAGE_PATH%\mindspore-lite-*.zip"
IF NOT %errorlevel% == 0 (
echo "Decompression of runtime tool fail!"
SET RET_CODE=1

View File

@ -42,7 +42,7 @@
#include "tools/common/node_util.h"
#include "tools/converter/converter_context.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
#include "tools/converter/quantizer/mixed_bit_weight_quantizer.h"
#include "tools/converter/quantizer/fse_encoder.h"
#include "nnacl/op_base.h"

View File

@ -84,6 +84,7 @@ int SplitLineToMap(std::ifstream *ifs, std::map<std::string, std::map<std::strin
}
auto split_vector = SplitStringToVector(raw_line, split_delimiter);
if (split_vector.size() != 2) {
MS_LOG(ERROR) << "split vector size != 2";
return RET_ERROR;
}
std::string key = split_vector.at(0);

View File

@ -62,7 +62,7 @@
#include "tools/optimizer/graph/decrease_transpose_algo.h"
#include "tools/optimizer/graph/specify_graph_input_format.h"
#include "tools/optimizer/graph/dump_graph.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/full_quant_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/weight_quantizer.h"
#include "tools/optimizer/parallel/split_strategy.h"
@ -278,9 +278,9 @@ void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGr
int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
// quant
if (config->commonQuantParam.quant_type == schema::QuantType_PostTraining) {
this->m_quantizer_ = std::make_unique<quant::PostTrainingQuantizer>(old_graph, config->commonQuantParam.bit_num);
this->m_quantizer_ = std::make_unique<quant::FullQuantQuantizer>(old_graph, config->commonQuantParam.bit_num);
if (m_quantizer_ == nullptr) {
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
MS_LOG(ERROR) << "New FullQuantQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}

View File

@ -25,7 +25,7 @@ int ConfigFileParser::ParseConfigFile(const std::string &config_file_path) {
std::map<std::string, std::map<std::string, std::string>> maps;
auto ret = mindspore::lite::ParseConfigFile(config_file_path, &maps);
if (ret != RET_OK) {
MS_LOG(ERROR) << "image_path=input1:/mnt/calibration_input1_path";
MS_LOG(ERROR) << "Parse config file failed.";
return ret;
}
ret = ParseDataPreProcessString(maps);

View File

@ -121,62 +121,20 @@ int PreprocessParser::ParseCalibratePath(const std::string &str, std::map<std::s
int PreprocessParser::ParseImagePreProcess(const DataPreProcessString &data_pre_process_str,
preprocess::ImagePreProcessParam *image_pre_process) {
if (!data_pre_process_str.resize_width.empty()) {
if (!ConvertIntNum(data_pre_process_str.resize_width, &image_pre_process->resize_width)) {
MS_LOG(ERROR) << "resize_width should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (image_pre_process->resize_width <= 0) {
MS_LOG(ERROR) << "resize_width must be > 0";
return RET_INPUT_PARAM_INVALID;
}
auto ret = ParseImageNormalize(data_pre_process_str, image_pre_process);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse image normalize failed.";
return ret;
}
if (!data_pre_process_str.resize_height.empty()) {
if (!ConvertIntNum(data_pre_process_str.resize_height, &image_pre_process->resize_height)) {
MS_LOG(ERROR) << "resize_width should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (image_pre_process->resize_height <= 0) {
MS_LOG(ERROR) << "resize_height must be > 0";
return RET_INPUT_PARAM_INVALID;
}
ret = ParseImageCenterCrop(data_pre_process_str, image_pre_process);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse image center crop failed.";
return ret;
}
if (!data_pre_process_str.resize_method.empty()) {
image_pre_process->resize_method = preprocess::ConvertResizeMethod(data_pre_process_str.resize_method);
}
if (!data_pre_process_str.normalize_mean.empty() &&
!ConvertDoubleVector(data_pre_process_str.normalize_mean, &image_pre_process->normalize_mean)) {
MS_LOG(ERROR) << "Convert normalize_mean failed.";
return RET_INPUT_PARAM_INVALID;
}
if (!data_pre_process_str.normalize_std.empty() &&
!ConvertDoubleVector(data_pre_process_str.normalize_std, &image_pre_process->normalize_std)) {
MS_LOG(ERROR) << "Convert normalize_std failed.";
return RET_INPUT_PARAM_INVALID;
}
if (!data_pre_process_str.center_crop_width.empty()) {
if (!ConvertIntNum(data_pre_process_str.center_crop_width, &image_pre_process->center_crop_width)) {
MS_LOG(ERROR) << "center_crop_width should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (image_pre_process->center_crop_width <= 0) {
MS_LOG(ERROR) << "center_crop_width must be > 0";
return RET_INPUT_PARAM_INVALID;
}
}
if (!data_pre_process_str.center_crop_height.empty()) {
if (!ConvertIntNum(data_pre_process_str.center_crop_height, &image_pre_process->center_crop_height)) {
MS_LOG(ERROR) << "center_crop_height should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (image_pre_process->center_crop_height <= 0) {
MS_LOG(ERROR) << "center_crop_height must be > 0";
return RET_INPUT_PARAM_INVALID;
}
ret = ParseImageResize(data_pre_process_str, image_pre_process);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Parse image resize failed.";
return ret;
}
return RET_OK;
}
@ -239,6 +197,73 @@ int PreprocessParser::CollectCalibInputs(const std::map<std::string, std::string
}
return RET_OK;
}
int PreprocessParser::ParseImageNormalize(const DataPreProcessString &data_pre_process_str,
preprocess::ImagePreProcessParam *image_pre_process) {
if (!data_pre_process_str.normalize_mean.empty() &&
!ConvertDoubleVector(data_pre_process_str.normalize_mean, &image_pre_process->normalize_mean)) {
MS_LOG(ERROR) << "Convert normalize_mean failed.";
return RET_INPUT_PARAM_INVALID;
}
if (!data_pre_process_str.normalize_std.empty() &&
!ConvertDoubleVector(data_pre_process_str.normalize_std, &image_pre_process->normalize_std)) {
MS_LOG(ERROR) << "Convert normalize_std failed.";
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;
}
int PreprocessParser::ParseImageResize(const DataPreProcessString &data_pre_process_str,
preprocess::ImagePreProcessParam *image_pre_process) {
if (!data_pre_process_str.resize_width.empty()) {
if (!ConvertIntNum(data_pre_process_str.resize_width, &image_pre_process->resize_width)) {
MS_LOG(ERROR) << "resize_width should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (image_pre_process->resize_width <= 0 || image_pre_process->resize_width > 65535) {
MS_LOG(ERROR) << "resize_width must be in (0,65535].";
return RET_INPUT_PARAM_INVALID;
}
}
if (!data_pre_process_str.resize_height.empty()) {
if (!ConvertIntNum(data_pre_process_str.resize_height, &image_pre_process->resize_height)) {
MS_LOG(ERROR) << "resize_width should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (image_pre_process->resize_height <= 0 || image_pre_process->resize_height > 65535) {
MS_LOG(ERROR) << "resize_height must be in (0,65535].";
return RET_INPUT_PARAM_INVALID;
}
}
if (!data_pre_process_str.resize_method.empty()) {
image_pre_process->resize_method = preprocess::ConvertResizeMethod(data_pre_process_str.resize_method);
}
return RET_OK;
}
int PreprocessParser::ParseImageCenterCrop(const DataPreProcessString &data_pre_process_str,
preprocess::ImagePreProcessParam *image_pre_process) {
if (!data_pre_process_str.center_crop_width.empty()) {
if (!ConvertIntNum(data_pre_process_str.center_crop_width, &image_pre_process->center_crop_width)) {
MS_LOG(ERROR) << "center_crop_width should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (image_pre_process->center_crop_width <= 0 || image_pre_process->center_crop_width > 65535) {
MS_LOG(ERROR) << "center_crop_width must be in (0,65535].";
return RET_INPUT_PARAM_INVALID;
}
}
if (!data_pre_process_str.center_crop_height.empty()) {
if (!ConvertIntNum(data_pre_process_str.center_crop_height, &image_pre_process->center_crop_height)) {
MS_LOG(ERROR) << "center_crop_height should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (image_pre_process->center_crop_height <= 0 || image_pre_process->center_crop_height > 65535) {
MS_LOG(ERROR) << "center_crop_height must be in (0,65535].";
return RET_INPUT_PARAM_INVALID;
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

View File

@ -35,6 +35,15 @@ class PreprocessParser {
int ParseImagePreProcess(const DataPreProcessString &data_pre_process_str,
preprocess::ImagePreProcessParam *image_pre_process);
int ParseImageNormalize(const DataPreProcessString &data_pre_process_str,
preprocess::ImagePreProcessParam *image_pre_process);
int ParseImageResize(const DataPreProcessString &data_pre_process_str,
preprocess::ImagePreProcessParam *image_pre_process);
int ParseImageCenterCrop(const DataPreProcessString &data_pre_process_str,
preprocess::ImagePreProcessParam *image_pre_process);
int ParseImageToFormat(const std::string &image_to_format_str, preprocess::ImageToFormat *image_to_format);
int ParseCalibratePath(const std::string &str, std::map<std::string, std::string> *value);

View File

@ -58,13 +58,13 @@ int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_str
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should be a valid number.";
return RET_INPUT_PARAM_INVALID;
}
if (common_quant->min_quant_weight_size < 0) {
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should be greater than or equal to zero." << std::endl;
if (common_quant->min_quant_weight_size < 0 || common_quant->min_quant_weight_size > 65535) {
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should in [0,65535]." << std::endl;
return RET_INPUT_PARAM_INVALID;
}
if (common_quant->min_quant_weight_channel < 0) {
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should be greater than or equal to zero." << std::endl;
if (common_quant->min_quant_weight_channel < 0 || common_quant->min_quant_weight_size > 65535) {
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should in [0,65535]." << std::endl;
return RET_INPUT_PARAM_INVALID;
}
return RET_OK;

View File

@ -8,14 +8,14 @@ file(GLOB QUANTIZER
${CMAKE_CURRENT_SOURCE_DIR}/quant_helper/*
${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/full_quant_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc
${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/huffman_encode.cc
${CMAKE_CURRENT_SOURCE_DIR}/fse_decoder.cc
${CMAKE_CURRENT_SOURCE_DIR}/fse_bit_stream.cc
${CMAKE_CURRENT_SOURCE_DIR}/fse_encoder.cc
${CMAKE_CURRENT_SOURCE_DIR}/fix_bit_weight_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/mixed_bit_weight_quantizer.cc
)
set_property(SOURCE ${QUANTIZER} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
add_library(quantizer_mid OBJECT ${QUANTIZER})

View File

@ -19,7 +19,7 @@
#include <vector>
#include "tools/converter/quantizer/fse_bit_stream.h"
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
#include "tools/converter/quantizer/mixed_bit_weight_quantizer.h"
namespace mindspore::lite::quant {
constexpr int MAX_SYMS = 65534;
constexpr int MAX_TABLE_LOG = 16;

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/full_quant_quantizer.h"
#include <dirent.h>
#include <future>
#include <map>
@ -43,7 +43,6 @@
#include "tools/common/tensor_util.h"
#include "src/common/quant_utils.h"
#include "src/common/utils.h"
#include "tools/converter/quantizer/weight_quantizer.h"
#include "tools/converter/preprocess/image_preprocess.h"
#include "nnacl/op_base.h"
@ -330,53 +329,6 @@ std::pair<CNodePtr, int32_t> DivergInfo::GetZeropoint() {
return std::make_pair(this->cnode, zero_point);
}
std::unordered_map<CNodePtr, float> Calibrator::GetScale(
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
MS_ASSERT(diverg_info != nullptr);
std::unordered_map<CNodePtr, float> result;
for (auto &iter : *diverg_info) {
DivergInfo *info = iter.second.get();
auto item = info->GetScale();
result.insert(item);
}
return result;
}
std::unordered_map<CNodePtr, int32_t> Calibrator::GetZeropoint(
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
MS_ASSERT(diverg_info != nullptr);
std::unordered_map<CNodePtr, int32_t> result;
for (auto &iter : *diverg_info) {
DivergInfo *info = iter.second.get();
auto zeropoint = info->GetZeropoint();
result.insert(zeropoint);
}
return result;
}
std::map<CNodePtr, MaxMin> Calibrator::GetMinMax(
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
MS_ASSERT(diverg_info != nullptr);
std::map<CNodePtr, MaxMin> result;
for (auto &iter : *diverg_info) {
DivergInfo *info = iter.second.get();
mindspore::lite::quant::MaxMin input_maxmin{};
input_maxmin.min = info->min;
input_maxmin.max = info->max;
result[info->cnode] = input_maxmin;
}
return result;
}
void Calibrator::Dump() {
for (auto &kv : this->inputs_diverg_info_) {
auto &infos = kv.second;
for (auto &info : infos) {
info->DumpHistogram();
}
}
}
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *Calibrator::GetInputDivergInfo() {
return &this->inputs_diverg_info_;
}
@ -479,7 +431,7 @@ STATUS Calibrator::GenerateInputData(const std::string &input_name, size_t image
return preprocess::PreProcess(data_pre_process_param_, input_name, image_index, tensor);
}
PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type, bool per_channel)
FullQuantQuantizer::FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type, bool per_channel)
: Quantizer(std::move(graph)) {
MS_ASSERT(graph != nullptr);
this->per_channel_ = per_channel;
@ -501,15 +453,15 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, int bit_num, Ty
}
}
PostTrainingQuantizer::~PostTrainingQuantizer() {
FullQuantQuantizer::~FullQuantQuantizer() {
delete fp32_session_;
delete fp32_model_;
delete int8_session_;
delete int8_model_;
}
STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min,
const PrimitivePtr &primitive) const {
STATUS FullQuantQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min,
const PrimitivePtr &primitive) const {
MS_ASSERT(max_min != nullptr);
MS_ASSERT(primitive != nullptr);
auto quant_param_holder = GetCNodeQuantHolder(primitive);
@ -529,8 +481,8 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, stru
return RET_OK;
}
STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min,
const PrimitivePtr &primitive) const {
STATUS FullQuantQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min,
const PrimitivePtr &primitive) const {
MS_ASSERT(max_min != nullptr);
MS_ASSERT(primitive != nullptr);
auto quant_param_holder = GetCNodeQuantHolder(primitive);
@ -550,8 +502,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
return RET_OK;
}
STATUS PostTrainingQuantizer::DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight,
const PrimitivePtr &primitive, bool per_channel) const {
STATUS FullQuantQuantizer::DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight,
const PrimitivePtr &primitive, bool per_channel) const {
MS_ASSERT(weight != nullptr);
MS_ASSERT(primitive != nullptr);
// perlayer
@ -572,19 +524,9 @@ STATUS PostTrainingQuantizer::DoWeightQuant(const std::string &op_name, const An
auto bit_num_t = bit_num;
auto quant_max_t = quant_max;
auto quant_min_t = quant_min;
if (calibrator_->full_quant_param_.mixed) {
auto opname_iter = opname_bit_.find(op_name);
if (opname_iter == opname_bit_.end()) {
MS_LOG(WARNING) << op_name << " not in the opname_bit_ map";
} else {
bit_num_t = opname_iter->second;
quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
}
}
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
auto status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_PostTraining, quant_max_t, quant_min_t, bit_num_t,
weight_quant_type, kNumberTypeInt8);
auto status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_PostTraining, quant_max_t, quant_min_t,
bit_num_t, weight_quant_type, kNumberTypeInt8);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status;
@ -608,7 +550,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(const std::string &op_name, const An
return RET_OK;
}
STATUS PostTrainingQuantizer::DoBiasQuant(const AnfNodePtr &bias, const PrimitivePtr &primitive) {
STATUS FullQuantQuantizer::DoBiasQuant(const AnfNodePtr &bias, const PrimitivePtr &primitive) {
if (primitive == nullptr || bias == nullptr) {
MS_LOG(ERROR) << "null pointer!";
return RET_NULL_PTR;
@ -698,7 +640,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(const AnfNodePtr &bias, const Primitiv
return RET_OK;
}
STATUS PostTrainingQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
STATUS FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto inputs_diverg_info = calibrator_->GetInputDivergInfo();
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
@ -780,7 +722,7 @@ STATUS PostTrainingQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
return RET_OK;
}
STATUS PostTrainingQuantizer::QuantNode() {
STATUS FullQuantQuantizer::QuantNode() {
auto inputs_diverg_info = calibrator_->GetInputDivergInfo();
auto outputs_diverg_info = calibrator_->GetOutputDivergInfo();
@ -879,7 +821,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
return RET_OK;
}
STATUS PostTrainingQuantizer::UpdateDivergInverval() {
STATUS FullQuantQuantizer::UpdateDivergInverval() {
this->calibrator_->UpdateDivergInverval(this->calibrator_->GetInputDivergInfo());
this->calibrator_->UpdateDivergInverval(this->calibrator_->GetOutputDivergInfo());
return RET_OK;
@ -894,7 +836,7 @@ STATUS PostTrainingQuantizer::UpdateDivergInverval() {
* 2.1 parse image files to input tensor
* 3. save quantied node
**/
STATUS PostTrainingQuantizer::PreProcess() {
STATUS FullQuantQuantizer::PreProcess() {
// 3. collect to be quantized operators
// from user input
QuantStrategy strategy(kMillisecondsBase);
@ -920,8 +862,8 @@ STATUS PostTrainingQuantizer::PreProcess() {
return RET_OK;
}
STATUS PostTrainingQuantizer::CheckFp32TensorVec(const std::string &node_name,
const std::vector<mindspore::tensor::MSTensor *> &tensor_vec) {
STATUS FullQuantQuantizer::CheckFp32TensorVec(const std::string &node_name,
const std::vector<mindspore::tensor::MSTensor *> &tensor_vec) {
MS_ASSERT(tensor_vec != nullptr);
if (tensor_vec.empty()) {
MS_LOG(ERROR) << "node: " << node_name << " input tensors is 0";
@ -942,7 +884,7 @@ STATUS PostTrainingQuantizer::CheckFp32TensorVec(const std::string &node_name,
* 2. insert callback to session
* 3. run session
**/
STATUS PostTrainingQuantizer::DoInference() {
STATUS FullQuantQuantizer::DoInference() {
// get input tensor
vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
if (inputs.size() != calibrator_->GetInputNum()) {
@ -968,7 +910,7 @@ STATUS PostTrainingQuantizer::DoInference() {
if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
return true;
}
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
return false;
}
if ((*diverg_info_map)[callParam.node_name].size() == 1 &&
@ -999,7 +941,7 @@ STATUS PostTrainingQuantizer::DoInference() {
if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
return true;
}
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
return false;
}
if ((*diverg_info_map)[callParam.node_name].size() == 1 && afterOutputs.size() > 1) {
@ -1029,7 +971,7 @@ STATUS PostTrainingQuantizer::DoInference() {
return RET_OK;
}
STATUS PostTrainingQuantizer::Int8Inference() {
STATUS FullQuantQuantizer::Int8Inference() {
// int8 inference
vector<mindspore::tensor::MSTensor *> inputs = int8_session_->GetInputs();
for (auto input_tensor : inputs) {
@ -1059,8 +1001,8 @@ STATUS PostTrainingQuantizer::Int8Inference() {
return RET_OK;
}
STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
std::future<STATUS> int8_inference = std::async(std::launch::async, &PostTrainingQuantizer::Int8Inference, this);
STATUS FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
std::future<STATUS> int8_inference = std::async(std::launch::async, &FullQuantQuantizer::Int8Inference, this);
// get input tensor
vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
if (inputs.size() != 1) {
@ -1115,7 +1057,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
return status;
}
STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
STATUS FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
auto op_name = cnode->fullname_with_scope();
const auto &bias_diff = op_bias_diff_map[op_name];
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
@ -1196,7 +1138,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, con
return RET_OK;
}
STATUS PostTrainingQuantizer::CollectDataFrequency() {
STATUS FullQuantQuantizer::CollectDataFrequency() {
// get input tensor
vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
if (inputs.size() != calibrator_->GetInputNum()) {
@ -1221,7 +1163,7 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
return true;
}
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
return false;
}
for (size_t i = 0; i < (*diverg_info_map)[callParam.node_name].size(); i++) {
@ -1243,7 +1185,7 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
if (diverg_info_map->find(call_param.node_name) == diverg_info_map->end()) {
return true;
}
if (PostTrainingQuantizer::CheckFp32TensorVec(call_param.node_name, after_outputs) != RET_OK) {
if (FullQuantQuantizer::CheckFp32TensorVec(call_param.node_name, after_outputs) != RET_OK) {
return false;
}
int output_i = 0;
@ -1267,9 +1209,9 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
return RET_OK;
}
STATUS PostTrainingQuantizer::ComputeThreshold() { return this->calibrator_->ComputeThreshold(); }
STATUS FullQuantQuantizer::ComputeThreshold() { return this->calibrator_->ComputeThreshold(); }
STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
STATUS FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_LOG(INFO) << "start to parse config file";
if (this->calibrator_ == nullptr) {
MS_LOG(ERROR) << "calibrator is null!";
@ -1278,22 +1220,6 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
calibrator_->full_quant_param_ = flags.fullQuantParam;
calibrator_->data_pre_process_param_ = flags.dataPreProcessParam;
STATUS status;
if (calibrator_->full_quant_param_.mixed) { // get opname_bit map
auto weight_quant_func_graph = CopyFuncGraph(func_graph);
if (weight_quant_func_graph == nullptr) {
MS_LOG(ERROR) << "CopyFuncGraph error";
return RET_ERROR;
}
WeightQuantizer weight_quantizer(weight_quant_func_graph, calibrator_->full_quant_param_);
weight_quantizer.flags = flags;
status = weight_quantizer.DoQuantize(weight_quant_func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Do mix weight quant error";
return RET_ERROR;
}
opname_bit_ = weight_quantizer.opname_bit_;
}
status = PreProcess();
if (status != RET_OK) {
MS_LOG(ERROR) << "do pre process failed!";
@ -1366,7 +1292,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
return RET_OK;
}
bool PostTrainingQuantizer::OpInputDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
bool FullQuantQuantizer::OpInputDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
MS_ASSERT(data != nullptr);
std::lock_guard<std::mutex> lg(mutex_op_input);
if (type == STORE) {
@ -1390,8 +1316,7 @@ bool PostTrainingQuantizer::OpInputDataHandle(OperationType type, const string &
return false;
}
bool PostTrainingQuantizer::OpOutputChMeanDataHandle(OperationType type, const string &op_name,
std::vector<float> *data) {
bool FullQuantQuantizer::OpOutputChMeanDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
MS_ASSERT(data != nullptr);
std::lock_guard<std::mutex> lg(mutex_op_output);
if (type == STORE) {
@ -1415,14 +1340,14 @@ bool PostTrainingQuantizer::OpOutputChMeanDataHandle(OperationType type, const s
return false;
}
KernelCallBack PostTrainingQuantizer::GetBeforeCallBack(bool int8_op) {
KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
KernelCallBack before_call_back;
if (!int8_op) {
before_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs,
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
const CallBackParam &callParam) -> bool {
if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
return false;
}
auto tensor = beforeInputs[0];
@ -1495,7 +1420,7 @@ KernelCallBack PostTrainingQuantizer::GetBeforeCallBack(bool int8_op) {
return before_call_back;
}
KernelCallBack PostTrainingQuantizer::GetAfterCallBack(bool int8_op) {
KernelCallBack FullQuantQuantizer::GetAfterCallBack(bool int8_op) {
KernelCallBack after_call_back;
if (!int8_op) {
return GetFloatAfterCallBack();
@ -1503,7 +1428,7 @@ KernelCallBack PostTrainingQuantizer::GetAfterCallBack(bool int8_op) {
return GetInt8AfterCallBack();
}
KernelCallBack PostTrainingQuantizer::GetInt8AfterCallBack() {
KernelCallBack FullQuantQuantizer::GetInt8AfterCallBack() {
KernelCallBack after_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &afterInputs,
const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
const CallBackParam &callParam) -> bool {
@ -1580,12 +1505,12 @@ KernelCallBack PostTrainingQuantizer::GetInt8AfterCallBack() {
return after_call_back;
}
KernelCallBack PostTrainingQuantizer::GetFloatAfterCallBack() {
KernelCallBack FullQuantQuantizer::GetFloatAfterCallBack() {
KernelCallBack after_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &afterInputs,
const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
const CallBackParam &callParam) -> bool {
if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
return false;
}
auto tensor = afterOutputs[0];

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_POSTRAINING_QUANTIZER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_POSTRAINING_QUANTIZER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FULL_QUANT_QUANTIZER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FULL_QUANT_QUANTIZER_H
#include <string>
#include <memory>
@ -31,7 +31,6 @@
#include "tools/converter/converter.h"
#include "include/ms_tensor.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "tools/converter/quantizer/weight_quantizer.h"
#include "tools/converter/quantizer/quant_params.h"
#include "tools/converter/preprocess/preprocess_param.h"
@ -46,10 +45,10 @@ struct MaxMin {
constexpr int kDefaultBinNumber = 2048;
class PostTrainingQuantizer : public Quantizer {
class FullQuantQuantizer : public Quantizer {
public:
PostTrainingQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type = kNumberTypeInt8, bool per_channel = true);
~PostTrainingQuantizer() override;
FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type = kNumberTypeInt8, bool per_channel = true);
~FullQuantQuantizer() override;
STATUS DoQuantize(FuncGraphPtr func_graph) override;
@ -201,19 +200,9 @@ class Calibrator {
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info);
static STATUS UpdateDataFrequency(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info);
void Dump();
STATUS ComputeThreshold();
static std::unordered_map<CNodePtr, float> GetScale(
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
static std::unordered_map<CNodePtr, int32_t> GetZeropoint(
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
static std::map<CNodePtr, MaxMin> GetMinMax(
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetInputDivergInfo();
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetOutputDivergInfo();
@ -232,4 +221,4 @@ class Calibrator {
int quant_min_;
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_POSTRAINING_QUANTIZER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FULL_QUANT_QUANTIZER_H

View File

@ -14,15 +14,15 @@
* limitations under the License.
*/
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
#include "tools/converter/quantizer/mixed_bit_weight_quantizer.h"
#include <cmath>
namespace mindspore::lite::quant {
// the error is currently measured per channel.
// it could be measured per layer but it would be less good.
// the `preferred` dim should point to the output channels dimension.
float FixBitWeightQuantizer::MeasureQuantizationError(float *weights, const int *shape, int dims, int preferred_dim,
float scale) {
float MixedBitWeightQuantizer::MeasureQuantizationError(float *weights, const int *shape, int dims, int preferred_dim,
float scale) {
MS_ASSERT(weights != nullptr);
MS_ASSERT(shape != nullptr);
int numel = 1;
@ -61,13 +61,13 @@ float FixBitWeightQuantizer::MeasureQuantizationError(float *weights, const int
}
variance_dequant = std::sqrt(variance_dequant / numel);
variance_raw = std::sqrt(variance_raw / numel);
var_corr = variance_raw / variance_dequant;
mean_corr = average_raw - average_dequant * var_corr;
var_corr_ = variance_raw / variance_dequant;
mean_corr_ = average_raw - average_dequant * var_corr_;
for (int i = 0; i < numel; i++) {
int bucket = (i / bucket_volume) % bucket_count;
norms2[bucket] += weights[i] * weights[i];
float dequant = var_corr * (scale * (floorf(weights[i] / scale + 0.5))) + mean_corr;
float dequant = var_corr_ * (scale * (floorf(weights[i] / scale + 0.5))) + mean_corr_;
float d = weights[i] - dequant;
dnorms2[bucket] += d * d;
}
@ -82,7 +82,7 @@ float FixBitWeightQuantizer::MeasureQuantizationError(float *weights, const int
return t / (c + 1e-7);
}
MinMax FixBitWeightQuantizer::GetMinMax(const float *arr, int arrc) {
MinMax MixedBitWeightQuantizer::GetMinMax(const float *arr, int arrc) {
MS_ASSERT(arr != nullptr);
MinMax min_max = {INFINITY, -INFINITY};
for (int i = 0; i < arrc; i++)
@ -93,9 +93,9 @@ MinMax FixBitWeightQuantizer::GetMinMax(const float *arr, int arrc) {
return min_max;
}
BinarySearchResult FixBitWeightQuantizer::BinarySearchForQuantizationScale(float *weights, int *shape, int dims,
int preferred_dim, int max_iters,
float target_err, float rel_tol) {
BinarySearchResult MixedBitWeightQuantizer::BinarySearchForQuantizationScale(float *weights, int *shape, int dims,
int preferred_dim, int max_iters,
float target_err, float rel_tol) {
MS_ASSERT(weights != nullptr);
MS_ASSERT(shape != nullptr);
int element_num = 1;
@ -136,9 +136,9 @@ BinarySearchResult FixBitWeightQuantizer::BinarySearchForQuantizationScale(float
}
}
int FixBitWeightQuantizer::DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim,
std::vector<schema::QuantParamT> *quant_params,
std::vector<int16_t> *quant_datas) {
int MixedBitWeightQuantizer::DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim,
std::vector<schema::QuantParamT> *quant_params,
std::vector<int16_t> *quant_datas) {
MS_ASSERT(weights != nullptr);
int weight_count = 1;
int dims = shape.size();
@ -148,8 +148,8 @@ int FixBitWeightQuantizer::DoQuantization(float *weights, std::vector<int64_t> s
input_shape[i] = shape[i];
}
BinarySearchResult br = BinarySearchForQuantizationScale(weights, input_shape, dims, preferred_dim, max_search_iters,
target_relative_err, target_search_tolerance);
BinarySearchResult br = BinarySearchForQuantizationScale(weights, input_shape, dims, preferred_dim, max_search_iters_,
target_relative_err_, target_search_tolerance_);
if (br.status != 0) {
MS_LOG(ERROR) << "reached_max_iters";
return RET_ERROR;
@ -166,15 +166,15 @@ int FixBitWeightQuantizer::DoQuantization(float *weights, std::vector<int64_t> s
return RET_OK;
}
int FixBitWeightQuantizer::QuantizeByScale(const float *weights, int weightsc, float scale,
schema::QuantParamT *quant_params, std::vector<int16_t> *quant_datas) {
int MixedBitWeightQuantizer::QuantizeByScale(const float *weights, int weightsc, float scale,
schema::QuantParamT *quant_params, std::vector<int16_t> *quant_datas) {
MS_ASSERT(weights != nullptr);
for (int i = 0; i < weightsc; i++) {
auto q = static_cast<int>(floorf(weights[i] / scale + 0.5));
quant_datas->at(i) = q;
}
quant_params->meanCorr = mean_corr;
quant_params->varCorr = var_corr;
quant_params->meanCorr = mean_corr_;
quant_params->varCorr = var_corr_;
quant_params->scale = scale;
quant_params->zeroPoint = 0;
quant_params->numBits = 0;

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FIX_BIT_WEIGHT_QUANTIZER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FIX_BIT_WEIGHT_QUANTIZER_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_MIXED_BIT_WEIGHT_QUANTIZER_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_MIXED_BIT_WEIGHT_QUANTIZER_H
#include <cstdint>
#include <vector>
#include <cmath>
@ -34,14 +34,14 @@ typedef struct {
float max;
} MinMax;
class FixBitWeightQuantizer {
class MixedBitWeightQuantizer {
public:
explicit FixBitWeightQuantizer(float target_relative_err = 0.01, float target_search_tolerance = 0.01,
int max_search_iters = 100)
: target_relative_err(target_relative_err),
target_search_tolerance(target_search_tolerance),
max_search_iters(max_search_iters) {}
~FixBitWeightQuantizer() = default;
explicit MixedBitWeightQuantizer(float target_relative_err = 0.01, float target_search_tolerance = 0.01,
int max_search_iters = 100)
: target_relative_err_(target_relative_err),
target_search_tolerance_(target_search_tolerance),
max_search_iters_(max_search_iters) {}
~MixedBitWeightQuantizer() = default;
int DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim,
std::vector<schema::QuantParamT> *quant_params, std::vector<int16_t> *quant_datas);
@ -58,11 +58,11 @@ class FixBitWeightQuantizer {
int max_iters, float target_err, float rel_tol);
private:
float var_corr{1};
float mean_corr{0};
float target_relative_err;
float target_search_tolerance;
int max_search_iters;
float var_corr_{1};
float mean_corr_{0};
float target_relative_err_;
float target_search_tolerance_;
int max_search_iters_;
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FIX_BIT_WEIGHT_QUANTIZER_H
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_MIXED_BIT_WEIGHT_QUANTIZER_H

View File

@ -42,8 +42,6 @@ struct FullQuantParam {
ActivationQuantizedMethod activation_quant_method = MAX_MIN;
bool bias_correction = true;
int thread_num = 1;
bool mixed = false;
float mean_error_threshold = 0.04;
};
} // namespace mindspore::lite::quant

View File

@ -726,8 +726,8 @@ void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<in
}
}
STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
WeightQuantType weight_quant_type, TypeId quant_data_type, int index) {
STATUS MixedBitQuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
WeightQuantType weight_quant_type, TypeId quant_data_type, double init_scale, int index) {
MS_CHECK_TRUE_RET(primitive != nullptr, RET_NULL_PTR);
MS_CHECK_TRUE_RET(weight != nullptr, RET_NULL_PTR);
auto dims = weight->shape();
@ -748,13 +748,13 @@ STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitiv
std::vector<int16_t> quant_data(elem_count);
int ret = RET_OK;
if (weight_quant_type == MIXED_BIT_PER_LAYER) {
FixBitWeightQuantizer quantizer(0.02);
MixedBitWeightQuantizer quantizer(init_scale);
quantizer.DoQuantization(static_cast<float *>(weight->data_c()), weight->shape_c(), 0, &quant_params, &quant_data);
} else {
MS_LOG(ERROR) << "Unsupported weight quant type:" << weight_quant_type;
}
auto status =
UpdateTensorDataAndSize(weight, quant_data.data(), quant_data.size() * sizeof(int16_t), TypeId::kNumberTypeInt16);
UpdateTensorDataAndSize(weight, quant_data.data(), quant_data.size() * sizeof(int16_t), quant_data_type);
if (status != RET_OK) {
MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
return RET_ERROR;

View File

@ -42,7 +42,7 @@
#include "abstract/dshape.h"
#include "tools/converter/quantizer/huffman_encode.h"
#include "tools/converter/quantizer/bitpacking.h"
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
#include "tools/converter/quantizer/mixed_bit_weight_quantizer.h"
#include "src/lite_session.h"
#include "tools/converter/graphdef_transform.h"
#include "src/common/file_utils.h"
@ -147,13 +147,13 @@ STATUS DoBitPack(const tensor::TensorPtr &weight, const size_t &bit_num, const s
return RET_OK;
}
STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
WeightQuantType weight_quant_type, TypeId quant_data_type, int index = 1);
STATUS MixedBitQuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
WeightQuantType weight_quant_type, TypeId quant_data_type, double init_scale, int index);
template <typename T>
STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type, int quant_max,
int quant_min, size_t bit_num, WeightQuantType weight_quant_type, TypeId quant_data_type,
int index = 1, bool k_means = false) {
STATUS FixedBitQuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
int quant_max, int quant_min, size_t bit_num, WeightQuantType weight_quant_type,
TypeId quant_data_type, int index = 1, bool k_means = false) {
MS_ASSERT(weight != nullptr);
MS_ASSERT(primitive != nullptr);
auto dims = weight->shape();
@ -240,7 +240,5 @@ FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &);
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info);
int ConvertInputShapeMapToVector(FullQuantParam *config_param_, const std::vector<tensor::MSTensor *> &inputs,
std::vector<std::vector<int>> *shapes);
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZE_UTIL_H_

View File

@ -1,31 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mindspore/lite/tools/converter/quantizer/quantizer.h"
namespace mindspore::lite::quant {
STATUS Quantizer::GenerateQuantParam() { return RET_OK; }
STATUS Quantizer::RemoveFakeQuant() { return RET_OK; }
STATUS Quantizer::DetermineNodeQuantType() { return RET_OK; }
STATUS FbQuantizer::GenerateQuantParam() { return RET_OK; }
STATUS FbQuantizer::RemoveFakeQuant() { return RET_OK; }
STATUS FbQuantizer::DetermineNodeQuantType() { return RET_OK; }
} // namespace mindspore::lite::quant

View File

@ -36,12 +36,6 @@ class Quantizer {
virtual ~Quantizer() = default;
virtual STATUS RemoveFakeQuant();
virtual STATUS GenerateQuantParam();
virtual STATUS DetermineNodeQuantType();
virtual STATUS DoQuantize(FuncGraphPtr func_graph) = 0;
converter::Flags flags;
@ -49,23 +43,5 @@ class Quantizer {
protected:
FuncGraphPtr funcGraph = nullptr;
};
class FbQuantizer {
public:
explicit FbQuantizer(schema::MetaGraphT *graph) : graph(graph) {}
virtual ~FbQuantizer() = default;
virtual STATUS RemoveFakeQuant();
virtual STATUS GenerateQuantParam();
virtual STATUS DetermineNodeQuantType();
virtual STATUS DoQuantize() = 0;
protected:
schema::MetaGraphT *graph = nullptr;
};
} // namespace mindspore::lite::quant
#endif

View File

@ -27,11 +27,6 @@ using std::string;
using std::vector;
namespace mindspore::lite::quant {
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const FullQuantParam &config) : Quantizer(std::move(graph)) {
quant_strategy_ = std::make_unique<QuantStrategy>(0, 0);
config_param_ = config;
}
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config) : Quantizer(std::move(graph)) {
auto quant_size = config.commonQuantParam.min_quant_weight_size;
this->bit_num_ = config.commonQuantParam.bit_num;
@ -39,17 +34,19 @@ WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &con
type_id_ = kNumberTypeInt16;
this->is_mixed_bit_ = true;
}
auto convQuantWeightChannelThreshold = config.commonQuantParam.min_quant_weight_channel;
quant_strategy_ = std::make_unique<QuantStrategy>(quant_size, convQuantWeightChannelThreshold);
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
// parse type_id_
if (this->bit_num_ > 0 && this->bit_num_ <= kMaxBit) {
type_id_ = kNumberTypeInt8;
} else if (this->bit_num_ <= (kMaxBit * 2)) {
type_id_ = kNumberTypeInt16;
} else {
MS_LOG(ERROR) << "invalid input bits";
quant_strategy_ = std::make_unique<QuantStrategy>(quant_size, config.commonQuantParam.min_quant_weight_channel);
// parse param for fixed bit quant.
if (!this->is_mixed_bit_) {
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
// parse type_id_
if (this->bit_num_ > 0 && this->bit_num_ <= kMaxBit) {
type_id_ = kNumberTypeInt8;
} else if (this->bit_num_ <= (kMaxBit * 2)) {
type_id_ = kNumberTypeInt16;
} else {
MS_LOG(ERROR) << "invalid input bits";
}
}
}
@ -116,18 +113,19 @@ STATUS WeightQuantizer::DoConvQuantize(const CNodePtr &cnode) {
auto status = RET_ERROR;
if (is_mixed_bit_) {
type_id_ = kNumberTypeInt16;
status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER, type_id_);
status = MixedBitQuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER,
type_id_, flags.mixedBitWeightQuantParam.init_scale, 1);
} else if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
bit_num_, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
status = FixedBitQuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
bit_num_, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
}
if (status == RET_CONTINUE) {
return RET_OK;
} else if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
MS_LOG(ERROR) << "MixedBitQuantFilter failed : " << status;
return status;
}
status = SetAbstract(tensor_info, param_node, primitive);
@ -160,19 +158,21 @@ STATUS WeightQuantizer::DoMulQuantize(const CNodePtr &cnode) {
weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
}
if (is_mixed_bit_) {
status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER,
type_id_, i - 1);
status =
MixedBitQuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER,
type_id_, flags.mixedBitWeightQuantParam.init_scale, i - 1);
} else if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
bit_num_, weight_quant_type, type_id_, i - 1);
status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
bit_num_, weight_quant_type, type_id_, i - 1);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
bit_num_, weight_quant_type, type_id_, i - 1);
status = FixedBitQuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
bit_num_, weight_quant_type, type_id_, i - 1);
}
if (status == RET_CONTINUE) {
continue;
} else if (status != RET_OK) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << " input " << i << " QuantFilter failed : " << status;
MS_LOG(ERROR) << cnode->fullname_with_scope() << " input " << i
<< " MixedBitQuantFilter failed : " << status;
return status;
}
status = SetAbstract(tensor_info, param_node, primitive);
@ -241,19 +241,19 @@ STATUS WeightQuantizer::DoGatherQuantize(const CNodePtr &cnode) {
auto status = RET_ERROR;
if (is_mixed_bit_) {
status =
QuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER, type_id_, 0);
status = MixedBitQuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER,
type_id_, flags.mixedBitWeightQuantParam.init_scale, 0);
} else if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, 0);
status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
bit_num_, WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, 0);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, 0);
status = FixedBitQuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
bit_num_, WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, 0);
}
if (status == RET_CONTINUE) {
return RET_OK;
} else if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
MS_LOG(ERROR) << "MixedBitQuantFilter failed : " << status;
return status;
}
status = SetAbstract(tensor_info, param_node, primitive);
@ -293,14 +293,14 @@ STATUS WeightQuantizer::DoOptimizerQuantize(const CNodePtr &cnode) {
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, idx - 1);
status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
bit_num_, WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, idx - 1);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, idx - 1);
status = FixedBitQuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
bit_num_, WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, idx - 1);
}
if (status != RET_OK && status != RET_CONTINUE) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
MS_LOG(ERROR) << "MixedBitQuantFilter failed : " << status;
return status;
}
status = SetAbstract(tensor_info, param_node, primitive);
@ -367,19 +367,19 @@ STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const Pr
}
auto status = RET_ERROR;
if (is_mixed_bit_) {
status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER, type_id_,
index - 1);
status = MixedBitQuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER,
type_id_, flags.mixedBitWeightQuantParam.init_scale, index - 1);
} else if (type_id_ == kNumberTypeInt8) {
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_, index - 1);
status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
bit_num_, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_, index - 1);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_, index - 1);
status = FixedBitQuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
bit_num_, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_, index - 1);
}
if (status == RET_CONTINUE) {
return RET_OK;
} else if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
MS_LOG(ERROR) << "MixedBitQuantFilter failed : " << status;
return status;
}
status = SetAbstract(tensor_info, param_node, primitive);
@ -390,347 +390,24 @@ STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const Pr
return RET_OK;
}
constexpr float relative_tolerance = 1e-5;
constexpr float abs_tolerance = 1e-4;
template <typename T>
float CompareOutputData(const std::unordered_map<std::string, mindspore::tensor::MSTensor *> &expected_tensor,
const std::unordered_map<std::string, mindspore::tensor::MSTensor *> &compare_tensor) {
auto valid_data = [](T data) -> bool { return (!std::isnan(data) && !std::isinf(data)); };
float total_mean_error = 0.0f;
int tensor_cnt = expected_tensor.size();
if (tensor_cnt <= 0) {
MS_LOG(ERROR) << "unexpected tensor_cnt: " << tensor_cnt;
return RET_ERROR;
}
for (const auto &exp_tensor_pair : expected_tensor) {
float mean_error = 0.0f;
int error_cnt = 0;
auto exp_tensor_name = exp_tensor_pair.first;
auto exp_tensor = exp_tensor_pair.second;
auto cmp_tensor_find_iter = compare_tensor.find(exp_tensor_name);
if (cmp_tensor_find_iter == compare_tensor.end()) {
MS_LOG(ERROR) << "can not find: " << exp_tensor_name;
return RET_ERROR;
}
auto cmp_tensor = cmp_tensor_find_iter->second;
auto exp_tensor_shape = exp_tensor->shape();
auto cmp_tensor_shape = cmp_tensor->shape();
if (exp_tensor_shape != cmp_tensor_shape) {
MS_LOG(ERROR) << "exp tensor shape not equal to cmp. exp_tensor_elem_cnt: " << exp_tensor->ElementsNum()
<< " cmp_tensor_elem_cnt: " << cmp_tensor->ElementsNum();
return RET_ERROR;
}
auto exp_data = static_cast<T *>(exp_tensor->MutableData());
auto cmp_data = static_cast<T *>(cmp_tensor->MutableData());
auto elem_cnt = exp_tensor->ElementsNum();
for (int i = 0; i < elem_cnt; i++) {
if (!valid_data(exp_data[i]) || !valid_data(cmp_data[i])) {
MS_LOG(ERROR) << "data is not valid. exp: " << exp_data[i] << " cmp: " << cmp_data[i] << " index: " << i;
return RET_ERROR;
}
auto tolerance = abs_tolerance + relative_tolerance * fabs(exp_data[i]);
auto abs_error = std::fabs(exp_data[i] - cmp_data[i]);
if (abs_error > tolerance) {
if (fabs(exp_data[i] == 0)) {
if (abs_error > 1e-5) {
mean_error += abs_error;
error_cnt++;
} else {
// it is ok, very close to 0
continue;
}
} else {
mean_error += abs_error / (fabs(exp_data[i]) + FLT_MIN);
error_cnt++;
}
} else {
// it is ok, no error
continue;
}
} // end one tensor data loop
total_mean_error += mean_error / elem_cnt;
} // end tensor loop
return total_mean_error / tensor_cnt;
}
STATUS WeightQuantizer::RunFp32Graph(const FuncGraphPtr &func_graph) {
auto image_cnt = images_.at(0).size();
// 0.1 Create Fp32 Session
flags.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
auto fp32_sm = CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num);
auto fp32_session = fp32_sm.session;
auto fp32_model = fp32_sm.model;
if (fp32_session == nullptr || fp32_model == nullptr) {
MS_LOG(ERROR) << "CreateSessoin fail";
delete fp32_model;
return RET_ERROR;
}
auto fp32_inputs = fp32_session->GetInputs();
fp32_output_tensors_.resize(image_cnt);
// 0.3 save fp32 output
for (size_t i = 0; i < image_cnt; i++) {
for (size_t input_index = 0; input_index < fp32_inputs.size(); input_index++) {
auto status = preprocess::PreProcess(flags.dataPreProcessParam, fp32_inputs[input_index]->tensor_name(), i,
fp32_inputs[input_index]);
if (status != RET_OK) {
MS_LOG(ERROR) << "generate input data from images failed!";
delete fp32_sm.session;
delete fp32_sm.model;
return RET_ERROR;
}
}
auto status = fp32_session->RunGraph();
if (status != RET_OK) {
MS_LOG(ERROR) << "RunGraph fail";
delete fp32_sm.session;
delete fp32_sm.model;
return RET_ERROR;
}
auto fp32_outputs = fp32_session->GetOutputs();
for (const auto &kv : fp32_outputs) {
auto *tensor = kv.second;
auto *lite_tensor = reinterpret_cast<lite::Tensor *>(tensor);
if (lite_tensor == nullptr) {
MS_LOG(ERROR) << "not lite tensor";
delete fp32_sm.session;
delete fp32_sm.model;
return RET_ERROR;
}
auto *new_tensor = Tensor::CopyTensor(*lite_tensor, true);
fp32_output_tensors_[i][kv.first] = new_tensor;
}
}
delete fp32_sm.session;
delete fp32_sm.model;
return RET_OK;
}
STATUS WeightQuantizer::DoMixedQuantize(const FuncGraphPtr &func_graph) {
STATUS WeightQuantizer::MarkWeightQuantizationInNodes(const FuncGraphPtr &func_graph) {
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
auto cnodes = func_graph->GetOrderedCnodes();
int status = RET_OK;
for (auto &cnode : cnodes) {
if (opt::CheckPrimitiveType(cnode, prim::kPrimLstm)) {
status = DoLstmQuantize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoLstmQuantize error";
return RET_ERROR;
}
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimGather)) {
status = DoGatherQuantize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoGatherQuantize error";
return RET_ERROR;
}
}
}
return status;
}
STATUS WeightQuantizer::GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name,
ParameterPtr *param_node, tensor::TensorPtr *tensor_info) {
if (!input_node->isa<Parameter>()) {
MS_LOG(WARNING) << op_name << " the second input is not parameter";
return RET_CONTINUE;
}
*param_node = input_node->cast<ParameterPtr>();
if (!(*param_node)->has_default()) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
return RET_CONTINUE;
}
*tensor_info = std::static_pointer_cast<tensor::Tensor>((*param_node)->default_param());
if (*tensor_info == nullptr) {
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
return RET_CONTINUE;
}
if ((*tensor_info)->data_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << op_name << " the second input type is not float";
return RET_CONTINUE;
}
return RET_OK;
}
STATUS WeightQuantizer::TryQuant(const int &bit_num_t, const ParameterPtr &param_node,
const tensor::TensorPtr &tensor_info, const PrimitivePtr &primitive) {
MS_CHECK_TRUE_RET(primitive != nullptr, RET_NULL_PTR);
MS_CHECK_TRUE_RET(tensor_info != nullptr, RET_NULL_PTR);
MS_CHECK_TRUE_RET(param_node != nullptr, RET_NULL_PTR);
int status;
type_id_ = TypeId::kNumberTypeInt8;
int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
if (type_id_ == TypeId::kNumberTypeInt8) {
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
bit_num_t, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
} else if (type_id_ == TypeId::kNumberTypeInt16) {
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
bit_num_t, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
} else {
MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
return RET_ERROR;
}
if (status == RET_CONTINUE) {
return RET_OK;
} else if (status != RET_OK) {
MS_LOG(ERROR) << "quant filter failed.";
return RET_ERROR;
}
status = SetAbstract(tensor_info, param_node, primitive);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
return status;
}
STATUS WeightQuantizer::EvaluateQuant(const FuncGraphPtr &func_graph, size_t image_cnt, float *mean_error) {
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
if (mean_error == nullptr) {
MS_LOG(ERROR) << "mean_error is nullptr.";
return RET_ERROR;
}
// 2.1 create quant session, get input, output tensor
flags.commonQuantParam.quant_type = schema::QuantType_WeightQuant;
auto quant_sm = CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num);
auto quant_session = std::unique_ptr<session::LiteSession>(quant_sm.session);
int status;
if (quant_session == nullptr) {
MS_LOG(ERROR) << "create session error.";
delete quant_sm.model;
return RET_ERROR;
}
auto quant_inputs = quant_session->GetInputs();
for (size_t i = 0; i < image_cnt; i++) {
// set multi-input data
for (size_t input_index = 0; input_index < quant_inputs.size(); input_index++) {
status = preprocess::PreProcess(flags.dataPreProcessParam, quant_inputs[input_index]->tensor_name(), i,
quant_inputs[input_index]);
if (status != RET_OK) {
MS_LOG(ERROR) << "generate input data from images failed!";
delete quant_sm.model;
return RET_ERROR;
}
}
status = quant_session->RunGraph();
if (status != RET_OK) {
MS_LOG(ERROR) << "quant session run error";
delete quant_sm.model;
return RET_ERROR;
}
// 3. compare between quant and fp32
auto quant_outputs = quant_session->GetOutputs();
(*mean_error) += CompareOutputData<float>(fp32_output_tensors_[i], quant_outputs);
} // end_for: calib data loop
delete quant_sm.model;
(*mean_error) = (*mean_error) / image_cnt;
return RET_OK;
}
STATUS WeightQuantizer::DoQuantSearch(const FuncGraphPtr &func_graph) {
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
auto cnodes = func_graph->GetOrderedCnodes();
size_t image_cnt = images_.at(0).size();
int status = RET_OK;
for (auto iter = cnodes.end(); iter != cnodes.begin();) {
auto cnode = *(--iter);
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
for (auto &cnode : func_graph->GetOrderedCnodes()) {
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is null.";
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
continue;
}
auto status = DoMarkWeightQuantizeIfQuantized(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "MarkWeightQuantizationInNodes error marking " << cnode->fullname_with_scope();
return RET_ERROR;
}
auto op_name = cnode->fullname_with_scope();
MS_LOG(DEBUG) << "process node: " << op_name << " type: " << primitive->name();
if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) {
auto input_node = cnode->input(2);
ParameterPtr param_node;
tensor::TensorPtr tensor_info;
status = GetParamNodeAndValue(input_node, op_name, &param_node, &tensor_info);
if (status == RET_CONTINUE) {
continue;
}
// copy origin data in case to recover
auto *raw_data = static_cast<float *>(tensor_info->data_c());
auto elem_count = tensor_info->DataSize();
auto val = std::make_unique<float>(elem_count);
MS_CHECK_TRUE_RET(val != nullptr, RET_NULL_PTR);
std::unique_ptr<float[]> origin_data(val.release());
auto ret = memcpy_s(origin_data.get(), sizeof(float) * elem_count, raw_data, tensor_info->Size());
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy fail: "
<< " dst size: " << sizeof(float) * elem_count << " src size: " << tensor_info->Size();
return RET_ERROR;
}
// 1. try quant
for (size_t bit_num_t = 2; bit_num_t <= kMaxBit; bit_num_t++) {
status = TryQuant(bit_num_t, param_node, tensor_info, primitive);
if (status != RET_OK) {
MS_LOG(ERROR) << "TryQuant failed.";
return RET_ERROR;
}
// 2. evaluate the quant
float mean_error = 0.0f;
status = EvaluateQuant(func_graph, image_cnt, &mean_error);
if (status != RET_OK) {
MS_LOG(ERROR) << "EvaluateQuant failed.";
return RET_ERROR;
}
if (mean_error <= config_param_.mean_error_threshold) {
MS_LOG(DEBUG) << "op: " << op_name << " got mixed bit: " << bit_num_t << " mean_error: " << mean_error;
opname_bit_[op_name] = bit_num_t;
break;
} else if (bit_num_t != kMaxBit) {
MS_LOG(DEBUG) << "op: " << op_name << " intermediate bit: " << bit_num_t << " mean_error: " << mean_error
<< " [recover]";
// recover
status =
UpdateTensorDataAndSize(tensor_info, origin_data.get(), sizeof(float) * elem_count, kNumberTypeFloat32);
if (status != RET_OK) {
MS_LOG(ERROR) << "UpdateTensorDataAndSize fail";
return RET_ERROR;
}
} else {
MS_LOG(DEBUG) << "op: " << op_name << " set bit: " << bit_num_t << " mean_error: " << mean_error;
opname_bit_[op_name] = bit_num_t;
}
} // end bit loop
} // if: conv and matmul
} // end loop: all cnode
return status;
}
STATUS WeightQuantizer::DoMixedQuant(const FuncGraphPtr &func_graph) {
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
auto status = RunFp32Graph(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "RunFp32Graph failed.";
return RET_ERROR;
}
status = DoMixedQuantize(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoMixedQuantize failed.";
return RET_ERROR;
}
status = DoQuantSearch(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantSearch failed.";
return RET_ERROR;
}
for (const auto &kv : opname_bit_) {
MS_LOG(INFO) << "op: " << kv.first << " bit:" << kv.second;
}
return RET_OK;
}
STATUS WeightQuantizer::DoFixedQuant(const FuncGraphPtr &func_graph) {
STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
weight_quantized_tensors_.clear();
@ -779,35 +456,4 @@ STATUS WeightQuantizer::DoFixedQuant(const FuncGraphPtr &func_graph) {
}
return MarkWeightQuantizationInNodes(func_graph);
}
STATUS WeightQuantizer::MarkWeightQuantizationInNodes(const FuncGraphPtr &func_graph) {
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
for (auto &cnode : func_graph->GetOrderedCnodes()) {
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
if (primitive == nullptr) {
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
continue;
}
auto status = DoMarkWeightQuantizeIfQuantized(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "MarkWeightQuantizationInNodes error marking " << cnode->fullname_with_scope();
return RET_ERROR;
}
}
return RET_OK;
}
STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
if (config_param_.mixed) {
bit_num_ = kMaxBit;
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
type_id_ = kNumberTypeInt8;
MS_LOG(INFO) << "Do mixed bit quantization";
return DoMixedQuant(func_graph);
}
return DoFixedQuant(func_graph);
}
} // namespace mindspore::lite::quant

View File

@ -39,7 +39,6 @@ namespace mindspore::lite::quant {
class WeightQuantizer : public Quantizer {
public:
WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config);
WeightQuantizer(FuncGraphPtr graph, const FullQuantParam &config);
~WeightQuantizer() override;
STATUS DoQuantize(FuncGraphPtr func_graph) override;
@ -54,33 +53,18 @@ class WeightQuantizer : public Quantizer {
int quant_max_{127};
int quant_min_{-128};
TypeId type_id_{kNumberTypeInt8};
std::map<std::string, int> opname_bit_;
private:
std::unique_ptr<QuantStrategy> quant_strategy_;
size_t bit_num_{8};
std::map<tensor::TensorPtr, ParameterPtr> weight_quantized_tensors_;
FullQuantParam config_param_;
std::map<std::string, std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...]
std::vector<std::unordered_map<std::string, mindspore::tensor::MSTensor *>> fp32_output_tensors_;
bool is_mixed_bit_ = false;
STATUS DoMixedQuant(const FuncGraphPtr &);
STATUS SetAbstract(const tensor::TensorPtr &tensor_info, const ParameterPtr &param_node,
const PrimitivePtr &primitive);
STATUS DoFixedQuant(const FuncGraphPtr &);
STATUS MarkWeightQuantizationInNodes(const FuncGraphPtr &);
STATUS DoMarkWeightQuantizeIfQuantized(const CNodePtr &);
STATUS RunFp32Graph(const FuncGraphPtr &);
STATUS DoMixedQuantize(const FuncGraphPtr &func_graph);
STATUS CheckImageCnt();
static STATUS GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name,
ParameterPtr *param_node, tensor::TensorPtr *tensor_info);
STATUS TryQuant(const int &bit_num_t, const ParameterPtr &param_node, const tensor::TensorPtr &tensor_info,
const PrimitivePtr &primitive);
STATUS DoQuantSearch(const FuncGraphPtr &func_graph);
STATUS EvaluateQuant(const FuncGraphPtr &func_graph, size_t image_cnt, float *mean_error);
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_