optimize the Quantizer base class and optimize min_quant_weight_channel strategy

This commit is contained in:
yeyunpeng2020 2021-11-25 10:54:16 +08:00
parent 35e0526b51
commit 8aaff90f7e
20 changed files with 103 additions and 123 deletions

View File

@ -1,8 +1,6 @@
[common_quant_param]
quant_type=FULL_QUANT
bit_num=8
min_quant_weight_size=0
min_quant_weight_channel=16
[data_preprocess_param]
calibrate_path=blob1:/home/workspace/mindspore_dataset/mslite/quantTraining/ml_face_landmark_2_calibration_data

View File

@ -1,8 +1,6 @@
[common_quant_param]
quant_type=FULL_QUANT
bit_num=8
min_quant_weight_size=0
min_quant_weight_channel=16
[data_preprocess_param]
calibrate_path=blob1:/home/workspace/mindspore_dataset/mslite/quantTraining/ml_face_mnet_image

View File

@ -1,8 +1,6 @@
[common_quant_param]
quant_type=FULL_QUANT
bit_num=8
min_quant_weight_size=0
min_quant_weight_channel=16
[data_preprocess_param]
calibrate_path=blob1:/home/workspace/mindspore_dataset/mslite/quantTraining/ml_face_mnet_calibration_data

View File

@ -1,8 +1,6 @@
[common_quant_param]
quant_type=FULL_QUANT
bit_num=8
min_quant_weight_size=0
min_quant_weight_channel=16
[data_preprocess_param]
calibrate_path=conv2d_input:/home/workspace/mindspore_dataset/mslite/quantTraining/mnist_calibration_data

View File

@ -1,8 +1,6 @@
[common_quant_param]
quant_type=FULL_QUANT
bit_num=8
min_quant_weight_size=0
min_quant_weight_channel=16
[data_preprocess_param]
calibrate_path=decoder_buffer_in_0:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_0,decoder_buffer_in_1:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_1,decoder_buffer_in_2:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_2,decoder_buffer_in_3:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_3,decoder_buffer_in_4:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_4,decoder_buffer_in_5:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_5,decoder_buffer_in_6:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_6,decoder_buffer_in_7:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_7,decoder_buffer_in_8:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_8,joint_in_encoder_deploy:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_9,decoder_in_deploy:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_10

View File

@ -1,8 +1,6 @@
[common_quant_param]
quant_type=FULL_QUANT
bit_num=8
min_quant_weight_size=0
min_quant_weight_channel=16
[data_preprocess_param]
calibrate_path=buffer_in_extreme_0:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_0,buffer_in_extreme_1:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_1,buffer_in_extreme_2:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_2,buffer_in_extreme_3:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_3,buffer_in_extreme_4:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_4,buffer_in_extreme_5:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_5,buffer_in_extreme_6:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_6,buffer_in_extreme_7:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_7,buffer_in_extreme_8:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_8,buffer_in_extreme_9:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_9,buffer_in_extreme_10:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_10,buffer_in_extreme_11:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_11,buffer_in_extreme_12:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_12,buffer_in_extreme_13:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_13,buffer_in_extreme_14:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_14,buffer_in_extreme_15:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_15,buffer_in_extreme_16:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_16,buffer_in_extreme_17:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_17,buffer_in_extreme_18:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_18,buffer_in_extreme_19:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_19,buffer_in_extreme_20:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_20,buffer_in_extreme_21:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_21,buffer_in_extreme_22:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_22,buffer_in_extreme_23:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_23,buffer_in_extreme_24:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_24,buffer_in_extreme_25:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_25,buffer_in_extreme_26:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_26,buffer_in_extreme_27:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_27,buffer_in_extreme_28:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_28,buffer_in_extreme_29:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_29,buffer_in_extreme_30:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_30,buffer_in_extreme_31:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_31,buffer_in_extreme_32:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_32,buffer_in_extreme_33:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_33,buffer_in_extreme_34:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_34,encoder_in_deploy:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_35

View File

