forked from mindspore-Ecosystem/mindspore
optimize the Quantizer base class and optimize min_quant_weight_channel strategy
This commit is contained in:
parent
35e0526b51
commit
8aaff90f7e
|
@ -1,8 +1,6 @@
|
|||
[common_quant_param]
|
||||
quant_type=FULL_QUANT
|
||||
bit_num=8
|
||||
min_quant_weight_size=0
|
||||
min_quant_weight_channel=16
|
||||
|
||||
[data_preprocess_param]
|
||||
calibrate_path=blob1:/home/workspace/mindspore_dataset/mslite/quantTraining/ml_face_landmark_2_calibration_data
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
[common_quant_param]
|
||||
quant_type=FULL_QUANT
|
||||
bit_num=8
|
||||
min_quant_weight_size=0
|
||||
min_quant_weight_channel=16
|
||||
|
||||
[data_preprocess_param]
|
||||
calibrate_path=blob1:/home/workspace/mindspore_dataset/mslite/quantTraining/ml_face_mnet_image
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
[common_quant_param]
|
||||
quant_type=FULL_QUANT
|
||||
bit_num=8
|
||||
min_quant_weight_size=0
|
||||
min_quant_weight_channel=16
|
||||
|
||||
[data_preprocess_param]
|
||||
calibrate_path=blob1:/home/workspace/mindspore_dataset/mslite/quantTraining/ml_face_mnet_calibration_data
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
[common_quant_param]
|
||||
quant_type=FULL_QUANT
|
||||
bit_num=8
|
||||
min_quant_weight_size=0
|
||||
min_quant_weight_channel=16
|
||||
|
||||
[data_preprocess_param]
|
||||
calibrate_path=conv2d_input:/home/workspace/mindspore_dataset/mslite/quantTraining/mnist_calibration_data
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
[common_quant_param]
|
||||
quant_type=FULL_QUANT
|
||||
bit_num=8
|
||||
min_quant_weight_size=0
|
||||
min_quant_weight_channel=16
|
||||
|
||||
[data_preprocess_param]
|
||||
calibrate_path=decoder_buffer_in_0:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_0,decoder_buffer_in_1:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_1,decoder_buffer_in_2:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_2,decoder_buffer_in_3:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_3,decoder_buffer_in_4:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_4,decoder_buffer_in_5:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_5,decoder_buffer_in_6:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_6,decoder_buffer_in_7:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_7,decoder_buffer_in_8:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_8,joint_in_encoder_deploy:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_9,decoder_in_deploy:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_decoder_calibration_data/input_10
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
[common_quant_param]
|
||||
quant_type=FULL_QUANT
|
||||
bit_num=8
|
||||
min_quant_weight_size=0
|
||||
min_quant_weight_channel=16
|
||||
|
||||
[data_preprocess_param]
|
||||
calibrate_path=buffer_in_extreme_0:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_0,buffer_in_extreme_1:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_1,buffer_in_extreme_2:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_2,buffer_in_extreme_3:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_3,buffer_in_extreme_4:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_4,buffer_in_extreme_5:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_5,buffer_in_extreme_6:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_6,buffer_in_extreme_7:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_7,buffer_in_extreme_8:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_8,buffer_in_extreme_9:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_9,buffer_in_extreme_10:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_10,buffer_in_extreme_11:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_11,buffer_in_extreme_12:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_12,buffer_in_extreme_13:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_13,buffer_in_extreme_14:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_14,buffer_in_extreme_15:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_15,buffer_in_extreme_16:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_16,buffer_in_extreme_17:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_17,buffer_in_extreme_18:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_18,buffer_in_extreme_19:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_19,buffer_in_extreme_20:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_20,buffer_in_extreme_21:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_21,buffer_in_extreme_22:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_22,buffer_in_extreme_23:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_23,buffer_in_extreme_24:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_24,buffer_in_extreme_25:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_25,buffer_in_extreme_26:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_26,buffer_in_extreme_27:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_27,buffer_in_extreme_28:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_28,buffer_in_extreme_29:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_29,buffer_in_extreme_30:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_30,buffer_in_extreme_31:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_31,buffer_in_extreme_32:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_32,buffer_in_extreme_33:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_33,buffer_in_extreme_34:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_34,encoder_in_deploy:/home/workspace/mindspore_dataset/mslite/quantTraining/transformer_0831_encoder_calibration_data/input_35
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <deque>
|
||||
#include <map>
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/converter/optimizer_manager.h"
|
||||
|
@ -363,29 +364,29 @@ void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGr
|
|||
}
|
||||
|
||||
int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
|
||||
constexpr int thread_num = 2;
|
||||
// quant
|
||||
if (config->commonQuantParam.quant_type != schema::QuantType_QUANT_ALL &&
|
||||
config->commonQuantParam.quant_type != schema::QuantType_QUANT_WEIGHT) {
|
||||
return RET_OK;
|
||||
}
|
||||
int status;
|
||||
std::unique_ptr<quant::Quantizer> quantizer;
|
||||
|
||||
quant::SessionModel origin;
|
||||
quant::SessionModel quant;
|
||||
if (config->commonQuantParam.is_debug) {
|
||||
converter::Flags new_flag = *config;
|
||||
new_flag.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
|
||||
origin = quant::CreateSessionByFuncGraph(old_graph, new_flag, thread_num);
|
||||
origin = quant::CreateSessionByFuncGraph(old_graph, new_flag, config->commonQuantParam.thread_num);
|
||||
}
|
||||
if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) {
|
||||
this->m_quantizer_ = std::make_unique<quant::FullQuantQuantizer>(old_graph, config->commonQuantParam.bit_num);
|
||||
if (m_quantizer_ == nullptr) {
|
||||
quantizer = std::make_unique<quant::FullQuantQuantizer>(*config);
|
||||
if (quantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New FullQuantQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return RET_ERROR;
|
||||
}
|
||||
m_quantizer_->flags = *config;
|
||||
status = m_quantizer_->DoQuantize(old_graph);
|
||||
status = quantizer->DoQuantize(old_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantization failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
|
@ -400,38 +401,35 @@ int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const con
|
|||
MS_LOG(ERROR) << "Grid search with scale failed.";
|
||||
return status;
|
||||
}
|
||||
this->m_quantizer_ = std::make_unique<quant::WeightQuantizer>(old_graph, *config);
|
||||
if (m_quantizer_ == nullptr) {
|
||||
quantizer = std::make_unique<quant::WeightQuantizer>(*config);
|
||||
if (quantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New WeightQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = static_cast<quant::WeightQuantizer *>(m_quantizer_.get())->DoQuantize(old_graph, init_scale);
|
||||
status = static_cast<quant::WeightQuantizer *>(quantizer.get())->DoQuantize(old_graph, init_scale);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantization failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
this->m_quantizer_ = std::make_unique<quant::WeightQuantizer>(old_graph, *config);
|
||||
if (m_quantizer_ == nullptr) {
|
||||
quantizer = std::make_unique<quant::WeightQuantizer>(*config);
|
||||
if (quantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New WeightQuantizer failed";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (m_quantizer_ != nullptr) {
|
||||
m_quantizer_->flags = *config;
|
||||
status = m_quantizer_->DoQuantize(old_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantization failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
status = quantizer->DoQuantize(old_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoQuantization failed " << status;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (config->commonQuantParam.is_debug) {
|
||||
quant = quant::CreateSessionByFuncGraph(old_graph, *config, thread_num);
|
||||
quant = quant::CreateSessionByFuncGraph(old_graph, *config, config->commonQuantParam.thread_num);
|
||||
std::map<std::string, OpParameter *> op_parameters;
|
||||
FetchOpParameterFromFuncGraph(old_graph, &op_parameters);
|
||||
DebugInfoManager manager;
|
||||
|
|
|
@ -37,8 +37,6 @@ class AnfTransform {
|
|||
FuncGraphPtr Transform(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
|
||||
|
||||
private:
|
||||
std::unique_ptr<quant::Quantizer> m_quantizer_ = nullptr;
|
||||
|
||||
FuncGraphPtr TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config = nullptr);
|
||||
|
||||
static int RunFusionPass(const FuncGraphPtr &old_graph, const converter::Flags *config);
|
||||
|
|
|
@ -4,10 +4,6 @@ quant_type=FULL_QUANT
|
|||
# Weight quantization support the number of bits [0,16], Set to 0 is mixed bit quantization, otherwise it is fixed bit quantization
|
||||
# Full quantization support the number of bits [1,8]
|
||||
bit_num=8
|
||||
# Layers with size of weights exceeds threshold `min_quant_weight_size` will be quantized.
|
||||
min_quant_weight_size=0
|
||||
# Layers with channel size of weights exceeds threshold `min_quant_weight_channel` will be quantized.
|
||||
min_quant_weight_channel=16
|
||||
|
||||
[data_preprocess_param]
|
||||
# Calibration dataset path, the format is input_name_1:input_1_dir,input_name_2:input_2_dir
|
||||
|
|
|
@ -405,11 +405,11 @@ int FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int FullQuantQuantizer::QuantNode() {
|
||||
int FullQuantQuantizer::QuantNode(const FuncGraphPtr &func_graph) {
|
||||
auto inputs_diverg_info = calibrator_->GetInputDivergInfo();
|
||||
auto outputs_diverg_info = calibrator_->GetOutputDivergInfo();
|
||||
|
||||
auto cnodes = funcGraph->GetOrderedCnodes();
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
|
@ -510,17 +510,17 @@ void FullQuantQuantizer::InitQMinMax() {
|
|||
}
|
||||
}
|
||||
|
||||
int FullQuantQuantizer::MarkQuantNode() {
|
||||
auto cnodes = funcGraph->GetOrderedCnodes();
|
||||
int FullQuantQuantizer::MarkQuantNode(const FuncGraphPtr &func_graph) {
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
auto anode = cnode->cast<AnfNodePtr>();
|
||||
if (anode == nullptr) {
|
||||
MS_LOG(ERROR) << cnode->fullname_with_scope() << " cnode is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto quant_strategy = std::make_unique<QuantStrategy>(flags.commonQuantParam.min_quant_weight_size,
|
||||
flags.commonQuantParam.min_quant_weight_channel,
|
||||
flags.commonQuantParam.skip_quant_node);
|
||||
auto quant_strategy = std::make_unique<QuantStrategy>(flags_.commonQuantParam.min_quant_weight_size,
|
||||
flags_.commonQuantParam.min_quant_weight_channel,
|
||||
flags_.commonQuantParam.skip_quant_node);
|
||||
// Mark quantifiable nodes
|
||||
auto is_support_op = quant_strategy->CanOpFullQuantized(anode);
|
||||
auto is_skip_op = quant_strategy->IsSkipOp(anode);
|
||||
|
@ -546,7 +546,7 @@ int FullQuantQuantizer::MarkQuantNode() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int FullQuantQuantizer::PreProcess() {
|
||||
int FullQuantQuantizer::PreProcess(const FuncGraphPtr &func_graph) {
|
||||
switch (device_) {
|
||||
case CPU:
|
||||
InitCpuConfig();
|
||||
|
@ -558,10 +558,10 @@ int FullQuantQuantizer::PreProcess() {
|
|||
}
|
||||
InitQMinMax();
|
||||
calibrator_ =
|
||||
std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_, this->flags.fullQuantParam.activation_quant_method,
|
||||
this->flags.dataPreProcessParam, activation_symmetry_);
|
||||
std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_, this->flags_.fullQuantParam.activation_quant_method,
|
||||
this->flags_.dataPreProcessParam, activation_symmetry_);
|
||||
MSLITE_CHECK_PTR(calibrator_);
|
||||
auto ret = MarkQuantNode();
|
||||
auto ret = MarkQuantNode(func_graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Mark quant node failed.";
|
||||
return ret;
|
||||
|
@ -804,21 +804,21 @@ int FullQuantQuantizer::ComputeThreshold() { return this->calibrator_->ComputeTh
|
|||
|
||||
int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||
MS_LOG(INFO) << "start to parse config file";
|
||||
if (flags.dataPreProcessParam.calibrate_path.empty()) {
|
||||
if (flags_.dataPreProcessParam.calibrate_path.empty()) {
|
||||
MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
int status = PreProcess();
|
||||
int status = PreProcess(func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "do pre process failed!";
|
||||
return status;
|
||||
}
|
||||
|
||||
// anf -- fb
|
||||
flags.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
|
||||
flags_.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
|
||||
MS_LOG(INFO) << "start create session";
|
||||
auto sm = CreateSessionByFuncGraph(func_graph, flags, this->flags.commonQuantParam.thread_num);
|
||||
auto sm = CreateSessionByFuncGraph(func_graph, flags_, this->flags_.commonQuantParam.thread_num);
|
||||
fp32_session_ = sm.session;
|
||||
fp32_model_ = sm.model;
|
||||
if (fp32_session_ == nullptr || fp32_model_ == nullptr) {
|
||||
|
@ -850,7 +850,7 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
return status;
|
||||
}
|
||||
MS_LOG(INFO) << "start to generate quant param and quantize tensor's data";
|
||||
status = QuantNode();
|
||||
status = QuantNode(func_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Quant node failed.";
|
||||
return status;
|
||||
|
@ -865,11 +865,11 @@ int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
SessionModel int8_sm;
|
||||
if (this->flags.fullQuantParam.bias_correction) {
|
||||
if (this->flags_.fullQuantParam.bias_correction) {
|
||||
// init in8 session
|
||||
MS_LOG(INFO) << "create quant session";
|
||||
flags.commonQuantParam.quant_type = schema::QuantType_QUANT_ALL;
|
||||
int8_sm = CreateSessionByFuncGraph(func_graph, flags, this->flags.commonQuantParam.thread_num);
|
||||
flags_.commonQuantParam.quant_type = schema::QuantType_QUANT_ALL;
|
||||
int8_sm = CreateSessionByFuncGraph(func_graph, flags_, this->flags_.commonQuantParam.thread_num);
|
||||
int8_session_ = int8_sm.session;
|
||||
int8_model_ = int8_sm.model;
|
||||
if (int8_session_ == nullptr || int8_model_ == nullptr) {
|
||||
|
|
|
@ -43,14 +43,11 @@ enum OperationType {
|
|||
FETCH,
|
||||
};
|
||||
|
||||
enum QuantRuntimeDevice {
|
||||
CPU,
|
||||
KIRIN,
|
||||
};
|
||||
|
||||
class FullQuantQuantizer : public Quantizer {
|
||||
public:
|
||||
FullQuantQuantizer(FuncGraphPtr graph, int bit_num) : Quantizer(std::move(graph)), bit_num_(bit_num) {}
|
||||
explicit FullQuantQuantizer(const converter::Flags &flags) : Quantizer(flags) {
|
||||
bit_num_ = flags.commonQuantParam.bit_num;
|
||||
}
|
||||
|
||||
~FullQuantQuantizer() override;
|
||||
|
||||
|
@ -60,7 +57,7 @@ class FullQuantQuantizer : public Quantizer {
|
|||
bool OpInputDataHandle(OperationType type, const string &op_name, std::vector<float> *data);
|
||||
bool OpOutputChMeanDataHandle(OperationType type, const string &op_name, std::vector<float> *data);
|
||||
|
||||
int PreProcess();
|
||||
int PreProcess(const FuncGraphPtr &func_graph);
|
||||
|
||||
int CheckFp32TensorVec(const std::string &node_name, const std::vector<mindspore::tensor::MSTensor *> &tensor_vec);
|
||||
|
||||
|
@ -72,7 +69,7 @@ class FullQuantQuantizer : public Quantizer {
|
|||
|
||||
int QuantNodeSimpleOp(const CNodePtr &cnode);
|
||||
|
||||
int QuantNode();
|
||||
int QuantNode(const FuncGraphPtr &func_graph);
|
||||
|
||||
int SetInOutQuantParam(const AnfNodePtr &input_node, const std::unique_ptr<DataDistribution> &info,
|
||||
const PrimitivePtr &primitive, bool is_input, size_t index) const;
|
||||
|
@ -92,7 +89,7 @@ class FullQuantQuantizer : public Quantizer {
|
|||
KernelCallBack GetFloatAfterCallBack();
|
||||
void InitQMinMax();
|
||||
void InitCpuConfig();
|
||||
int MarkQuantNode();
|
||||
int MarkQuantNode(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
// Config
|
||||
|
|
|
@ -202,8 +202,8 @@ int HuffmanEncode::DoHuffmanCompress(const int8_t *input_datas, const size_t dat
|
|||
out_c = 0;
|
||||
for (size_t i = 0; i < code_str.length(); i++) {
|
||||
auto tmp_c = code_str[i] == '0' ? 0 : 1;
|
||||
out_c += tmp_c << ((quant::kMaxBit - 1) - (i % quant::kMaxBit));
|
||||
if ((i + 1) % quant::kMaxBit == 0 || i == code_str.length() - 1) {
|
||||
out_c += tmp_c << ((quant::k8Bit - 1) - (i % quant::k8Bit));
|
||||
if ((i + 1) % quant::k8Bit == 0 || i == code_str.length() - 1) {
|
||||
encode_str[2] += out_c;
|
||||
out_c = 0;
|
||||
}
|
||||
|
|
|
@ -84,9 +84,8 @@ int ParameterOptimizer::WeightQuantModelInference(const FuncGraphPtr &func_graph
|
|||
}
|
||||
|
||||
// quant
|
||||
auto quantizer = std::make_unique<quant::WeightQuantizer>(func_graph_bak, *flags);
|
||||
auto quantizer = std::make_unique<quant::WeightQuantizer>(*flags);
|
||||
CHECK_NULL_RETURN(quantizer);
|
||||
quantizer->flags = *flags;
|
||||
auto status = quantizer->DoQuantize(func_graph_bak, scale);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(WARNING) << "DoQuantization failed " << status;
|
||||
|
|
|
@ -28,6 +28,11 @@ enum ActivationQuantizedMethod {
|
|||
REMOVAL_OUTLIER = 2,
|
||||
};
|
||||
|
||||
enum QuantRuntimeDevice {
|
||||
CPU,
|
||||
KIRIN,
|
||||
};
|
||||
|
||||
struct CommonQuantParam {
|
||||
schema::QuantType quant_type = schema::QuantType_QUANT_NONE;
|
||||
int bit_num = 8;
|
||||
|
@ -36,6 +41,7 @@ struct CommonQuantParam {
|
|||
bool is_debug = false;
|
||||
std::string debug_info_save_path;
|
||||
std::set<std::string> skip_quant_node;
|
||||
QuantRuntimeDevice device = CPU;
|
||||
int thread_num = 4;
|
||||
};
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
#include "nnacl/op_base.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &input_node, int preferred_dim) {
|
||||
bool QuantStrategy::CanTensorQuantized(const CNodePtr &cnode, const AnfNodePtr &input_node, int preferred_dim) {
|
||||
if (input_node == nullptr) {
|
||||
MS_LOG(INFO) << "CanTensorQuantized input is nullptr!";
|
||||
return false;
|
||||
|
@ -65,8 +65,9 @@ bool QuantStrategy::CanTensorQuantized(const AnfNodePtr &input_node, int preferr
|
|||
return false;
|
||||
}
|
||||
|
||||
// min_quant_weight_channel_ only supports convolution
|
||||
if (weight_shape.size() > DIMENSION_2D &&
|
||||
static const std::set<PrimitivePtr> check_channel_ops = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion};
|
||||
|
||||
if (CheckNodeInSet(cnode, check_channel_ops) && weight_shape.size() >= DIMENSION_2D &&
|
||||
weight_shape[preferred_dim] <= static_cast<int>(min_quant_weight_channel_)) {
|
||||
MS_LOG(INFO) << "preferred_dim shape:" << weight_shape[preferred_dim] << " less min_quant_weight_channel_ "
|
||||
<< min_quant_weight_channel_;
|
||||
|
@ -80,7 +81,7 @@ bool QuantStrategy::CanOpFullQuantized(const AnfNodePtr &node) {
|
|||
if (!node->isa<mindspore::CNode>()) {
|
||||
return false;
|
||||
}
|
||||
const auto cnode = std::dynamic_pointer_cast<mindspore::CNode>(node);
|
||||
const auto cnode = node->cast<mindspore::CNodePtr>();
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto type = NodePrimitiveType(cnode);
|
||||
static const std::set<PrimitivePtr> support_int8_ops = {prim::kPrimAddFusion, prim::kPrimActivation,
|
||||
|
|
|
@ -33,7 +33,7 @@ class QuantStrategy {
|
|||
~QuantStrategy() = default;
|
||||
|
||||
bool CanOpFullQuantized(const AnfNodePtr &node);
|
||||
bool CanTensorQuantized(const AnfNodePtr &input_node, int preferred_dim);
|
||||
bool CanTensorQuantized(const CNodePtr &cnode, const AnfNodePtr &input_node, int preferred_dim);
|
||||
bool IsSkipOp(const AnfNodePtr &input_node);
|
||||
|
||||
private:
|
||||
|
|
|
@ -56,7 +56,8 @@ enum WeightQuantType {
|
|||
MIXED_BIT_PER_LAYER = 2,
|
||||
};
|
||||
constexpr size_t kUint8Quantization = 8;
|
||||
constexpr size_t kMaxBit = 8;
|
||||
constexpr size_t k8Bit = 8;
|
||||
constexpr size_t k16Bit = 16;
|
||||
constexpr size_t kMaxNum1024 = 1024;
|
||||
constexpr float kPercentBase = 100.0;
|
||||
constexpr size_t kMillisecondsBase = 10;
|
||||
|
|
|
@ -31,16 +31,14 @@
|
|||
namespace mindspore::lite::quant {
|
||||
class Quantizer {
|
||||
public:
|
||||
explicit Quantizer(FuncGraphPtr graph) : funcGraph(std::move(graph)) {}
|
||||
explicit Quantizer(const converter::Flags &config) : flags_(config) {}
|
||||
|
||||
virtual ~Quantizer() = default;
|
||||
|
||||
virtual int DoQuantize(FuncGraphPtr func_graph) = 0;
|
||||
|
||||
converter::Flags flags;
|
||||
|
||||
protected:
|
||||
FuncGraphPtr funcGraph = nullptr;
|
||||
converter::Flags flags_;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif
|
||||
|
|
|
@ -22,32 +22,6 @@
|
|||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config) : Quantizer(std::move(graph)) {
|
||||
this->flags = config;
|
||||
this->bit_num_ = config.commonQuantParam.bit_num;
|
||||
if (this->bit_num_ == 0) {
|
||||
type_id_ = kNumberTypeInt16;
|
||||
this->is_mixed_bit_ = true;
|
||||
mixed_bit_init_scale_ = config.mixedBitWeightQuantParam.init_scale;
|
||||
}
|
||||
quant_strategy_ = std::make_unique<QuantStrategy>(config.commonQuantParam.min_quant_weight_size,
|
||||
config.commonQuantParam.min_quant_weight_channel,
|
||||
config.commonQuantParam.skip_quant_node);
|
||||
// parse param for fixed bit quant.
|
||||
if (!this->is_mixed_bit_) {
|
||||
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
|
||||
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
|
||||
// parse type_id_
|
||||
if (this->bit_num_ > 0 && this->bit_num_ <= kMaxBit) {
|
||||
type_id_ = kNumberTypeInt8;
|
||||
} else if (this->bit_num_ <= (kMaxBit * 2)) {
|
||||
type_id_ = kNumberTypeInt16;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "invalid input bits";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
WeightQuantizer::~WeightQuantizer() {
|
||||
for (const auto &fp32_output_tensor : fp32_output_tensors_) {
|
||||
for (const auto &kv : fp32_output_tensor) {
|
||||
|
@ -56,12 +30,12 @@ WeightQuantizer::~WeightQuantizer() {
|
|||
}
|
||||
}
|
||||
|
||||
int WeightQuantizer::DoWeightQuantize(const CNodePtr &cnode) {
|
||||
int WeightQuantizer::DoWeightQuantize(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
|
||||
CHECK_NULL_RETURN(cnode);
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
CHECK_NULL_RETURN(primitive);
|
||||
WeightQuantType weight_quant_type = WeightQuantType::FIXED_BIT_PER_CHANNEL;
|
||||
auto manager = api::FuncGraphManager::Manage(funcGraph, true);
|
||||
auto manager = api::FuncGraphManager::Manage(func_graph, true);
|
||||
CHECK_NULL_RETURN(manager);
|
||||
std::set<PrimitivePtr> per_layer_primitive_types = {prim::kPrimAdam, prim::kPrimSGD, prim::kPrimApplyMomentum};
|
||||
if (CheckNodeInSet(cnode, per_layer_primitive_types)) {
|
||||
|
@ -89,11 +63,14 @@ int WeightQuantizer::DoWeightQuantize(const CNodePtr &cnode) {
|
|||
continue;
|
||||
}
|
||||
int preferred_dim = GetPreferredDim(primitive, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
|
||||
if (!quant_strategy_->CanTensorQuantized(input, preferred_dim)) {
|
||||
MS_LOG(INFO) << "Input " << idx << "of Optimizer is not quantizable";
|
||||
auto quant_strategy = std::make_unique<QuantStrategy>(flags_.commonQuantParam.min_quant_weight_size,
|
||||
flags_.commonQuantParam.min_quant_weight_channel,
|
||||
flags_.commonQuantParam.skip_quant_node);
|
||||
if (!quant_strategy->CanTensorQuantized(cnode, input, preferred_dim)) {
|
||||
MS_LOG(INFO) << input->fullname_with_scope() << " is not quantizable";
|
||||
continue;
|
||||
}
|
||||
// support for shared weight
|
||||
// support for matmul shared weight
|
||||
auto node_map = manager->node_users();
|
||||
auto node_user = node_map[input];
|
||||
auto tmp_weight_quant_type = weight_quant_type;
|
||||
|
@ -170,12 +147,8 @@ int WeightQuantizer::MarkWeightQuantizationInNodes(const FuncGraphPtr &func_grap
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph, double init_scale) {
|
||||
int WeightQuantizer::DoQuantize(const FuncGraphPtr &func_graph, double init_scale) {
|
||||
mixed_bit_init_scale_ = init_scale;
|
||||
return DoQuantize(std::move(func_graph));
|
||||
}
|
||||
|
||||
int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||
MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
|
||||
weight_quantized_tensors_.clear();
|
||||
|
||||
|
@ -192,7 +165,7 @@ int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
prim::kPrimAdam, prim::kPrimSGD,
|
||||
prim::kPrimApplyMomentum};
|
||||
if (CheckNodeInSet(cnode, support_primitive_types)) {
|
||||
auto status = DoWeightQuantize(cnode);
|
||||
auto status = DoWeightQuantize(func_graph, cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoWeightQuantize error";
|
||||
return RET_ERROR;
|
||||
|
@ -203,4 +176,6 @@ int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
}
|
||||
return MarkWeightQuantizationInNodes(func_graph);
|
||||
}
|
||||
|
||||
int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) { return DoQuantize(func_graph, mixed_bit_init_scale_); }
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <map>
|
||||
#include <list>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include "tools/converter/quantizer/quantizer.h"
|
||||
|
@ -36,27 +37,49 @@
|
|||
#include "base/base.h"
|
||||
#include "abstract/dshape.h"
|
||||
#include "src/lite_session.h"
|
||||
#include "src/common/quant_utils.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class WeightQuantizer : public Quantizer {
|
||||
public:
|
||||
WeightQuantizer(FuncGraphPtr graph, const converter::Flags &config);
|
||||
explicit WeightQuantizer(const converter::Flags &flags) : Quantizer(flags) {
|
||||
bit_num_ = flags.commonQuantParam.bit_num;
|
||||
if (bit_num_ == 0) {
|
||||
type_id_ = kNumberTypeInt16;
|
||||
is_mixed_bit_ = true;
|
||||
mixed_bit_init_scale_ = flags.mixedBitWeightQuantParam.init_scale;
|
||||
}
|
||||
// parse param for fixed bit quant.
|
||||
if (!is_mixed_bit_) {
|
||||
quant_max_ = QuantMax(bit_num_, false);
|
||||
quant_min_ = QuantMin(bit_num_, false, false);
|
||||
// parse type_id_
|
||||
if (bit_num_ > 0 && bit_num_ <= k8Bit) {
|
||||
type_id_ = kNumberTypeInt8;
|
||||
} else if (bit_num_ <= k16Bit) {
|
||||
type_id_ = kNumberTypeInt16;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "invalid input bits";
|
||||
}
|
||||
}
|
||||
}
|
||||
~WeightQuantizer() override;
|
||||
|
||||
int DoQuantize(FuncGraphPtr func_graph) override;
|
||||
int DoQuantize(FuncGraphPtr func_graph, double init_scale);
|
||||
int DoWeightQuantize(const CNodePtr &cnode);
|
||||
int DoQuantize(const FuncGraphPtr &func_graph, double init_scale);
|
||||
|
||||
private:
|
||||
int DoWeightQuantize(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
|
||||
int MarkWeightQuantizationInNodes(const FuncGraphPtr &);
|
||||
int DoMarkWeightQuantizeIfQuantized(const CNodePtr &);
|
||||
|
||||
private:
|
||||
std::unique_ptr<QuantStrategy> quant_strategy_;
|
||||
size_t bit_num_{8};
|
||||
// delete it in the future.
|
||||
std::set<tensor::TensorPtr> weight_quantized_tensors_;
|
||||
std::vector<std::unordered_map<std::string, mindspore::tensor::MSTensor *>> fp32_output_tensors_;
|
||||
bool is_mixed_bit_ = false;
|
||||
double mixed_bit_init_scale_ = 0.02;
|
||||
int MarkWeightQuantizationInNodes(const FuncGraphPtr &);
|
||||
int DoMarkWeightQuantizeIfQuantized(const CNodePtr &);
|
||||
|
||||
int quant_max_{127};
|
||||
int quant_min_{-128};
|
||||
|
|
Loading…
Reference in New Issue