forked from mindspore-Ecosystem/mindspore
!22918 fix mixed bit quant bug && delete unused weight quant code
Merge pull request !22918 from yeyunpeng2020/quant_bak
This commit is contained in:
commit
5bbb0647ee
|
@ -5,7 +5,6 @@ source ./scripts/base_functions.sh
|
|||
function Run_Converter() {
|
||||
# Unzip x86 runtime and converter
|
||||
cd ${x86_path} || exit 1
|
||||
tar -zxf ${x86_path}/avx/mindspore-lite-${version}-linux-x64.tar.gz || exit 1
|
||||
tar -zxf mindspore-lite-${version}-linux-x64.tar.gz || exit 1
|
||||
cd ${x86_path}/mindspore-lite-${version}-linux-x64/ || exit 1
|
||||
|
||||
|
|
|
@ -23,15 +23,8 @@ SET INSTRUCTION=%3
|
|||
SET BASEPATH=%CD%
|
||||
SET RET_CODE=0
|
||||
|
||||
SET PACKAGE_PATH_CONVERT=%PACKAGE_PATH:"=%\windows_x64\avx
|
||||
SET PACKAGE_PATH=%PACKAGE_PATH:"=%\windows_x64\%INSTRUCTION%
|
||||
7z x -r "%PACKAGE_PATH_CONVERT%\mindspore-lite-*.zip"
|
||||
IF NOT %errorlevel% == 0 (
|
||||
echo "Decompression of runtime tool fail!"
|
||||
SET RET_CODE=1
|
||||
goto run_eof
|
||||
)
|
||||
echo A | 7z x -r "%PACKAGE_PATH%\mindspore-lite-*.zip"
|
||||
7z x -r "%PACKAGE_PATH%\mindspore-lite-*.zip"
|
||||
IF NOT %errorlevel% == 0 (
|
||||
echo "Decompression of runtime tool fail!"
|
||||
SET RET_CODE=1
|
||||
|
|
|
@ -42,7 +42,7 @@
|
|||
#include "tools/common/node_util.h"
|
||||
#include "tools/converter/converter_context.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
|
||||
#include "tools/converter/quantizer/mixed_bit_weight_quantizer.h"
|
||||
#include "tools/converter/quantizer/fse_encoder.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
|
|
|
@ -84,6 +84,7 @@ int SplitLineToMap(std::ifstream *ifs, std::map<std::string, std::map<std::strin
|
|||
}
|
||||
auto split_vector = SplitStringToVector(raw_line, split_delimiter);
|
||||
if (split_vector.size() != 2) {
|
||||
MS_LOG(ERROR) << "split vector size != 2";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::string key = split_vector.at(0);
|
||||
|
|
|
@ -62,7 +62,7 @@
|
|||
#include "tools/optimizer/graph/decrease_transpose_algo.h"
|
||||
#include "tools/optimizer/graph/specify_graph_input_format.h"
|
||||
#include "tools/optimizer/graph/dump_graph.h"
|
||||
#include "tools/converter/quantizer/post_training_quantizer.h"
|
||||
#include "tools/converter/quantizer/full_quant_quantizer.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||
#include "tools/optimizer/parallel/split_strategy.h"
|
||||
|
@ -278,9 +278,9 @@ void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGr
|
|||
int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
// quant
|
||||
if (config->commonQuantParam.quant_type == schema::QuantType_PostTraining) {
|
||||
this->m_quantizer_ = std::make_unique<quant::PostTrainingQuantizer>(old_graph, config->commonQuantParam.bit_num);
|
||||
this->m_quantizer_ = std::make_unique<quant::FullQuantQuantizer>(old_graph, config->commonQuantParam.bit_num);
|
||||
if (m_quantizer_ == nullptr) {
|
||||
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
|
||||
MS_LOG(ERROR) << "New FullQuantQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@ int ConfigFileParser::ParseConfigFile(const std::string &config_file_path) {
|
|||
std::map<std::string, std::map<std::string, std::string>> maps;
|
||||
auto ret = mindspore::lite::ParseConfigFile(config_file_path, &maps);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "image_path=input1:/mnt/calibration_input1_path";
|
||||
MS_LOG(ERROR) << "Parse config file failed.";
|
||||
return ret;
|
||||
}
|
||||
ret = ParseDataPreProcessString(maps);
|
||||
|
|
|
@ -121,62 +121,20 @@ int PreprocessParser::ParseCalibratePath(const std::string &str, std::map<std::s
|
|||
|
||||
int PreprocessParser::ParseImagePreProcess(const DataPreProcessString &data_pre_process_str,
|
||||
preprocess::ImagePreProcessParam *image_pre_process) {
|
||||
if (!data_pre_process_str.resize_width.empty()) {
|
||||
if (!ConvertIntNum(data_pre_process_str.resize_width, &image_pre_process->resize_width)) {
|
||||
MS_LOG(ERROR) << "resize_width should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (image_pre_process->resize_width <= 0) {
|
||||
MS_LOG(ERROR) << "resize_width must be > 0";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
auto ret = ParseImageNormalize(data_pre_process_str, image_pre_process);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Parse image normalize failed.";
|
||||
return ret;
|
||||
}
|
||||
if (!data_pre_process_str.resize_height.empty()) {
|
||||
if (!ConvertIntNum(data_pre_process_str.resize_height, &image_pre_process->resize_height)) {
|
||||
MS_LOG(ERROR) << "resize_width should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (image_pre_process->resize_height <= 0) {
|
||||
MS_LOG(ERROR) << "resize_height must be > 0";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
ret = ParseImageCenterCrop(data_pre_process_str, image_pre_process);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Parse image center crop failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
if (!data_pre_process_str.resize_method.empty()) {
|
||||
image_pre_process->resize_method = preprocess::ConvertResizeMethod(data_pre_process_str.resize_method);
|
||||
}
|
||||
|
||||
if (!data_pre_process_str.normalize_mean.empty() &&
|
||||
!ConvertDoubleVector(data_pre_process_str.normalize_mean, &image_pre_process->normalize_mean)) {
|
||||
MS_LOG(ERROR) << "Convert normalize_mean failed.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (!data_pre_process_str.normalize_std.empty() &&
|
||||
!ConvertDoubleVector(data_pre_process_str.normalize_std, &image_pre_process->normalize_std)) {
|
||||
MS_LOG(ERROR) << "Convert normalize_std failed.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (!data_pre_process_str.center_crop_width.empty()) {
|
||||
if (!ConvertIntNum(data_pre_process_str.center_crop_width, &image_pre_process->center_crop_width)) {
|
||||
MS_LOG(ERROR) << "center_crop_width should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (image_pre_process->center_crop_width <= 0) {
|
||||
MS_LOG(ERROR) << "center_crop_width must be > 0";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
if (!data_pre_process_str.center_crop_height.empty()) {
|
||||
if (!ConvertIntNum(data_pre_process_str.center_crop_height, &image_pre_process->center_crop_height)) {
|
||||
MS_LOG(ERROR) << "center_crop_height should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (image_pre_process->center_crop_height <= 0) {
|
||||
MS_LOG(ERROR) << "center_crop_height must be > 0";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
ret = ParseImageResize(data_pre_process_str, image_pre_process);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Parse image resize failed.";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -239,6 +197,73 @@ int PreprocessParser::CollectCalibInputs(const std::map<std::string, std::string
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
int PreprocessParser::ParseImageNormalize(const DataPreProcessString &data_pre_process_str,
|
||||
preprocess::ImagePreProcessParam *image_pre_process) {
|
||||
if (!data_pre_process_str.normalize_mean.empty() &&
|
||||
!ConvertDoubleVector(data_pre_process_str.normalize_mean, &image_pre_process->normalize_mean)) {
|
||||
MS_LOG(ERROR) << "Convert normalize_mean failed.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (!data_pre_process_str.normalize_std.empty() &&
|
||||
!ConvertDoubleVector(data_pre_process_str.normalize_std, &image_pre_process->normalize_std)) {
|
||||
MS_LOG(ERROR) << "Convert normalize_std failed.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
int PreprocessParser::ParseImageResize(const DataPreProcessString &data_pre_process_str,
|
||||
preprocess::ImagePreProcessParam *image_pre_process) {
|
||||
if (!data_pre_process_str.resize_width.empty()) {
|
||||
if (!ConvertIntNum(data_pre_process_str.resize_width, &image_pre_process->resize_width)) {
|
||||
MS_LOG(ERROR) << "resize_width should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (image_pre_process->resize_width <= 0 || image_pre_process->resize_width > 65535) {
|
||||
MS_LOG(ERROR) << "resize_width must be in (0,65535].";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
if (!data_pre_process_str.resize_height.empty()) {
|
||||
if (!ConvertIntNum(data_pre_process_str.resize_height, &image_pre_process->resize_height)) {
|
||||
MS_LOG(ERROR) << "resize_width should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (image_pre_process->resize_height <= 0 || image_pre_process->resize_height > 65535) {
|
||||
MS_LOG(ERROR) << "resize_height must be in (0,65535].";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
if (!data_pre_process_str.resize_method.empty()) {
|
||||
image_pre_process->resize_method = preprocess::ConvertResizeMethod(data_pre_process_str.resize_method);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
int PreprocessParser::ParseImageCenterCrop(const DataPreProcessString &data_pre_process_str,
|
||||
preprocess::ImagePreProcessParam *image_pre_process) {
|
||||
if (!data_pre_process_str.center_crop_width.empty()) {
|
||||
if (!ConvertIntNum(data_pre_process_str.center_crop_width, &image_pre_process->center_crop_width)) {
|
||||
MS_LOG(ERROR) << "center_crop_width should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (image_pre_process->center_crop_width <= 0 || image_pre_process->center_crop_width > 65535) {
|
||||
MS_LOG(ERROR) << "center_crop_width must be in (0,65535].";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
if (!data_pre_process_str.center_crop_height.empty()) {
|
||||
if (!ConvertIntNum(data_pre_process_str.center_crop_height, &image_pre_process->center_crop_height)) {
|
||||
MS_LOG(ERROR) << "center_crop_height should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (image_pre_process->center_crop_height <= 0 || image_pre_process->center_crop_height > 65535) {
|
||||
MS_LOG(ERROR) << "center_crop_height must be in (0,65535].";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,6 +35,15 @@ class PreprocessParser {
|
|||
int ParseImagePreProcess(const DataPreProcessString &data_pre_process_str,
|
||||
preprocess::ImagePreProcessParam *image_pre_process);
|
||||
|
||||
int ParseImageNormalize(const DataPreProcessString &data_pre_process_str,
|
||||
preprocess::ImagePreProcessParam *image_pre_process);
|
||||
|
||||
int ParseImageResize(const DataPreProcessString &data_pre_process_str,
|
||||
preprocess::ImagePreProcessParam *image_pre_process);
|
||||
|
||||
int ParseImageCenterCrop(const DataPreProcessString &data_pre_process_str,
|
||||
preprocess::ImagePreProcessParam *image_pre_process);
|
||||
|
||||
int ParseImageToFormat(const std::string &image_to_format_str, preprocess::ImageToFormat *image_to_format);
|
||||
|
||||
int ParseCalibratePath(const std::string &str, std::map<std::string, std::string> *value);
|
||||
|
|
|
@ -58,13 +58,13 @@ int QuantParamParser::ParseCommonQuant(const CommonQuantString &common_quant_str
|
|||
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (common_quant->min_quant_weight_size < 0) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should be greater than or equal to zero." << std::endl;
|
||||
if (common_quant->min_quant_weight_size < 0 || common_quant->min_quant_weight_size > 65535) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_size should in [0,65535]." << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
if (common_quant->min_quant_weight_channel < 0) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should be greater than or equal to zero." << std::endl;
|
||||
if (common_quant->min_quant_weight_channel < 0 || common_quant->min_quant_weight_size > 65535) {
|
||||
MS_LOG(ERROR) << "INPUT ILLEGAL: min_quant_weight_channel should in [0,65535]." << std::endl;
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
|
@ -8,14 +8,14 @@ file(GLOB QUANTIZER
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/quant_helper/*
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/full_quant_quantizer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/huffman_encode.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fse_decoder.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fse_bit_stream.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fse_encoder.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/fix_bit_weight_quantizer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mixed_bit_weight_quantizer.cc
|
||||
)
|
||||
set_property(SOURCE ${QUANTIZER} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE)
|
||||
add_library(quantizer_mid OBJECT ${QUANTIZER})
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include "tools/converter/quantizer/fse_bit_stream.h"
|
||||
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
|
||||
#include "tools/converter/quantizer/mixed_bit_weight_quantizer.h"
|
||||
namespace mindspore::lite::quant {
|
||||
constexpr int MAX_SYMS = 65534;
|
||||
constexpr int MAX_TABLE_LOG = 16;
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/quantizer/post_training_quantizer.h"
|
||||
#include "tools/converter/quantizer/full_quant_quantizer.h"
|
||||
#include <dirent.h>
|
||||
#include <future>
|
||||
#include <map>
|
||||
|
@ -43,7 +43,6 @@
|
|||
#include "tools/common/tensor_util.h"
|
||||
#include "src/common/quant_utils.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||
#include "tools/converter/preprocess/image_preprocess.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
|
@ -330,53 +329,6 @@ std::pair<CNodePtr, int32_t> DivergInfo::GetZeropoint() {
|
|||
return std::make_pair(this->cnode, zero_point);
|
||||
}
|
||||
|
||||
std::unordered_map<CNodePtr, float> Calibrator::GetScale(
|
||||
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
|
||||
MS_ASSERT(diverg_info != nullptr);
|
||||
std::unordered_map<CNodePtr, float> result;
|
||||
for (auto &iter : *diverg_info) {
|
||||
DivergInfo *info = iter.second.get();
|
||||
auto item = info->GetScale();
|
||||
result.insert(item);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unordered_map<CNodePtr, int32_t> Calibrator::GetZeropoint(
|
||||
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
|
||||
MS_ASSERT(diverg_info != nullptr);
|
||||
std::unordered_map<CNodePtr, int32_t> result;
|
||||
for (auto &iter : *diverg_info) {
|
||||
DivergInfo *info = iter.second.get();
|
||||
auto zeropoint = info->GetZeropoint();
|
||||
result.insert(zeropoint);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::map<CNodePtr, MaxMin> Calibrator::GetMinMax(
|
||||
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info) {
|
||||
MS_ASSERT(diverg_info != nullptr);
|
||||
std::map<CNodePtr, MaxMin> result;
|
||||
for (auto &iter : *diverg_info) {
|
||||
DivergInfo *info = iter.second.get();
|
||||
mindspore::lite::quant::MaxMin input_maxmin{};
|
||||
input_maxmin.min = info->min;
|
||||
input_maxmin.max = info->max;
|
||||
result[info->cnode] = input_maxmin;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void Calibrator::Dump() {
|
||||
for (auto &kv : this->inputs_diverg_info_) {
|
||||
auto &infos = kv.second;
|
||||
for (auto &info : infos) {
|
||||
info->DumpHistogram();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *Calibrator::GetInputDivergInfo() {
|
||||
return &this->inputs_diverg_info_;
|
||||
}
|
||||
|
@ -479,7 +431,7 @@ STATUS Calibrator::GenerateInputData(const std::string &input_name, size_t image
|
|||
return preprocess::PreProcess(data_pre_process_param_, input_name, image_index, tensor);
|
||||
}
|
||||
|
||||
PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type, bool per_channel)
|
||||
FullQuantQuantizer::FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type, bool per_channel)
|
||||
: Quantizer(std::move(graph)) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
this->per_channel_ = per_channel;
|
||||
|
@ -501,15 +453,15 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, int bit_num, Ty
|
|||
}
|
||||
}
|
||||
|
||||
PostTrainingQuantizer::~PostTrainingQuantizer() {
|
||||
FullQuantQuantizer::~FullQuantQuantizer() {
|
||||
delete fp32_session_;
|
||||
delete fp32_model_;
|
||||
delete int8_session_;
|
||||
delete int8_model_;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min,
|
||||
const PrimitivePtr &primitive) const {
|
||||
STATUS FullQuantQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min,
|
||||
const PrimitivePtr &primitive) const {
|
||||
MS_ASSERT(max_min != nullptr);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
|
@ -529,8 +481,8 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, stru
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min,
|
||||
const PrimitivePtr &primitive) const {
|
||||
STATUS FullQuantQuantizer::DoQuantOutput(double scale, int zeropoint, struct MaxMin *max_min,
|
||||
const PrimitivePtr &primitive) const {
|
||||
MS_ASSERT(max_min != nullptr);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
|
@ -550,8 +502,8 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight,
|
||||
const PrimitivePtr &primitive, bool per_channel) const {
|
||||
STATUS FullQuantQuantizer::DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight,
|
||||
const PrimitivePtr &primitive, bool per_channel) const {
|
||||
MS_ASSERT(weight != nullptr);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
// perlayer
|
||||
|
@ -572,19 +524,9 @@ STATUS PostTrainingQuantizer::DoWeightQuant(const std::string &op_name, const An
|
|||
auto bit_num_t = bit_num;
|
||||
auto quant_max_t = quant_max;
|
||||
auto quant_min_t = quant_min;
|
||||
if (calibrator_->full_quant_param_.mixed) {
|
||||
auto opname_iter = opname_bit_.find(op_name);
|
||||
if (opname_iter == opname_bit_.end()) {
|
||||
MS_LOG(WARNING) << op_name << " not in the opname_bit_ map";
|
||||
} else {
|
||||
bit_num_t = opname_iter->second;
|
||||
quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
|
||||
quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
|
||||
}
|
||||
}
|
||||
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||
auto status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_PostTraining, quant_max_t, quant_min_t, bit_num_t,
|
||||
weight_quant_type, kNumberTypeInt8);
|
||||
auto status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_PostTraining, quant_max_t, quant_min_t,
|
||||
bit_num_t, weight_quant_type, kNumberTypeInt8);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||
return status;
|
||||
|
@ -608,7 +550,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(const std::string &op_name, const An
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::DoBiasQuant(const AnfNodePtr &bias, const PrimitivePtr &primitive) {
|
||||
STATUS FullQuantQuantizer::DoBiasQuant(const AnfNodePtr &bias, const PrimitivePtr &primitive) {
|
||||
if (primitive == nullptr || bias == nullptr) {
|
||||
MS_LOG(ERROR) << "null pointer!";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -698,7 +640,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(const AnfNodePtr &bias, const Primitiv
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
|
||||
STATUS FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto inputs_diverg_info = calibrator_->GetInputDivergInfo();
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
|
@ -780,7 +722,7 @@ STATUS PostTrainingQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::QuantNode() {
|
||||
STATUS FullQuantQuantizer::QuantNode() {
|
||||
auto inputs_diverg_info = calibrator_->GetInputDivergInfo();
|
||||
auto outputs_diverg_info = calibrator_->GetOutputDivergInfo();
|
||||
|
||||
|
@ -879,7 +821,7 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::UpdateDivergInverval() {
|
||||
STATUS FullQuantQuantizer::UpdateDivergInverval() {
|
||||
this->calibrator_->UpdateDivergInverval(this->calibrator_->GetInputDivergInfo());
|
||||
this->calibrator_->UpdateDivergInverval(this->calibrator_->GetOutputDivergInfo());
|
||||
return RET_OK;
|
||||
|
@ -894,7 +836,7 @@ STATUS PostTrainingQuantizer::UpdateDivergInverval() {
|
|||
* 2.1 parse image files to input tensor
|
||||
* 3. save quantied node
|
||||
**/
|
||||
STATUS PostTrainingQuantizer::PreProcess() {
|
||||
STATUS FullQuantQuantizer::PreProcess() {
|
||||
// 3. collect to be quantized operators
|
||||
// from user input
|
||||
QuantStrategy strategy(kMillisecondsBase);
|
||||
|
@ -920,8 +862,8 @@ STATUS PostTrainingQuantizer::PreProcess() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::CheckFp32TensorVec(const std::string &node_name,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &tensor_vec) {
|
||||
STATUS FullQuantQuantizer::CheckFp32TensorVec(const std::string &node_name,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &tensor_vec) {
|
||||
MS_ASSERT(tensor_vec != nullptr);
|
||||
if (tensor_vec.empty()) {
|
||||
MS_LOG(ERROR) << "node: " << node_name << " input tensors is 0";
|
||||
|
@ -942,7 +884,7 @@ STATUS PostTrainingQuantizer::CheckFp32TensorVec(const std::string &node_name,
|
|||
* 2. insert callback to session
|
||||
* 3. run session
|
||||
**/
|
||||
STATUS PostTrainingQuantizer::DoInference() {
|
||||
STATUS FullQuantQuantizer::DoInference() {
|
||||
// get input tensor
|
||||
vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
|
||||
if (inputs.size() != calibrator_->GetInputNum()) {
|
||||
|
@ -968,7 +910,7 @@ STATUS PostTrainingQuantizer::DoInference() {
|
|||
if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
|
||||
return true;
|
||||
}
|
||||
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
|
||||
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
|
||||
return false;
|
||||
}
|
||||
if ((*diverg_info_map)[callParam.node_name].size() == 1 &&
|
||||
|
@ -999,7 +941,7 @@ STATUS PostTrainingQuantizer::DoInference() {
|
|||
if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
|
||||
return true;
|
||||
}
|
||||
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
|
||||
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
|
||||
return false;
|
||||
}
|
||||
if ((*diverg_info_map)[callParam.node_name].size() == 1 && afterOutputs.size() > 1) {
|
||||
|
@ -1029,7 +971,7 @@ STATUS PostTrainingQuantizer::DoInference() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::Int8Inference() {
|
||||
STATUS FullQuantQuantizer::Int8Inference() {
|
||||
// int8 inference
|
||||
vector<mindspore::tensor::MSTensor *> inputs = int8_session_->GetInputs();
|
||||
for (auto input_tensor : inputs) {
|
||||
|
@ -1059,8 +1001,8 @@ STATUS PostTrainingQuantizer::Int8Inference() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
|
||||
std::future<STATUS> int8_inference = std::async(std::launch::async, &PostTrainingQuantizer::Int8Inference, this);
|
||||
STATUS FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
|
||||
std::future<STATUS> int8_inference = std::async(std::launch::async, &FullQuantQuantizer::Int8Inference, this);
|
||||
// get input tensor
|
||||
vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
|
||||
if (inputs.size() != 1) {
|
||||
|
@ -1115,7 +1057,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
|
|||
return status;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
STATUS FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
const auto &bias_diff = op_bias_diff_map[op_name];
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
|
@ -1196,7 +1138,7 @@ STATUS PostTrainingQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, con
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::CollectDataFrequency() {
|
||||
STATUS FullQuantQuantizer::CollectDataFrequency() {
|
||||
// get input tensor
|
||||
vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
|
||||
if (inputs.size() != calibrator_->GetInputNum()) {
|
||||
|
@ -1221,7 +1163,7 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
|
|||
if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
|
||||
return true;
|
||||
}
|
||||
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
|
||||
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < (*diverg_info_map)[callParam.node_name].size(); i++) {
|
||||
|
@ -1243,7 +1185,7 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
|
|||
if (diverg_info_map->find(call_param.node_name) == diverg_info_map->end()) {
|
||||
return true;
|
||||
}
|
||||
if (PostTrainingQuantizer::CheckFp32TensorVec(call_param.node_name, after_outputs) != RET_OK) {
|
||||
if (FullQuantQuantizer::CheckFp32TensorVec(call_param.node_name, after_outputs) != RET_OK) {
|
||||
return false;
|
||||
}
|
||||
int output_i = 0;
|
||||
|
@ -1267,9 +1209,9 @@ STATUS PostTrainingQuantizer::CollectDataFrequency() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostTrainingQuantizer::ComputeThreshold() { return this->calibrator_->ComputeThreshold(); }
|
||||
STATUS FullQuantQuantizer::ComputeThreshold() { return this->calibrator_->ComputeThreshold(); }
|
||||
|
||||
STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||
STATUS FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||
MS_LOG(INFO) << "start to parse config file";
|
||||
if (this->calibrator_ == nullptr) {
|
||||
MS_LOG(ERROR) << "calibrator is null!";
|
||||
|
@ -1278,22 +1220,6 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
calibrator_->full_quant_param_ = flags.fullQuantParam;
|
||||
calibrator_->data_pre_process_param_ = flags.dataPreProcessParam;
|
||||
STATUS status;
|
||||
if (calibrator_->full_quant_param_.mixed) { // get opname_bit map
|
||||
auto weight_quant_func_graph = CopyFuncGraph(func_graph);
|
||||
if (weight_quant_func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "CopyFuncGraph error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
WeightQuantizer weight_quantizer(weight_quant_func_graph, calibrator_->full_quant_param_);
|
||||
weight_quantizer.flags = flags;
|
||||
status = weight_quantizer.DoQuantize(weight_quant_func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do mix weight quant error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
opname_bit_ = weight_quantizer.opname_bit_;
|
||||
}
|
||||
|
||||
status = PreProcess();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "do pre process failed!";
|
||||
|
@ -1366,7 +1292,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
bool PostTrainingQuantizer::OpInputDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
|
||||
bool FullQuantQuantizer::OpInputDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
|
||||
MS_ASSERT(data != nullptr);
|
||||
std::lock_guard<std::mutex> lg(mutex_op_input);
|
||||
if (type == STORE) {
|
||||
|
@ -1390,8 +1316,7 @@ bool PostTrainingQuantizer::OpInputDataHandle(OperationType type, const string &
|
|||
return false;
|
||||
}
|
||||
|
||||
bool PostTrainingQuantizer::OpOutputChMeanDataHandle(OperationType type, const string &op_name,
|
||||
std::vector<float> *data) {
|
||||
bool FullQuantQuantizer::OpOutputChMeanDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
|
||||
MS_ASSERT(data != nullptr);
|
||||
std::lock_guard<std::mutex> lg(mutex_op_output);
|
||||
if (type == STORE) {
|
||||
|
@ -1415,14 +1340,14 @@ bool PostTrainingQuantizer::OpOutputChMeanDataHandle(OperationType type, const s
|
|||
return false;
|
||||
}
|
||||
|
||||
KernelCallBack PostTrainingQuantizer::GetBeforeCallBack(bool int8_op) {
|
||||
KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
|
||||
KernelCallBack before_call_back;
|
||||
if (!int8_op) {
|
||||
before_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
|
||||
const CallBackParam &callParam) -> bool {
|
||||
if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
|
||||
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
|
||||
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, beforeInputs) != RET_OK) {
|
||||
return false;
|
||||
}
|
||||
auto tensor = beforeInputs[0];
|
||||
|
@ -1495,7 +1420,7 @@ KernelCallBack PostTrainingQuantizer::GetBeforeCallBack(bool int8_op) {
|
|||
return before_call_back;
|
||||
}
|
||||
|
||||
KernelCallBack PostTrainingQuantizer::GetAfterCallBack(bool int8_op) {
|
||||
KernelCallBack FullQuantQuantizer::GetAfterCallBack(bool int8_op) {
|
||||
KernelCallBack after_call_back;
|
||||
if (!int8_op) {
|
||||
return GetFloatAfterCallBack();
|
||||
|
@ -1503,7 +1428,7 @@ KernelCallBack PostTrainingQuantizer::GetAfterCallBack(bool int8_op) {
|
|||
return GetInt8AfterCallBack();
|
||||
}
|
||||
|
||||
KernelCallBack PostTrainingQuantizer::GetInt8AfterCallBack() {
|
||||
KernelCallBack FullQuantQuantizer::GetInt8AfterCallBack() {
|
||||
KernelCallBack after_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &afterInputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
|
||||
const CallBackParam &callParam) -> bool {
|
||||
|
@ -1580,12 +1505,12 @@ KernelCallBack PostTrainingQuantizer::GetInt8AfterCallBack() {
|
|||
return after_call_back;
|
||||
}
|
||||
|
||||
KernelCallBack PostTrainingQuantizer::GetFloatAfterCallBack() {
|
||||
KernelCallBack FullQuantQuantizer::GetFloatAfterCallBack() {
|
||||
KernelCallBack after_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &afterInputs,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
|
||||
const CallBackParam &callParam) -> bool {
|
||||
if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
|
||||
if (PostTrainingQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
|
||||
if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
|
||||
return false;
|
||||
}
|
||||
auto tensor = afterOutputs[0];
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_POSTRAINING_QUANTIZER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_POSTRAINING_QUANTIZER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FULL_QUANT_QUANTIZER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FULL_QUANT_QUANTIZER_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
@ -31,7 +31,6 @@
|
|||
#include "tools/converter/converter.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||
#include "tools/converter/quantizer/quant_params.h"
|
||||
#include "tools/converter/preprocess/preprocess_param.h"
|
||||
|
||||
|
@ -46,10 +45,10 @@ struct MaxMin {
|
|||
|
||||
constexpr int kDefaultBinNumber = 2048;
|
||||
|
||||
class PostTrainingQuantizer : public Quantizer {
|
||||
class FullQuantQuantizer : public Quantizer {
|
||||
public:
|
||||
PostTrainingQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type = kNumberTypeInt8, bool per_channel = true);
|
||||
~PostTrainingQuantizer() override;
|
||||
FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type = kNumberTypeInt8, bool per_channel = true);
|
||||
~FullQuantQuantizer() override;
|
||||
|
||||
STATUS DoQuantize(FuncGraphPtr func_graph) override;
|
||||
|
||||
|
@ -201,19 +200,9 @@ class Calibrator {
|
|||
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info);
|
||||
|
||||
static STATUS UpdateDataFrequency(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info);
|
||||
void Dump();
|
||||
|
||||
STATUS ComputeThreshold();
|
||||
|
||||
static std::unordered_map<CNodePtr, float> GetScale(
|
||||
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
|
||||
|
||||
static std::unordered_map<CNodePtr, int32_t> GetZeropoint(
|
||||
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
|
||||
|
||||
static std::map<CNodePtr, MaxMin> GetMinMax(
|
||||
std::unordered_map<std::string, std::unique_ptr<DivergInfo>> *diverg_info);
|
||||
|
||||
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetInputDivergInfo();
|
||||
|
||||
std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetOutputDivergInfo();
|
||||
|
@ -232,4 +221,4 @@ class Calibrator {
|
|||
int quant_min_;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_POSTRAINING_QUANTIZER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FULL_QUANT_QUANTIZER_H
|
|
@ -14,15 +14,15 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
|
||||
#include "tools/converter/quantizer/mixed_bit_weight_quantizer.h"
|
||||
#include <cmath>
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
// the error is currently measured per channel.
|
||||
// it could be measured per layer but it would be less good.
|
||||
// the `preferred` dim should point to the output channels dimension.
|
||||
float FixBitWeightQuantizer::MeasureQuantizationError(float *weights, const int *shape, int dims, int preferred_dim,
|
||||
float scale) {
|
||||
float MixedBitWeightQuantizer::MeasureQuantizationError(float *weights, const int *shape, int dims, int preferred_dim,
|
||||
float scale) {
|
||||
MS_ASSERT(weights != nullptr);
|
||||
MS_ASSERT(shape != nullptr);
|
||||
int numel = 1;
|
||||
|
@ -61,13 +61,13 @@ float FixBitWeightQuantizer::MeasureQuantizationError(float *weights, const int
|
|||
}
|
||||
variance_dequant = std::sqrt(variance_dequant / numel);
|
||||
variance_raw = std::sqrt(variance_raw / numel);
|
||||
var_corr = variance_raw / variance_dequant;
|
||||
mean_corr = average_raw - average_dequant * var_corr;
|
||||
var_corr_ = variance_raw / variance_dequant;
|
||||
mean_corr_ = average_raw - average_dequant * var_corr_;
|
||||
|
||||
for (int i = 0; i < numel; i++) {
|
||||
int bucket = (i / bucket_volume) % bucket_count;
|
||||
norms2[bucket] += weights[i] * weights[i];
|
||||
float dequant = var_corr * (scale * (floorf(weights[i] / scale + 0.5))) + mean_corr;
|
||||
float dequant = var_corr_ * (scale * (floorf(weights[i] / scale + 0.5))) + mean_corr_;
|
||||
float d = weights[i] - dequant;
|
||||
dnorms2[bucket] += d * d;
|
||||
}
|
||||
|
@ -82,7 +82,7 @@ float FixBitWeightQuantizer::MeasureQuantizationError(float *weights, const int
|
|||
return t / (c + 1e-7);
|
||||
}
|
||||
|
||||
MinMax FixBitWeightQuantizer::GetMinMax(const float *arr, int arrc) {
|
||||
MinMax MixedBitWeightQuantizer::GetMinMax(const float *arr, int arrc) {
|
||||
MS_ASSERT(arr != nullptr);
|
||||
MinMax min_max = {INFINITY, -INFINITY};
|
||||
for (int i = 0; i < arrc; i++)
|
||||
|
@ -93,9 +93,9 @@ MinMax FixBitWeightQuantizer::GetMinMax(const float *arr, int arrc) {
|
|||
return min_max;
|
||||
}
|
||||
|
||||
BinarySearchResult FixBitWeightQuantizer::BinarySearchForQuantizationScale(float *weights, int *shape, int dims,
|
||||
int preferred_dim, int max_iters,
|
||||
float target_err, float rel_tol) {
|
||||
BinarySearchResult MixedBitWeightQuantizer::BinarySearchForQuantizationScale(float *weights, int *shape, int dims,
|
||||
int preferred_dim, int max_iters,
|
||||
float target_err, float rel_tol) {
|
||||
MS_ASSERT(weights != nullptr);
|
||||
MS_ASSERT(shape != nullptr);
|
||||
int element_num = 1;
|
||||
|
@ -136,9 +136,9 @@ BinarySearchResult FixBitWeightQuantizer::BinarySearchForQuantizationScale(float
|
|||
}
|
||||
}
|
||||
|
||||
int FixBitWeightQuantizer::DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim,
|
||||
std::vector<schema::QuantParamT> *quant_params,
|
||||
std::vector<int16_t> *quant_datas) {
|
||||
int MixedBitWeightQuantizer::DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim,
|
||||
std::vector<schema::QuantParamT> *quant_params,
|
||||
std::vector<int16_t> *quant_datas) {
|
||||
MS_ASSERT(weights != nullptr);
|
||||
int weight_count = 1;
|
||||
int dims = shape.size();
|
||||
|
@ -148,8 +148,8 @@ int FixBitWeightQuantizer::DoQuantization(float *weights, std::vector<int64_t> s
|
|||
input_shape[i] = shape[i];
|
||||
}
|
||||
|
||||
BinarySearchResult br = BinarySearchForQuantizationScale(weights, input_shape, dims, preferred_dim, max_search_iters,
|
||||
target_relative_err, target_search_tolerance);
|
||||
BinarySearchResult br = BinarySearchForQuantizationScale(weights, input_shape, dims, preferred_dim, max_search_iters_,
|
||||
target_relative_err_, target_search_tolerance_);
|
||||
if (br.status != 0) {
|
||||
MS_LOG(ERROR) << "reached_max_iters";
|
||||
return RET_ERROR;
|
||||
|
@ -166,15 +166,15 @@ int FixBitWeightQuantizer::DoQuantization(float *weights, std::vector<int64_t> s
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int FixBitWeightQuantizer::QuantizeByScale(const float *weights, int weightsc, float scale,
|
||||
schema::QuantParamT *quant_params, std::vector<int16_t> *quant_datas) {
|
||||
int MixedBitWeightQuantizer::QuantizeByScale(const float *weights, int weightsc, float scale,
|
||||
schema::QuantParamT *quant_params, std::vector<int16_t> *quant_datas) {
|
||||
MS_ASSERT(weights != nullptr);
|
||||
for (int i = 0; i < weightsc; i++) {
|
||||
auto q = static_cast<int>(floorf(weights[i] / scale + 0.5));
|
||||
quant_datas->at(i) = q;
|
||||
}
|
||||
quant_params->meanCorr = mean_corr;
|
||||
quant_params->varCorr = var_corr;
|
||||
quant_params->meanCorr = mean_corr_;
|
||||
quant_params->varCorr = var_corr_;
|
||||
quant_params->scale = scale;
|
||||
quant_params->zeroPoint = 0;
|
||||
quant_params->numBits = 0;
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FIX_BIT_WEIGHT_QUANTIZER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FIX_BIT_WEIGHT_QUANTIZER_H
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_MIXED_BIT_WEIGHT_QUANTIZER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_MIXED_BIT_WEIGHT_QUANTIZER_H
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
|
@ -34,14 +34,14 @@ typedef struct {
|
|||
float max;
|
||||
} MinMax;
|
||||
|
||||
class FixBitWeightQuantizer {
|
||||
class MixedBitWeightQuantizer {
|
||||
public:
|
||||
explicit FixBitWeightQuantizer(float target_relative_err = 0.01, float target_search_tolerance = 0.01,
|
||||
int max_search_iters = 100)
|
||||
: target_relative_err(target_relative_err),
|
||||
target_search_tolerance(target_search_tolerance),
|
||||
max_search_iters(max_search_iters) {}
|
||||
~FixBitWeightQuantizer() = default;
|
||||
explicit MixedBitWeightQuantizer(float target_relative_err = 0.01, float target_search_tolerance = 0.01,
|
||||
int max_search_iters = 100)
|
||||
: target_relative_err_(target_relative_err),
|
||||
target_search_tolerance_(target_search_tolerance),
|
||||
max_search_iters_(max_search_iters) {}
|
||||
~MixedBitWeightQuantizer() = default;
|
||||
|
||||
int DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim,
|
||||
std::vector<schema::QuantParamT> *quant_params, std::vector<int16_t> *quant_datas);
|
||||
|
@ -58,11 +58,11 @@ class FixBitWeightQuantizer {
|
|||
int max_iters, float target_err, float rel_tol);
|
||||
|
||||
private:
|
||||
float var_corr{1};
|
||||
float mean_corr{0};
|
||||
float target_relative_err;
|
||||
float target_search_tolerance;
|
||||
int max_search_iters;
|
||||
float var_corr_{1};
|
||||
float mean_corr_{0};
|
||||
float target_relative_err_;
|
||||
float target_search_tolerance_;
|
||||
int max_search_iters_;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FIX_BIT_WEIGHT_QUANTIZER_H
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_MIXED_BIT_WEIGHT_QUANTIZER_H
|
|
@ -42,8 +42,6 @@ struct FullQuantParam {
|
|||
ActivationQuantizedMethod activation_quant_method = MAX_MIN;
|
||||
bool bias_correction = true;
|
||||
int thread_num = 1;
|
||||
bool mixed = false;
|
||||
float mean_error_threshold = 0.04;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
|
||||
|
|
|
@ -726,8 +726,8 @@ void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<in
|
|||
}
|
||||
}
|
||||
|
||||
STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
|
||||
WeightQuantType weight_quant_type, TypeId quant_data_type, int index) {
|
||||
STATUS MixedBitQuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
|
||||
WeightQuantType weight_quant_type, TypeId quant_data_type, double init_scale, int index) {
|
||||
MS_CHECK_TRUE_RET(primitive != nullptr, RET_NULL_PTR);
|
||||
MS_CHECK_TRUE_RET(weight != nullptr, RET_NULL_PTR);
|
||||
auto dims = weight->shape();
|
||||
|
@ -748,13 +748,13 @@ STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitiv
|
|||
std::vector<int16_t> quant_data(elem_count);
|
||||
int ret = RET_OK;
|
||||
if (weight_quant_type == MIXED_BIT_PER_LAYER) {
|
||||
FixBitWeightQuantizer quantizer(0.02);
|
||||
MixedBitWeightQuantizer quantizer(init_scale);
|
||||
quantizer.DoQuantization(static_cast<float *>(weight->data_c()), weight->shape_c(), 0, &quant_params, &quant_data);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported weight quant type:" << weight_quant_type;
|
||||
}
|
||||
auto status =
|
||||
UpdateTensorDataAndSize(weight, quant_data.data(), quant_data.size() * sizeof(int16_t), TypeId::kNumberTypeInt16);
|
||||
UpdateTensorDataAndSize(weight, quant_data.data(), quant_data.size() * sizeof(int16_t), quant_data_type);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -42,7 +42,7 @@
|
|||
#include "abstract/dshape.h"
|
||||
#include "tools/converter/quantizer/huffman_encode.h"
|
||||
#include "tools/converter/quantizer/bitpacking.h"
|
||||
#include "tools/converter/quantizer/fix_bit_weight_quantizer.h"
|
||||
#include "tools/converter/quantizer/mixed_bit_weight_quantizer.h"
|
||||
#include "src/lite_session.h"
|
||||
#include "tools/converter/graphdef_transform.h"
|
||||
#include "src/common/file_utils.h"
|
||||
|
@ -147,13 +147,13 @@ STATUS DoBitPack(const tensor::TensorPtr &weight, const size_t &bit_num, const s
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
|
||||
WeightQuantType weight_quant_type, TypeId quant_data_type, int index = 1);
|
||||
STATUS MixedBitQuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
|
||||
WeightQuantType weight_quant_type, TypeId quant_data_type, double init_scale, int index);
|
||||
|
||||
template <typename T>
|
||||
STATUS QuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type, int quant_max,
|
||||
int quant_min, size_t bit_num, WeightQuantType weight_quant_type, TypeId quant_data_type,
|
||||
int index = 1, bool k_means = false) {
|
||||
STATUS FixedBitQuantFilter(const tensor::TensorPtr &weight, const PrimitivePtr &primitive, QuantType quant_type,
|
||||
int quant_max, int quant_min, size_t bit_num, WeightQuantType weight_quant_type,
|
||||
TypeId quant_data_type, int index = 1, bool k_means = false) {
|
||||
MS_ASSERT(weight != nullptr);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto dims = weight->shape();
|
||||
|
@ -240,7 +240,5 @@ FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &);
|
|||
|
||||
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, tensor::TensorPtr *tensor_info);
|
||||
|
||||
int ConvertInputShapeMapToVector(FullQuantParam *config_param_, const std::vector<tensor::MSTensor *> &inputs,
|
||||
std::vector<std::vector<int>> *shapes);
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANTIZE_UTIL_H_
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mindspore/lite/tools/converter/quantizer/quantizer.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
STATUS Quantizer::GenerateQuantParam() { return RET_OK; }
|
||||
|
||||
STATUS Quantizer::RemoveFakeQuant() { return RET_OK; }
|
||||
|
||||
STATUS Quantizer::DetermineNodeQuantType() { return RET_OK; }
|
||||
|
||||
STATUS FbQuantizer::GenerateQuantParam() { return RET_OK; }
|
||||
|
||||
STATUS FbQuantizer::RemoveFakeQuant() { return RET_OK; }
|
||||
|
||||
STATUS FbQuantizer::DetermineNodeQuantType() { return RET_OK; }
|
||||
} // namespace mindspore::lite::quant
|
|
@ -36,12 +36,6 @@ class Quantizer {
|
|||
|
||||
virtual ~Quantizer() = default;
|
||||
|
||||
virtual STATUS RemoveFakeQuant();
|
||||
|
||||
virtual STATUS GenerateQuantParam();
|
||||
|
||||
virtual STATUS DetermineNodeQuantType();
|
||||
|
||||
virtual STATUS DoQuantize(FuncGraphPtr func_graph) = 0;
|
||||
|
||||
converter::Flags flags;
|
||||
|
@ -49,23 +43,5 @@ class Quantizer {
|
|||
protected:
|
||||
FuncGraphPtr funcGraph = nullptr;
|
||||
};
|
||||
|
||||
class FbQuantizer {
|
||||
public:
|
||||
explicit FbQuantizer(schema::MetaGraphT *graph) : graph(graph) {}
|
||||
|
||||
virtual ~FbQuantizer() = default;
|
||||
|
||||
virtual STATUS RemoveFakeQuant();
|
||||
|
||||
virtual STATUS GenerateQuantParam();
|
||||
|
||||
virtual STATUS DetermineNodeQuantType();
|
||||
|
||||
virtual STATUS DoQuantize() = 0;
|
||||
|
||||
protected:
|
||||
schema::MetaGraphT *graph = nullptr;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif
|
||||
|
|
|
@ -27,11 +27,6 @@ using std::string;
|
|||
using std::vector;
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const FullQuantParam &config) : Quantizer(std::move(graph)) {
|
||||
quant_strategy_ = std::make_unique<QuantStrategy>(0, 0);
|
||||
config_param_ = config;
|
||||
}
|
||||
|
||||
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config) : Quantizer(std::move(graph)) {
|
||||
auto quant_size = config.commonQuantParam.min_quant_weight_size;
|
||||
this->bit_num_ = config.commonQuantParam.bit_num;
|
||||
|
@ -39,17 +34,19 @@ WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &con
|
|||
type_id_ = kNumberTypeInt16;
|
||||
this->is_mixed_bit_ = true;
|
||||
}
|
||||
auto convQuantWeightChannelThreshold = config.commonQuantParam.min_quant_weight_channel;
|
||||
quant_strategy_ = std::make_unique<QuantStrategy>(quant_size, convQuantWeightChannelThreshold);
|
||||
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
|
||||
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
|
||||
// parse type_id_
|
||||
if (this->bit_num_ > 0 && this->bit_num_ <= kMaxBit) {
|
||||
type_id_ = kNumberTypeInt8;
|
||||
} else if (this->bit_num_ <= (kMaxBit * 2)) {
|
||||
type_id_ = kNumberTypeInt16;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "invalid input bits";
|
||||
quant_strategy_ = std::make_unique<QuantStrategy>(quant_size, config.commonQuantParam.min_quant_weight_channel);
|
||||
// parse param for fixed bit quant.
|
||||
if (!this->is_mixed_bit_) {
|
||||
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
|
||||
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
|
||||
// parse type_id_
|
||||
if (this->bit_num_ > 0 && this->bit_num_ <= kMaxBit) {
|
||||
type_id_ = kNumberTypeInt8;
|
||||
} else if (this->bit_num_ <= (kMaxBit * 2)) {
|
||||
type_id_ = kNumberTypeInt16;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "invalid input bits";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -116,18 +113,19 @@ STATUS WeightQuantizer::DoConvQuantize(const CNodePtr &cnode) {
|
|||
auto status = RET_ERROR;
|
||||
if (is_mixed_bit_) {
|
||||
type_id_ = kNumberTypeInt16;
|
||||
status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER, type_id_);
|
||||
status = MixedBitQuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER,
|
||||
type_id_, flags.mixedBitWeightQuantParam.init_scale, 1);
|
||||
} else if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
|
||||
status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
|
||||
bit_num_, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
|
||||
status = FixedBitQuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
|
||||
bit_num_, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
|
||||
}
|
||||
if (status == RET_CONTINUE) {
|
||||
return RET_OK;
|
||||
} else if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
MS_LOG(ERROR) << "MixedBitQuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(tensor_info, param_node, primitive);
|
||||
|
@ -160,19 +158,21 @@ STATUS WeightQuantizer::DoMulQuantize(const CNodePtr &cnode) {
|
|||
weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||
}
|
||||
if (is_mixed_bit_) {
|
||||
status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER,
|
||||
type_id_, i - 1);
|
||||
status =
|
||||
MixedBitQuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER,
|
||||
type_id_, flags.mixedBitWeightQuantParam.init_scale, i - 1);
|
||||
} else if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
|
||||
bit_num_, weight_quant_type, type_id_, i - 1);
|
||||
status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
|
||||
bit_num_, weight_quant_type, type_id_, i - 1);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
|
||||
bit_num_, weight_quant_type, type_id_, i - 1);
|
||||
status = FixedBitQuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
|
||||
bit_num_, weight_quant_type, type_id_, i - 1);
|
||||
}
|
||||
if (status == RET_CONTINUE) {
|
||||
continue;
|
||||
} else if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << cnode->fullname_with_scope() << " input " << i << " QuantFilter failed : " << status;
|
||||
MS_LOG(ERROR) << cnode->fullname_with_scope() << " input " << i
|
||||
<< " MixedBitQuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(tensor_info, param_node, primitive);
|
||||
|
@ -241,19 +241,19 @@ STATUS WeightQuantizer::DoGatherQuantize(const CNodePtr &cnode) {
|
|||
|
||||
auto status = RET_ERROR;
|
||||
if (is_mixed_bit_) {
|
||||
status =
|
||||
QuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER, type_id_, 0);
|
||||
status = MixedBitQuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER,
|
||||
type_id_, flags.mixedBitWeightQuantParam.init_scale, 0);
|
||||
} else if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, 0);
|
||||
status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
|
||||
bit_num_, WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, 0);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, 0);
|
||||
status = FixedBitQuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
|
||||
bit_num_, WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, 0);
|
||||
}
|
||||
if (status == RET_CONTINUE) {
|
||||
return RET_OK;
|
||||
} else if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
MS_LOG(ERROR) << "MixedBitQuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(tensor_info, param_node, primitive);
|
||||
|
@ -293,14 +293,14 @@ STATUS WeightQuantizer::DoOptimizerQuantize(const CNodePtr &cnode) {
|
|||
|
||||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, idx - 1);
|
||||
status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
|
||||
bit_num_, WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, idx - 1);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, idx - 1);
|
||||
status = FixedBitQuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
|
||||
bit_num_, WeightQuantType::FIXED_BIT_PER_LAYER, type_id_, idx - 1);
|
||||
}
|
||||
if (status != RET_OK && status != RET_CONTINUE) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
MS_LOG(ERROR) << "MixedBitQuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(tensor_info, param_node, primitive);
|
||||
|
@ -367,19 +367,19 @@ STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const Pr
|
|||
}
|
||||
auto status = RET_ERROR;
|
||||
if (is_mixed_bit_) {
|
||||
status = QuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER, type_id_,
|
||||
index - 1);
|
||||
status = MixedBitQuantFilter(tensor_info, primitive, QuantType_WeightQuant, WeightQuantType::MIXED_BIT_PER_LAYER,
|
||||
type_id_, flags.mixedBitWeightQuantParam.init_scale, index - 1);
|
||||
} else if (type_id_ == kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_, index - 1);
|
||||
status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
|
||||
bit_num_, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_, index - 1);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_, index - 1);
|
||||
status = FixedBitQuantFilter<int16_t>(tensor_info, primitive, QuantType_WeightQuant, quant_max_, quant_min_,
|
||||
bit_num_, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_, index - 1);
|
||||
}
|
||||
if (status == RET_CONTINUE) {
|
||||
return RET_OK;
|
||||
} else if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
MS_LOG(ERROR) << "MixedBitQuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(tensor_info, param_node, primitive);
|
||||
|
@ -390,347 +390,24 @@ STATUS WeightQuantizer::ProcessLstmWeightByIndex(const CNodePtr &cnode, const Pr
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
constexpr float relative_tolerance = 1e-5;
|
||||
constexpr float abs_tolerance = 1e-4;
|
||||
|
||||
template <typename T>
|
||||
float CompareOutputData(const std::unordered_map<std::string, mindspore::tensor::MSTensor *> &expected_tensor,
|
||||
const std::unordered_map<std::string, mindspore::tensor::MSTensor *> &compare_tensor) {
|
||||
auto valid_data = [](T data) -> bool { return (!std::isnan(data) && !std::isinf(data)); };
|
||||
|
||||
float total_mean_error = 0.0f;
|
||||
int tensor_cnt = expected_tensor.size();
|
||||
if (tensor_cnt <= 0) {
|
||||
MS_LOG(ERROR) << "unexpected tensor_cnt: " << tensor_cnt;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
for (const auto &exp_tensor_pair : expected_tensor) {
|
||||
float mean_error = 0.0f;
|
||||
int error_cnt = 0;
|
||||
|
||||
auto exp_tensor_name = exp_tensor_pair.first;
|
||||
auto exp_tensor = exp_tensor_pair.second;
|
||||
auto cmp_tensor_find_iter = compare_tensor.find(exp_tensor_name);
|
||||
if (cmp_tensor_find_iter == compare_tensor.end()) {
|
||||
MS_LOG(ERROR) << "can not find: " << exp_tensor_name;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto cmp_tensor = cmp_tensor_find_iter->second;
|
||||
|
||||
auto exp_tensor_shape = exp_tensor->shape();
|
||||
auto cmp_tensor_shape = cmp_tensor->shape();
|
||||
if (exp_tensor_shape != cmp_tensor_shape) {
|
||||
MS_LOG(ERROR) << "exp tensor shape not equal to cmp. exp_tensor_elem_cnt: " << exp_tensor->ElementsNum()
|
||||
<< " cmp_tensor_elem_cnt: " << cmp_tensor->ElementsNum();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto exp_data = static_cast<T *>(exp_tensor->MutableData());
|
||||
auto cmp_data = static_cast<T *>(cmp_tensor->MutableData());
|
||||
auto elem_cnt = exp_tensor->ElementsNum();
|
||||
for (int i = 0; i < elem_cnt; i++) {
|
||||
if (!valid_data(exp_data[i]) || !valid_data(cmp_data[i])) {
|
||||
MS_LOG(ERROR) << "data is not valid. exp: " << exp_data[i] << " cmp: " << cmp_data[i] << " index: " << i;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto tolerance = abs_tolerance + relative_tolerance * fabs(exp_data[i]);
|
||||
auto abs_error = std::fabs(exp_data[i] - cmp_data[i]);
|
||||
if (abs_error > tolerance) {
|
||||
if (fabs(exp_data[i] == 0)) {
|
||||
if (abs_error > 1e-5) {
|
||||
mean_error += abs_error;
|
||||
error_cnt++;
|
||||
} else {
|
||||
// it is ok, very close to 0
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
mean_error += abs_error / (fabs(exp_data[i]) + FLT_MIN);
|
||||
error_cnt++;
|
||||
}
|
||||
} else {
|
||||
// it is ok, no error
|
||||
continue;
|
||||
}
|
||||
} // end one tensor data loop
|
||||
total_mean_error += mean_error / elem_cnt;
|
||||
} // end tensor loop
|
||||
return total_mean_error / tensor_cnt;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::RunFp32Graph(const FuncGraphPtr &func_graph) {
|
||||
auto image_cnt = images_.at(0).size();
|
||||
// 0.1 Create Fp32 Session
|
||||
flags.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
|
||||
auto fp32_sm = CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num);
|
||||
auto fp32_session = fp32_sm.session;
|
||||
auto fp32_model = fp32_sm.model;
|
||||
if (fp32_session == nullptr || fp32_model == nullptr) {
|
||||
MS_LOG(ERROR) << "CreateSessoin fail";
|
||||
delete fp32_model;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto fp32_inputs = fp32_session->GetInputs();
|
||||
fp32_output_tensors_.resize(image_cnt);
|
||||
// 0.3 save fp32 output
|
||||
for (size_t i = 0; i < image_cnt; i++) {
|
||||
for (size_t input_index = 0; input_index < fp32_inputs.size(); input_index++) {
|
||||
auto status = preprocess::PreProcess(flags.dataPreProcessParam, fp32_inputs[input_index]->tensor_name(), i,
|
||||
fp32_inputs[input_index]);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "generate input data from images failed!";
|
||||
delete fp32_sm.session;
|
||||
delete fp32_sm.model;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto status = fp32_session->RunGraph();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "RunGraph fail";
|
||||
delete fp32_sm.session;
|
||||
delete fp32_sm.model;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto fp32_outputs = fp32_session->GetOutputs();
|
||||
for (const auto &kv : fp32_outputs) {
|
||||
auto *tensor = kv.second;
|
||||
auto *lite_tensor = reinterpret_cast<lite::Tensor *>(tensor);
|
||||
if (lite_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "not lite tensor";
|
||||
delete fp32_sm.session;
|
||||
delete fp32_sm.model;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto *new_tensor = Tensor::CopyTensor(*lite_tensor, true);
|
||||
fp32_output_tensors_[i][kv.first] = new_tensor;
|
||||
}
|
||||
}
|
||||
delete fp32_sm.session;
|
||||
delete fp32_sm.model;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoMixedQuantize(const FuncGraphPtr &func_graph) {
|
||||
STATUS WeightQuantizer::MarkWeightQuantizationInNodes(const FuncGraphPtr &func_graph) {
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
int status = RET_OK;
|
||||
for (auto &cnode : cnodes) {
|
||||
if (opt::CheckPrimitiveType(cnode, prim::kPrimLstm)) {
|
||||
status = DoLstmQuantize(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoLstmQuantize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (opt::CheckPrimitiveType(cnode, prim::kPrimGather)) {
|
||||
status = DoGatherQuantize(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoGatherQuantize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name,
|
||||
ParameterPtr *param_node, tensor::TensorPtr *tensor_info) {
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
MS_LOG(WARNING) << op_name << " the second input is not parameter";
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
*param_node = input_node->cast<ParameterPtr>();
|
||||
if (!(*param_node)->has_default()) {
|
||||
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
*tensor_info = std::static_pointer_cast<tensor::Tensor>((*param_node)->default_param());
|
||||
if (*tensor_info == nullptr) {
|
||||
MS_LOG(WARNING) << op_name << " the second input can not convert to parameter";
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
if ((*tensor_info)->data_type() != TypeId::kNumberTypeFloat32) {
|
||||
MS_LOG(WARNING) << op_name << " the second input type is not float";
|
||||
return RET_CONTINUE;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
STATUS WeightQuantizer::TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node,
|
||||
const tensor::TensorPtr &tensor_info, const PrimitivePtr &primitive) {
|
||||
MS_CHECK_TRUE_RET(primitive != nullptr, RET_NULL_PTR);
|
||||
MS_CHECK_TRUE_RET(tensor_info != nullptr, RET_NULL_PTR);
|
||||
MS_CHECK_TRUE_RET(param_node != nullptr, RET_NULL_PTR);
|
||||
int status;
|
||||
type_id_ = TypeId::kNumberTypeInt8;
|
||||
int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
|
||||
int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
|
||||
|
||||
if (type_id_ == TypeId::kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(tensor_info, primitive, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
|
||||
bit_num_t, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
|
||||
} else if (type_id_ == TypeId::kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(tensor_info, primitive, QuantType::QuantType_WeightQuant, quant_max_t, quant_min_t,
|
||||
bit_num_t, WeightQuantType::FIXED_BIT_PER_CHANNEL, type_id_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (status == RET_CONTINUE) {
|
||||
return RET_OK;
|
||||
} else if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "quant filter failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = SetAbstract(tensor_info, param_node, primitive);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::EvaluateQuant(const FuncGraphPtr &func_graph, size_t image_cnt, float *mean_error) {
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
|
||||
if (mean_error == nullptr) {
|
||||
MS_LOG(ERROR) << "mean_error is nullptr.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// 2.1 create quant session, get input, output tensor
|
||||
flags.commonQuantParam.quant_type = schema::QuantType_WeightQuant;
|
||||
auto quant_sm = CreateSessionByFuncGraph(func_graph, flags, config_param_.thread_num);
|
||||
auto quant_session = std::unique_ptr<session::LiteSession>(quant_sm.session);
|
||||
int status;
|
||||
if (quant_session == nullptr) {
|
||||
MS_LOG(ERROR) << "create session error.";
|
||||
delete quant_sm.model;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto quant_inputs = quant_session->GetInputs();
|
||||
|
||||
for (size_t i = 0; i < image_cnt; i++) {
|
||||
// set multi-input data
|
||||
for (size_t input_index = 0; input_index < quant_inputs.size(); input_index++) {
|
||||
status = preprocess::PreProcess(flags.dataPreProcessParam, quant_inputs[input_index]->tensor_name(), i,
|
||||
quant_inputs[input_index]);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "generate input data from images failed!";
|
||||
delete quant_sm.model;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
status = quant_session->RunGraph();
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "quant session run error";
|
||||
delete quant_sm.model;
|
||||
return RET_ERROR;
|
||||
}
|
||||
// 3. compare between quant and fp32
|
||||
auto quant_outputs = quant_session->GetOutputs();
|
||||
(*mean_error) += CompareOutputData<float>(fp32_output_tensors_[i], quant_outputs);
|
||||
} // end_for: calib data loop
|
||||
delete quant_sm.model;
|
||||
(*mean_error) = (*mean_error) / image_cnt;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoQuantSearch(const FuncGraphPtr &func_graph) {
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
size_t image_cnt = images_.at(0).size();
|
||||
int status = RET_OK;
|
||||
for (auto iter = cnodes.end(); iter != cnodes.begin();) {
|
||||
auto cnode = *(--iter);
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
for (auto &cnode : func_graph->GetOrderedCnodes()) {
|
||||
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive is null.";
|
||||
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
|
||||
continue;
|
||||
}
|
||||
auto status = DoMarkWeightQuantizeIfQuantized(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "MarkWeightQuantizationInNodes error marking " << cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
MS_LOG(DEBUG) << "process node: " << op_name << " type: " << primitive->name();
|
||||
if (quant_strategy_->CanConvOpQuantized(cnode) || quant_strategy_->CanMulOpQuantized(cnode)) {
|
||||
auto input_node = cnode->input(2);
|
||||
ParameterPtr param_node;
|
||||
tensor::TensorPtr tensor_info;
|
||||
status = GetParamNodeAndValue(input_node, op_name, ¶m_node, &tensor_info);
|
||||
if (status == RET_CONTINUE) {
|
||||
continue;
|
||||
}
|
||||
// copy origin data in case to recover
|
||||
auto *raw_data = static_cast<float *>(tensor_info->data_c());
|
||||
auto elem_count = tensor_info->DataSize();
|
||||
auto val = std::make_unique<float>(elem_count);
|
||||
MS_CHECK_TRUE_RET(val != nullptr, RET_NULL_PTR);
|
||||
std::unique_ptr<float[]> origin_data(val.release());
|
||||
auto ret = memcpy_s(origin_data.get(), sizeof(float) * elem_count, raw_data, tensor_info->Size());
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy fail: "
|
||||
<< " dst size: " << sizeof(float) * elem_count << " src size: " << tensor_info->Size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
// 1. try quant
|
||||
for (size_t bit_num_t = 2; bit_num_t <= kMaxBit; bit_num_t++) {
|
||||
status = TryQuant(bit_num_t, param_node, tensor_info, primitive);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "TryQuant failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
// 2. evaluate the quant
|
||||
float mean_error = 0.0f;
|
||||
status = EvaluateQuant(func_graph, image_cnt, &mean_error);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "EvaluateQuant failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (mean_error <= config_param_.mean_error_threshold) {
|
||||
MS_LOG(DEBUG) << "op: " << op_name << " got mixed bit: " << bit_num_t << " mean_error: " << mean_error;
|
||||
opname_bit_[op_name] = bit_num_t;
|
||||
break;
|
||||
} else if (bit_num_t != kMaxBit) {
|
||||
MS_LOG(DEBUG) << "op: " << op_name << " intermediate bit: " << bit_num_t << " mean_error: " << mean_error
|
||||
<< " [recover]";
|
||||
// recover
|
||||
status =
|
||||
UpdateTensorDataAndSize(tensor_info, origin_data.get(), sizeof(float) * elem_count, kNumberTypeFloat32);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "UpdateTensorDataAndSize fail";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "op: " << op_name << " set bit: " << bit_num_t << " mean_error: " << mean_error;
|
||||
opname_bit_[op_name] = bit_num_t;
|
||||
}
|
||||
} // end bit loop
|
||||
} // if: conv and matmul
|
||||
} // end loop: all cnode
|
||||
return status;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoMixedQuant(const FuncGraphPtr &func_graph) {
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
|
||||
auto status = RunFp32Graph(func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "RunFp32Graph failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
status = DoMixedQuantize(func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoMixedQuantize failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
status = DoQuantSearch(func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantSearch failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
for (const auto &kv : opname_bit_) {
|
||||
MS_LOG(INFO) << "op: " << kv.first << " bit:" << kv.second;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoFixedQuant(const FuncGraphPtr &func_graph) {
|
||||
STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
|
||||
weight_quantized_tensors_.clear();
|
||||
|
||||
|
@ -779,35 +456,4 @@ STATUS WeightQuantizer::DoFixedQuant(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
return MarkWeightQuantizationInNodes(func_graph);
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::MarkWeightQuantizationInNodes(const FuncGraphPtr &func_graph) {
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
|
||||
for (auto &cnode : func_graph->GetOrderedCnodes()) {
|
||||
auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
|
||||
continue;
|
||||
}
|
||||
auto status = DoMarkWeightQuantizeIfQuantized(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "MarkWeightQuantizationInNodes error marking " << cnode->fullname_with_scope();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
|
||||
if (config_param_.mixed) {
|
||||
bit_num_ = kMaxBit;
|
||||
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
|
||||
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
|
||||
type_id_ = kNumberTypeInt8;
|
||||
MS_LOG(INFO) << "Do mixed bit quantization";
|
||||
return DoMixedQuant(func_graph);
|
||||
}
|
||||
|
||||
return DoFixedQuant(func_graph);
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -39,7 +39,6 @@ namespace mindspore::lite::quant {
|
|||
class WeightQuantizer : public Quantizer {
|
||||
public:
|
||||
WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config);
|
||||
WeightQuantizer(FuncGraphPtr graph, const FullQuantParam &config);
|
||||
~WeightQuantizer() override;
|
||||
|
||||
STATUS DoQuantize(FuncGraphPtr func_graph) override;
|
||||
|
@ -54,33 +53,18 @@ class WeightQuantizer : public Quantizer {
|
|||
int quant_max_{127};
|
||||
int quant_min_{-128};
|
||||
TypeId type_id_{kNumberTypeInt8};
|
||||
std::map<std::string, int> opname_bit_;
|
||||
|
||||
private:
|
||||
std::unique_ptr<QuantStrategy> quant_strategy_;
|
||||
size_t bit_num_{8};
|
||||
std::map<tensor::TensorPtr, ParameterPtr> weight_quantized_tensors_;
|
||||
FullQuantParam config_param_;
|
||||
std::map<std::string, std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...]
|
||||
std::vector<std::unordered_map<std::string, mindspore::tensor::MSTensor *>> fp32_output_tensors_;
|
||||
bool is_mixed_bit_ = false;
|
||||
|
||||
STATUS DoMixedQuant(const FuncGraphPtr &);
|
||||
STATUS SetAbstract(const tensor::TensorPtr &tensor_info, const ParameterPtr ¶m_node,
|
||||
const PrimitivePtr &primitive);
|
||||
STATUS DoFixedQuant(const FuncGraphPtr &);
|
||||
STATUS MarkWeightQuantizationInNodes(const FuncGraphPtr &);
|
||||
STATUS DoMarkWeightQuantizeIfQuantized(const CNodePtr &);
|
||||
STATUS RunFp32Graph(const FuncGraphPtr &);
|
||||
|
||||
STATUS DoMixedQuantize(const FuncGraphPtr &func_graph);
|
||||
STATUS CheckImageCnt();
|
||||
static STATUS GetParamNodeAndValue(const std::shared_ptr<AnfNode> &input_node, const std::string &op_name,
|
||||
ParameterPtr *param_node, tensor::TensorPtr *tensor_info);
|
||||
STATUS TryQuant(const int &bit_num_t, const ParameterPtr ¶m_node, const tensor::TensorPtr &tensor_info,
|
||||
const PrimitivePtr &primitive);
|
||||
STATUS DoQuantSearch(const FuncGraphPtr &func_graph);
|
||||
STATUS EvaluateQuant(const FuncGraphPtr &func_graph, size_t image_cnt, float *mean_error);
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_WEIGHT_QUANTIZER_H_
|
||||
|
|
Loading…
Reference in New Issue