@ -19,6 +19,7 @@
#include <string>
#include <unordered_map>
#include <deque>
#include <map>
#include "nnacl/op_base.h"
#include "src/common/log_adapter.h"
#include "tools/converter/optimizer_manager.h"
@ -363,29 +364,29 @@ void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGr
}
int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
constexpr int thread_num = 2;
// quant
if (config->commonQuantParam.quant_type != schema::QuantType_QUANT_ALL &&
config->commonQuantParam.quant_type != schema::QuantType_QUANT_WEIGHT) {
return RET_OK;
}
int status;
std::unique_ptr<quant::Quantizer> quantizer;
quant::SessionModel origin;
quant::SessionModel quant;
if (config->commonQuantParam.is_debug) {
converter::Flags new_flag = *config;
new_flag.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
origin = quant::CreateSessionByFuncGraph(old_graph, new_flag, thread_num);
origin = quant::CreateSessionByFuncGraph(old_graph, new_flag, config->commonQuantParam.thread_num);
}
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) {
this->m_quantizer_ = std::make_unique<quant::FullQuantQuantizer>(old_graph, config->commonQuantParam.bit_num);
if (m_quantizer_ == nullptr) {
quantizer = std::make_unique<quant::FullQuantQuantizer>(*config);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New FullQuantQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}
m_quantizer_->flags = *config;
status = m_quantizer_->DoQuantize(old_graph);
status = quantizer->DoQuantize(old_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantization failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
@ -400,38 +401,35 @@ int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const con
MS_LOG(ERROR) << "Grid search with scale failed.";
return status;
}
this->m_quantizer_ = std::make_unique<quant::WeightQuantizer>(old_graph, *config);
if (m_quantizer_ == nullptr) {
quantizer = std::make_unique<quant::WeightQuantizer>(*config);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}
status = static_cast<quant::WeightQuantizer *>(m_quantizer_.get())->DoQuantize(old_graph, init_scale);
status = static_cast<quant::WeightQuantizer *>(quantizer.get())->DoQuantize(old_graph, init_scale);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantization failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}
} else {
this->m_quantizer_ = std::make_unique<quant::WeightQuantizer>(old_graph, *config);
if (m_quantizer_ == nullptr) {
quantizer = std::make_unique<quant::WeightQuantizer>(*config);
if (quantizer == nullptr) {
MS_LOG(ERROR) << "New WeightQuantizer failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
return RET_ERROR;
}
if (m_quantizer_ != nullptr) {
m_quantizer_->flags = *config;
status = m_quantizer_->DoQuantize(old_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantization failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}
status = quantizer->DoQuantize(old_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantization failed " << status;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return RET_ERROR;
}
}
}
if (config->commonQuantParam.is_debug) {
quant = quant::CreateSessionByFuncGraph(old_graph, *config, thread_num);
quant = quant::CreateSessionByFuncGraph(old_graph, *config, config->commonQuantParam.thread_num);
std::map<std::string, OpParameter *> op_parameters;
FetchOpParameterFromFuncGraph(old_graph, &op_parameters);
DebugInfoManager manager;

View File

@ -37,8 +37,6 @@ class AnfTransform {
FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
private:
std::unique_ptr<quant::Quantizer> m_quantizer_ = nullptr;
FuncGraphPtr TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
static int RunFusionPass(const FuncGraphPtr &old_graph, const converter::Flags *config);

View File

@ -4,10 +4,6 @@ quant_type=FULL_QUANT
# Weight quantization support the number of bits [0,16], Set to 0 is mixed bit quantization, otherwise it is fixed bit quantization
# Full quantization support the number of bits [1,8]
bit_num=8
# Layers with size of weights exceeds threshold `min_quant_weight_size` will be quantized.
min_quant_weight_size=0
# Layers with channel size of weights exceeds threshold `min_quant_weight_channel` will be quantized.
min_quant_weight_channel=16
[data_preprocess_param]
# Calibration dataset path, the format is input_name_1:input_1_dir,input_name_2:input_2_dir

View File

@ -405,11 +405,11 @@ int FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
return RET_OK;
}
int FullQuantQuantizer::QuantNode() {
int FullQuantQuantizer::QuantNode(const FuncGraphPtr &func_graph) {
auto inputs_diverg_info = calibrator_->GetInputDivergInfo();
auto outputs_diverg_info = calibrator_->GetOutputDivergInfo();
auto cnodes = funcGraph->GetOrderedCnodes();
auto cnodes = func_graph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
auto op_name = cnode->fullname_with_scope();
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
@ -510,17 +510,17 @@ void FullQuantQuantizer::InitQMinMax() {
}
}
int FullQuantQuantizer::MarkQuantNode() {
auto cnodes = funcGraph->GetOrderedCnodes();
int FullQuantQuantizer::MarkQuantNode(const FuncGraphPtr &func_graph) {
auto cnodes = func_graph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
auto anode = cnode->cast<AnfNodePtr>();
if (anode == nullptr) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << " cnode is null";
return RET_NULL_PTR;
}
auto quant_strategy = std::make_unique<QuantStrategy>(flags.commonQuantParam.min_quant_weight_size,
flags.commonQuantParam.min_quant_weight_channel,
flags.commonQuantParam.skip_quant_node);
auto quant_strategy = std::make_unique<QuantStrategy>(flags_.commonQuantParam.min_quant_weight_size,
flags_.commonQuantParam.min_quant_weight_channel,
flags_.commonQuantParam.skip_quant_node);
// Mark quantifiable nodes
auto is_support_op = quant_strategy->CanOpFullQuantized(anode);
auto is_skip_op = quant_strategy->IsSkipOp(anode);
@ -546,7 +546,7 @@ int FullQuantQuantizer::MarkQuantNode() {
return RET_OK;
}
int FullQuantQuantizer::PreProcess() {
int FullQuantQuantizer::PreProcess(const FuncGraphPtr &func_graph) {
switch (device_) {
case CPU:
InitCpuConfig();
@ -558,10 +558,10 @@ int FullQuantQuantizer::PreProcess() {
}
InitQMinMax();
calibrator_ =
std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_, this->flags.fullQuantParam.activation_quant_method,
this->flags.dataPreProcessParam, activation_symmetry_);
std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_, this->flags_.fullQuantParam.activation_quant_method,
this->flags_.dataPreProcessParam, activation_symmetry_);
MSLITE_CHECK_PTR(calibrator_);
auto ret = MarkQuantNode();
auto ret = MarkQuantNode(func_graph);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Mark quant node failed.";
return ret;
@ -804,21 +804,21 @@ int FullQuantQuantizer::ComputeThreshold() { return this->calibrator_->ComputeTh
int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_LOG(INFO) << "start to parse config file";
if (flags.dataPreProcessParam.calibrate_path.empty()) {
if (flags_.dataPreProcessParam.calibrate_path.empty()) {
MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir.";
return RET_INPUT_PARAM_INVALID;
}
int status = PreProcess();
int status = PreProcess(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "do pre process failed!";
return status;
}
// anf -- fb
flags.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
flags_.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
MS_LOG(INFO) << "start create session";
auto sm = CreateSessionByFuncGraph(func_graph, flags, this->flags.commonQuantParam.thread_num);
auto sm = CreateSessionByFuncGraph(func_graph, flags_, this->flags_.commonQuantParam.thread_num);
fp32_session_ = sm.session;
fp32_model_ = sm.model;
if (fp32_session_ == nullptr || fp32_model_ == nullptr) {
@ -850,7 +850,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
return status;
}
MS_LOG(INFO) << "start to generate quant param and quantize tensor's data";
status = QuantNode();
status = QuantNode(func_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Quant node failed.";
return status;
@ -865,11 +865,11 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
return RET_ERROR;
}
SessionModel int8_sm;
if (this->flags.fullQuantParam.bias_correction) {
if (this->flags_.fullQuantParam.bias_correction) {
// init in8 session
MS_LOG(INFO) << "create quant session";
flags.commonQuantParam.quant_type = schema::QuantType_QUANT_ALL;
int8_sm = CreateSessionByFuncGraph(func_graph, flags, this->flags.commonQuantParam.thread_num);
flags_.commonQuantParam.quant_type = schema::QuantType_QUANT_ALL;
int8_sm = CreateSessionByFuncGraph(func_graph, flags_, this->flags_.commonQuantParam.thread_num);
int8_session_ = int8_sm.session;
int8_model_ = int8_sm.model;
if (int8_session_ == nullptr || int8_model_ == nullptr) {

View File

@ -43,14 +43,11 @@ enum OperationType {
FETCH,
};
enum QuantRuntimeDevice {
CPU,
KIRIN,
};
class FullQuantQuantizer : public Quantizer {
public:
FullQuantQuantizer(FuncGraphPtr graph, int bit_num) : Quantizer(std::move(graph)), bit_num_(bit_num) {}
explicit FullQuantQuantizer(const converter::Flags &flags) : Quantizer(flags) {
bit_num_ = flags.commonQuantParam.bit_num;
}
~FullQuantQuantizer() override;
@ -60,7 +57,7 @@ class FullQuantQuantizer : public Quantizer {
bool OpInputDataHandle(OperationType type, const string &op_name, std::vector<float> *data);
bool OpOutputChMeanDataHandle(OperationType type, const string &op_name, std::vector<float> *data);
int PreProcess();
int PreProcess(const FuncGraphPtr &func_graph);
int CheckFp32TensorVec(const std::string &node_name, const std::vector<mindspore::tensor::MSTensor *> &tensor_vec);
@ -72,7 +69,7 @@ class FullQuantQuantizer : public Quantizer {
int QuantNodeSimpleOp(const CNodePtr &cnode);
int QuantNode();
int QuantNode(const FuncGraphPtr &func_graph);
int SetInOutQuantParam(const AnfNodePtr &input_node, const std::unique_ptr<DataDistribution> &info,
const PrimitivePtr &primitive, bool is_input, size_t index) const;
@ -92,7 +89,7 @@ class FullQuantQuantizer : public Quantizer {
KernelCallBack GetFloatAfterCallBack();
void InitQMinMax();
void InitCpuConfig();
int MarkQuantNode();
int MarkQuantNode(const FuncGraphPtr &func_graph);
private:
// Config

View File

@ -202,8 +202,8 @@ int HuffmanEncode::DoHuffmanCompress(const int8_t *input_datas, const size_t dat
out_c = 0;
for (size_t i = 0; i < code_str.length(); i++) {
auto tmp_c = code_str[i] == '0' ? 0 : 1;
out_c += tmp_c << ((quant::kMaxBit - 1) - (i % quant::kMaxBit));
if ((i + 1) % quant::kMaxBit == 0 || i == code_str.length() - 1) {
out_c += tmp_c << ((quant::k8Bit - 1) - (i % quant::k8Bit));
if ((i + 1) % quant::k8Bit == 0 || i == code_str.length() - 1) {
encode_str[2] += out_c;
out_c = 0;
}

View File

@ -84,9 +84,8 @@ int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph
}
// quant
auto quantizer = std::make_unique<quant::WeightQuantizer>(func_graph_bak, *flags);
auto quantizer = std::make_unique<quant::WeightQuantizer>(*flags);
CHECK_NULL_RETURN(quantizer);
quantizer->flags = *flags;
auto status = quantizer->DoQuantize(func_graph_bak, scale);
if (status != RET_OK) {
MS_LOG(WARNING) << "DoQuantization failed " << status;

View File

@ -28,6 +28,11 @@ enum ActivationQuantizedMethod {
REMOVAL_OUTLIER = 2,
};
enum QuantRuntimeDevice {
CPU,
KIRIN,
};
struct CommonQuantParam {
schema::QuantType quant_type = schema::QuantType_QUANT_NONE;
int bit_num = 8;
@ -36,6 +41,7 @@ struct CommonQuantParam {
bool is_debug = false;
std::string debug_info_save_path;
std::set<std::string> skip_quant_node;
QuantRuntimeDevice device = CPU;
int thread_num = 4;
};

View File

@ -24,7 +24,7 @@
#include "nnacl/op_base.h"
namespace mindspore::lite::quant {
bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &input_node, int preferred_dim) {
bool QuantStrategy::CanTensorQuantized(const CNodePtr &cnode, const AnfNodePtr &input_node, int preferred_dim) {
if (input_node == nullptr) {
MS_LOG(INFO) << "CanTensorQuantized input is nullptr!";
return false;
@ -65,8 +65,9 @@ bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &input_node, int preferr
return false;
}
// min_quant_weight_channel_ only supports convolution
if (weight_shape.size() > DIMENSION_2D &&
static const std::set<PrimitivePtr> check_channel_ops = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion};
if (CheckNodeInSet(cnode, check_channel_ops) && weight_shape.size() >= DIMENSION_2D &&
weight_shape[preferred_dim] <= static_cast<int>(min_quant_weight_channel_)) {
MS_LOG(INFO) << "preferred_dim shape:" << weight_shape[preferred_dim] << " less min_quant_weight_channel_ "
<< min_quant_weight_channel_;
@ -80,7 +81,7 @@ bool QuantStrategy::CanOpFullQuantized(const AnfNodePtr &node) {
if (!node->isa<mindspore::CNode>()) {
return false;
}
const auto cnode = std::dynamic_pointer_cast<mindspore::CNode>(node);
const auto cnode = node->cast<mindspore::CNodePtr>();
MS_ASSERT(cnode != nullptr);
auto type = NodePrimitiveType(cnode);
static const std::set<PrimitivePtr> support_int8_ops = {prim::kPrimAddFusion, prim::kPrimActivation,

View File

@ -33,7 +33,7 @@ class QuantStrategy {
~QuantStrategy() = default;
bool CanOpFullQuantized(const AnfNodePtr &node);
bool CanTensorQuantized(const AnfNodePtr &input_node, int preferred_dim);
bool CanTensorQuantized(const CNodePtr &cnode, const AnfNodePtr &input_node, int preferred_dim);
bool IsSkipOp(const AnfNodePtr &input_node);
private:

View File

@ -56,7 +56,8 @@ enum WeightQuantType {
MIXED_BIT_PER_LAYER = 2,
};
constexpr size_t kUint8Quantization = 8;
constexpr size_t kMaxBit = 8;
constexpr size_t k8Bit = 8;
constexpr size_t k16Bit = 16;
constexpr size_t kMaxNum1024 = 1024;
constexpr float kPercentBase = 100.0;
constexpr size_t kMillisecondsBase = 10;

View File

@ -31,16 +31,14 @@
namespace mindspore::lite::quant {
class Quantizer {
public:
explicit Quantizer(FuncGraphPtr graph) : funcGraph(std::move(graph)) {}
explicit Quantizer(const converter::Flags &config) : flags_(config) {}
virtual ~Quantizer() = default;
virtual int DoQuantize(FuncGraphPtr func_graph) = 0;
converter::Flags flags;
protected:
FuncGraphPtr funcGraph = nullptr;
converter::Flags flags_;
};
} // namespace mindspore::lite::quant
#endif

View File

@ -22,32 +22,6 @@
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::lite::quant {
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config) : Quantizer(std::move(graph)) {
this->flags = config;
this->bit_num_ = config.commonQuantParam.bit_num;
if (this->bit_num_ == 0) {
type_id_ = kNumberTypeInt16;
this->is_mixed_bit_ = true;
mixed_bit_init_scale_ = config.mixedBitWeightQuantParam.init_scale;
}
quant_strategy_ = std::make_unique<QuantStrategy>(config.commonQuantParam.min_quant_weight_size,
config.commonQuantParam.min_quant_weight_channel,
config.commonQuantParam.skip_quant_node);
// parse param for fixed bit quant.
if (!this->is_mixed_bit_) {
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
// parse type_id_
if (this->bit_num_ > 0 && this->bit_num_ <= kMaxBit) {
type_id_ = kNumberTypeInt8;
} else if (this->bit_num_ <= (kMaxBit * 2)) {
type_id_ = kNumberTypeInt16;
} else {
MS_LOG(ERROR) << "invalid input bits";
}
}
}
WeightQuantizer::~WeightQuantizer() {
for (const auto &fp32_output_tensor : fp32_output_tensors_) {
for (const auto &kv : fp32_output_tensor) {
@ -56,12 +30,12 @@ WeightQuantizer::~WeightQuantizer() {
}
}
int WeightQuantizer::DoWeightQuantize(const CNodePtr &cnode) {
int WeightQuantizer::DoWeightQuantize(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
CHECK_NULL_RETURN(cnode);
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
CHECK_NULL_RETURN(primitive);
WeightQuantType weight_quant_type = WeightQuantType::FIXED_BIT_PER_CHANNEL;
auto manager = api::FuncGraphManager::Manage(funcGraph, true);
auto manager = api::FuncGraphManager::Manage(func_graph, true);
CHECK_NULL_RETURN(manager);
std::set<PrimitivePtr> per_layer_primitive_types = {prim::kPrimAdam, prim::kPrimSGD, prim::kPrimApplyMomentum};
if (CheckNodeInSet(cnode, per_layer_primitive_types)) {
@ -89,11 +63,14 @@ int WeightQuantizer::DoWeightQuantize(const CNodePtr &cnode) {
continue;
}
int preferred_dim = GetPreferredDim(primitive, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
if (!quant_strategy_->CanTensorQuantized(input, preferred_dim)) {
MS_LOG(INFO) << "Input " << idx << "of Optimizer is not quantizable";
auto quant_strategy = std::make_unique<QuantStrategy>(flags_.commonQuantParam.min_quant_weight_size,
flags_.commonQuantParam.min_quant_weight_channel,
flags_.commonQuantParam.skip_quant_node);
if (!quant_strategy->CanTensorQuantized(cnode, input, preferred_dim)) {
MS_LOG(INFO) << input->fullname_with_scope() << " is not quantizable";
continue;
}
// support for shared weight
// support for matmul shared weight
auto node_map = manager->node_users();
auto node_user = node_map[input];
auto tmp_weight_quant_type = weight_quant_type;
@ -170,12 +147,8 @@ int WeightQuantizer::MarkWeightQuantizationInNodes(const FuncGraphPtr &func_grap
return RET_OK;
}
int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph, double init_scale) {
int WeightQuantizer::DoQuantize(const FuncGraphPtr &func_graph, double init_scale) {
mixed_bit_init_scale_ = init_scale;
return DoQuantize(std::move(func_graph));
}
int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
weight_quantized_tensors_.clear();
@ -192,7 +165,7 @@ int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
prim::kPrimAdam, prim::kPrimSGD,
prim::kPrimApplyMomentum};
if (CheckNodeInSet(cnode, support_primitive_types)) {
auto status = DoWeightQuantize(cnode);
auto status = DoWeightQuantize(func_graph, cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoWeightQuantize error";
return RET_ERROR;
@ -203,4 +176,6 @@ int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
}
return MarkWeightQuantizationInNodes(func_graph);
}
int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { return DoQuantize(func_graph, mixed_bit_init_scale_); }
} // namespace mindspore::lite::quant

View File

@ -23,6 +23,7 @@
#include <map>
#include <list>
#include <string>
#include <utility>
#include <vector>
#include <set>
#include "tools/converter/quantizer/quantizer.h"
@ -36,27 +37,49 @@
#include "base/base.h"
#include "abstract/dshape.h"
#include "src/lite_session.h"
#include "src/common/quant_utils.h"
namespace mindspore::lite::quant {
class WeightQuantizer : public Quantizer {
public:
WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config);
explicit WeightQuantizer(const converter::Flags &flags) : Quantizer(flags) {
bit_num_ = flags.commonQuantParam.bit_num;
if (bit_num_ == 0) {
type_id_ = kNumberTypeInt16;
is_mixed_bit_ = true;
mixed_bit_init_scale_ = flags.mixedBitWeightQuantParam.init_scale;
}
// parse param for fixed bit quant.
if (!is_mixed_bit_) {
quant_max_ = QuantMax(bit_num_, false);
quant_min_ = QuantMin(bit_num_, false, false);
// parse type_id_
if (bit_num_ > 0 && bit_num_ <= k8Bit) {
type_id_ = kNumberTypeInt8;
} else if (bit_num_ <= k16Bit) {
type_id_ = kNumberTypeInt16;
} else {
MS_LOG(ERROR) << "invalid input bits";
}
}
}
~WeightQuantizer() override;
int DoQuantize(FuncGraphPtr func_graph) override;
int DoQuantize(FuncGraphPtr func_graph, double init_scale);
int DoWeightQuantize(const CNodePtr &cnode);
int DoQuantize(const FuncGraphPtr &func_graph, double init_scale);
private:
int DoWeightQuantize(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
int MarkWeightQuantizationInNodes(const FuncGraphPtr &);
int DoMarkWeightQuantizeIfQuantized(const CNodePtr &);
private:
std::unique_ptr<QuantStrategy> quant_strategy_;
size_t bit_num_{8};
// delete it in the future.
std::set<tensor::TensorPtr> weight_quantized_tensors_;
std::vector<std::unordered_map<std::string, mindspore::tensor::MSTensor *>> fp32_output_tensors_;
bool is_mixed_bit_ = false;
double mixed_bit_init_scale_ = 0.02;
int MarkWeightQuantizationInNodes(const FuncGraphPtr &);
int DoMarkWeightQuantizeIfQuantized(const CNodePtr &);
int quant_max_{127};
int quant_min_{-128};