!26723 full quant support skip node

Merge pull request !26723 from yeyunpeng2020/quant_bak2
This commit is contained in:
i-robot 2021-11-24 07:16:05 +00:00 committed by Gitee
commit d461433648
10 changed files with 107 additions and 48 deletions

View File

@ -112,7 +112,7 @@ int ConfigFileParser::ParseCommonQuantString(const std::map<std::string, std::ma
{"bit_num", common_quant_string_.bit_num},
{"min_quant_weight_size", common_quant_string_.min_quant_weight_size},
{"min_quant_weight_channel", common_quant_string_.min_quant_weight_channel},
{"skip_node", common_quant_string_.skip_node},
{"skip_quant_node", common_quant_string_.skip_quant_node},
{"debug_info_save_path", common_quant_string_.debug_info_save_path},
};
return SetMapData(map, parse_map, kCommonQuantParam);

View File

@ -41,7 +41,7 @@ struct CommonQuantString {
std::string bit_num;
std::string min_quant_weight_size;
std::string min_quant_weight_channel;
std::string skip_node;
std::string skip_quant_node;
std::string debug_info_save_path;
};

View File

@ -49,10 +49,10 @@ int QuantParamParser::ParseFilter(const CommonQuantString &common_quant_string,
return RET_INPUT_PARAM_INVALID;
}
}
if (!common_quant_string.skip_node.empty()) {
std::vector<std::string> nodes = SplitStringToVector(common_quant_string.skip_node, ',');
if (!common_quant_string.skip_quant_node.empty()) {
std::vector<std::string> nodes = SplitStringToVector(common_quant_string.skip_quant_node, ',');
for (const auto &node : nodes) {
common_quant->skip_node.insert(node);
common_quant->skip_quant_node.insert(node);
}
}
return RET_OK;

View File

@ -180,7 +180,7 @@ float DataDistribution::CalculateScale(float min_value, float max_value) {
this->encode_max_ = max_value;
// Optimize Handle 0.
MS_ASSERT(quant_max_ - quant_min_ > 0);
return (max_value - min_value) / (quant_max_ - quant_min_);
return (encode_max_ - encode_min_) / (quant_max_ - quant_min_);
}
float DataDistribution::CalculateKLScale() { return CalculateScale(this->best_T_, this->real_max_); }

View File

@ -40,6 +40,7 @@
#include "src/common/utils.h"
#include "tools/common/node_util.h"
#include "nnacl/op_base.h"
#include "src/common/log_util.h"
using std::string;
using std::vector;
@ -126,22 +127,6 @@ int ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, const s
}
} // namespace
FullQuantQuantizer::FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type)
: Quantizer(std::move(graph)) {
MS_ASSERT(graph != nullptr);
this->bit_num_ = bit_num;
this->target_data_type_ = target_type;
if (target_type == kNumberTypeInt8) {
q_max_ = (1 << (this->bit_num_ - 1)) - 1; // 127
q_min_ = -q_max_; // -127
} else if (target_type == kNumberTypeUInt8) {
q_max_ = (1 << this->bit_num_) - 1; // 255
q_min_ = 0;
} else {
MS_LOG(ERROR) << "unsupported quant value type: " << target_type;
}
}
FullQuantQuantizer::~FullQuantQuantizer() {
delete fp32_session_;
delete fp32_model_;
@ -507,41 +492,83 @@ int FullQuantQuantizer::UpdateDivergeInterval() {
return RET_OK;
}
int FullQuantQuantizer::PreProcess() {
calibrator_ =
std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_, this->flags.fullQuantParam.activation_quant_method,
this->flags.dataPreProcessParam);
if (calibrator_ == nullptr) {
MS_LOG(ERROR) << "create calibrator failed!";
return RET_NULL_PTR;
void FullQuantQuantizer::InitCpuConfig() {
this->target_data_type_ = kNumberTypeInt8;
activation_symmetry_ = true;
weight_symmetry_ = false;
}
void FullQuantQuantizer::InitQMinMax() {
if (this->target_data_type_ == kNumberTypeInt8) {
q_max_ = QuantMax(this->bit_num_, false); // 127
q_min_ = QuantMin(this->bit_num_, false, true); // -127
} else if (target_data_type_ == kNumberTypeUInt8) {
q_max_ = QuantMax(this->bit_num_, true); // 255
q_min_ = QuantMin(this->bit_num_, true, false); // 0
} else {
MS_LOG(ERROR) << "unsupported quant value type: " << target_data_type_;
}
}
int FullQuantQuantizer::MarkQuantNode() {
auto cnodes = funcGraph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
auto anode = cnode->cast<AnfNodePtr>();
if (anode == nullptr) {
MS_LOG(ERROR) << " cnode is null";
MS_LOG(ERROR) << cnode->fullname_with_scope() << " cnode is null";
return RET_NULL_PTR;
}
auto quant_strategy = std::make_unique<QuantStrategy>(flags.commonQuantParam.min_quant_weight_size,
flags.commonQuantParam.min_quant_weight_channel,
flags.commonQuantParam.skip_quant_node);
// Mark quantifiable nodes
if (mindspore::lite::quant::QuantStrategy::CanOpFullQuantized(anode)) {
auto is_support_op = quant_strategy->CanOpFullQuantized(anode);
auto is_skip_op = quant_strategy->IsSkipOp(anode);
if (is_support_op && !is_skip_op) {
auto ret = calibrator_->AddQuantizedOp(cnode);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Add Quantized Op failed.";
MS_LOG(ERROR) << cnode->fullname_with_scope() << " add quantized op failed.";
return ret;
}
}
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
if (primitive == nullptr) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << " primitive is null";
continue;
return RET_ERROR;
}
auto quant_param_holder = GetCNodeQuantHolder(primitive);
MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
if (quant_param_holder == nullptr) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << " quant_param_holder is null";
return RET_ERROR;
}
quant_param_holder->ClearInputOutputQuantParam();
}
return RET_OK;
}
int FullQuantQuantizer::PreProcess() {
switch (device_) {
case CPU:
InitCpuConfig();
break;
default:
MS_LOG(ERROR) << " Unsupported device " << device_;
return RET_ERROR;
break;
}
InitQMinMax();
calibrator_ =
std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_, this->flags.fullQuantParam.activation_quant_method,
this->flags.dataPreProcessParam, activation_symmetry_);
MSLITE_CHECK_PTR(calibrator_);
auto ret = MarkQuantNode();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Mark quant node failed.";
return ret;
}
return RET_OK;
}
int FullQuantQuantizer::CheckFp32TensorVec(const std::string &node_name,
const std::vector<mindspore::tensor::MSTensor *> &tensor_vec) {
if (tensor_vec.empty()) {
@ -943,9 +970,10 @@ KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
}
auto tensor = before_inputs[0];
MS_ASSERT(tensor != nullptr);
// op can be skipped.
if (tensor->data_type() != kNumberTypeInt8) {
MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type();
return false;
MS_LOG(INFO) << "tensor type is " << tensor->data_type();
return true;
}
// do quantization: activation is always per layer quantized
std::vector<int8_t> quant_datas;
@ -999,9 +1027,10 @@ KernelCallBack FullQuantQuantizer::GetInt8AfterCallBack() {
}
auto tensor = afterOutputs[0];
MS_ASSERT(tensor != nullptr);
// op can be skipped.
if (tensor->data_type() != kNumberTypeInt8) {
MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type();
return false;
MS_LOG(INFO) << "tensor type is " << tensor->data_type();
return true;
}
const int8_t *tensor_data = static_cast<int8_t *>(tensor->data());
size_t elem_count = tensor->ElementsNum();

View File

@ -35,15 +35,23 @@
#include "tools/converter/preprocess/preprocess_param.h"
#include "tools/converter/quantizer/calibrator.h"
#include "tools/converter/quantizer/data_distribution.h"
#include "src/common/quant_utils.h"
namespace mindspore::lite::quant {
enum OperationType {
STORE,
FETCH,
};
enum QuantRuntimeDevice {
CPU,
KIRIN,
};
class FullQuantQuantizer : public Quantizer {
public:
FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type = kNumberTypeInt8);
FullQuantQuantizer(FuncGraphPtr graph, int bit_num) : Quantizer(std::move(graph)), bit_num_(bit_num) {}
~FullQuantQuantizer() override;
int DoQuantize(FuncGraphPtr func_graph) override;
@ -82,9 +90,20 @@ class FullQuantQuantizer : public Quantizer {
KernelCallBack GetAfterCallBack(bool int8_op);
KernelCallBack GetInt8AfterCallBack();
KernelCallBack GetFloatAfterCallBack();
void InitQMinMax();
void InitCpuConfig();
int MarkQuantNode();
private:
// Config
TypeId target_data_type_{kNumberTypeInt8};
size_t bit_num_{8};
int q_max_{INT8_MAX};
int q_min_{INT8_MIN};
bool activation_symmetry_{true};
bool weight_symmetry_{true};
QuantRuntimeDevice device_ = CPU;
std::unique_ptr<Calibrator> calibrator_{nullptr};
session::LiteSession *fp32_session_{nullptr};
Model *fp32_model_{nullptr};
@ -96,10 +115,6 @@ class FullQuantQuantizer : public Quantizer {
std::map<std::string, std::vector<float>> op_bias_diff_map_; // only use by int8 model
std::mutex mutex_op_input_;
std::mutex mutex_op_output_;
size_t bit_num_;
int q_max_{INT8_MAX};
int q_min_{INT8_MIN};
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FULL_QUANT_QUANTIZER_H

View File

@ -35,7 +35,7 @@ struct CommonQuantParam {
int min_quant_weight_channel = 16;
bool is_debug = false;
std::string debug_info_save_path;
std::set<std::string> skip_node;
std::set<std::string> skip_quant_node;
int thread_num = 4;
};

View File

@ -120,4 +120,11 @@ bool QuantStrategy::CanOpFullQuantized(const AnfNodePtr &node) {
}
return is_data_type_fp32;
}
bool QuantStrategy::IsSkipOp(const AnfNodePtr &input_node) {
if (skip_node_.find(input_node->fullname_with_scope()) == skip_node_.end()) {
return false;
}
return true;
}
} // namespace mindspore::lite::quant

View File

@ -17,22 +17,29 @@
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_STRATEGY_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_STRATEGY_H
#include <cstddef>
#include <utility>
#include <set>
#include <string>
#include "ir/anf.h"
namespace mindspore::lite::quant {
class QuantStrategy {
public:
QuantStrategy(size_t min_quant_weight_size, size_t min_quant_weight_channel)
: min_quant_weight_size_(min_quant_weight_size), min_quant_weight_channel_(min_quant_weight_channel) {}
QuantStrategy(size_t min_quant_weight_size, size_t min_quant_weight_channel, std::set<std::string> skip_node)
: min_quant_weight_size_(min_quant_weight_size),
min_quant_weight_channel_(min_quant_weight_channel),
skip_node_(std::move(skip_node)) {}
~QuantStrategy() = default;
static bool CanOpFullQuantized(const AnfNodePtr &node);
bool CanOpFullQuantized(const AnfNodePtr &node);
bool CanTensorQuantized(const AnfNodePtr &input_node, int preferred_dim);
bool IsSkipOp(const AnfNodePtr &input_node);
private:
size_t min_quant_weight_size_;
size_t min_quant_weight_channel_;
std::set<std::string> skip_node_;
};
} // namespace mindspore::lite::quant

View File

@ -31,7 +31,8 @@ WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &con
mixed_bit_init_scale_ = config.mixedBitWeightQuantParam.init_scale;
}
quant_strategy_ = std::make_unique<QuantStrategy>(config.commonQuantParam.min_quant_weight_size,
config.commonQuantParam.min_quant_weight_channel);
config.commonQuantParam.min_quant_weight_channel,
config.commonQuantParam.skip_quant_node);
// parse param for fixed bit quant.
if (!this->is_mixed_bit_) {
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;