forked from mindspore-Ecosystem/mindspore
!26723 full quant support skip node
Merge pull request !26723 from yeyunpeng2020/quant_bak2
This commit is contained in:
commit
d461433648
|
@ -112,7 +112,7 @@ int ConfigFileParser::ParseCommonQuantString(const std::map<std::string, std::ma
|
|||
{"bit_num", common_quant_string_.bit_num},
|
||||
{"min_quant_weight_size", common_quant_string_.min_quant_weight_size},
|
||||
{"min_quant_weight_channel", common_quant_string_.min_quant_weight_channel},
|
||||
{"skip_node", common_quant_string_.skip_node},
|
||||
{"skip_quant_node", common_quant_string_.skip_quant_node},
|
||||
{"debug_info_save_path", common_quant_string_.debug_info_save_path},
|
||||
};
|
||||
return SetMapData(map, parse_map, kCommonQuantParam);
|
||||
|
|
|
@ -41,7 +41,7 @@ struct CommonQuantString {
|
|||
std::string bit_num;
|
||||
std::string min_quant_weight_size;
|
||||
std::string min_quant_weight_channel;
|
||||
std::string skip_node;
|
||||
std::string skip_quant_node;
|
||||
std::string debug_info_save_path;
|
||||
};
|
||||
|
||||
|
|
|
@ -49,10 +49,10 @@ int QuantParamParser::ParseFilter(const CommonQuantString &common_quant_string,
|
|||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
if (!common_quant_string.skip_node.empty()) {
|
||||
std::vector<std::string> nodes = SplitStringToVector(common_quant_string.skip_node, ',');
|
||||
if (!common_quant_string.skip_quant_node.empty()) {
|
||||
std::vector<std::string> nodes = SplitStringToVector(common_quant_string.skip_quant_node, ',');
|
||||
for (const auto &node : nodes) {
|
||||
common_quant->skip_node.insert(node);
|
||||
common_quant->skip_quant_node.insert(node);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
|
|
@ -180,7 +180,7 @@ float DataDistribution::CalculateScale(float min_value, float max_value) {
|
|||
this->encode_max_ = max_value;
|
||||
// Optimize Handle 0.
|
||||
MS_ASSERT(quant_max_ - quant_min_ > 0);
|
||||
return (max_value - min_value) / (quant_max_ - quant_min_);
|
||||
return (encode_max_ - encode_min_) / (quant_max_ - quant_min_);
|
||||
}
|
||||
|
||||
float DataDistribution::CalculateKLScale() { return CalculateScale(this->best_T_, this->real_max_); }
|
||||
|
|
|
@ -40,6 +40,7 @@
|
|||
#include "src/common/utils.h"
|
||||
#include "tools/common/node_util.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/common/log_util.h"
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
@ -126,22 +127,6 @@ int ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, const s
|
|||
}
|
||||
} // namespace
|
||||
|
||||
FullQuantQuantizer::FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type)
|
||||
: Quantizer(std::move(graph)) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
this->bit_num_ = bit_num;
|
||||
this->target_data_type_ = target_type;
|
||||
if (target_type == kNumberTypeInt8) {
|
||||
q_max_ = (1 << (this->bit_num_ - 1)) - 1; // 127
|
||||
q_min_ = -q_max_; // -127
|
||||
} else if (target_type == kNumberTypeUInt8) {
|
||||
q_max_ = (1 << this->bit_num_) - 1; // 255
|
||||
q_min_ = 0;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported quant value type: " << target_type;
|
||||
}
|
||||
}
|
||||
|
||||
FullQuantQuantizer::~FullQuantQuantizer() {
|
||||
delete fp32_session_;
|
||||
delete fp32_model_;
|
||||
|
@ -507,41 +492,83 @@ int FullQuantQuantizer::UpdateDivergeInterval() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int FullQuantQuantizer::PreProcess() {
|
||||
calibrator_ =
|
||||
std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_, this->flags.fullQuantParam.activation_quant_method,
|
||||
this->flags.dataPreProcessParam);
|
||||
if (calibrator_ == nullptr) {
|
||||
MS_LOG(ERROR) << "create calibrator failed!";
|
||||
return RET_NULL_PTR;
|
||||
void FullQuantQuantizer::InitCpuConfig() {
|
||||
this->target_data_type_ = kNumberTypeInt8;
|
||||
activation_symmetry_ = true;
|
||||
weight_symmetry_ = false;
|
||||
}
|
||||
|
||||
void FullQuantQuantizer::InitQMinMax() {
|
||||
if (this->target_data_type_ == kNumberTypeInt8) {
|
||||
q_max_ = QuantMax(this->bit_num_, false); // 127
|
||||
q_min_ = QuantMin(this->bit_num_, false, true); // -127
|
||||
} else if (target_data_type_ == kNumberTypeUInt8) {
|
||||
q_max_ = QuantMax(this->bit_num_, true); // 255
|
||||
q_min_ = QuantMin(this->bit_num_, true, false); // 0
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported quant value type: " << target_data_type_;
|
||||
}
|
||||
}
|
||||
|
||||
int FullQuantQuantizer::MarkQuantNode() {
|
||||
auto cnodes = funcGraph->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
auto anode = cnode->cast<AnfNodePtr>();
|
||||
if (anode == nullptr) {
|
||||
MS_LOG(ERROR) << " cnode is null";
|
||||
MS_LOG(ERROR) << cnode->fullname_with_scope() << " cnode is null";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto quant_strategy = std::make_unique<QuantStrategy>(flags.commonQuantParam.min_quant_weight_size,
|
||||
flags.commonQuantParam.min_quant_weight_channel,
|
||||
flags.commonQuantParam.skip_quant_node);
|
||||
// Mark quantifiable nodes
|
||||
if (mindspore::lite::quant::QuantStrategy::CanOpFullQuantized(anode)) {
|
||||
auto is_support_op = quant_strategy->CanOpFullQuantized(anode);
|
||||
auto is_skip_op = quant_strategy->IsSkipOp(anode);
|
||||
if (is_support_op && !is_skip_op) {
|
||||
auto ret = calibrator_->AddQuantizedOp(cnode);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Add Quantized Op failed.";
|
||||
MS_LOG(ERROR) << cnode->fullname_with_scope() << " add quantized op failed.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << cnode->fullname_with_scope() << " primitive is null";
|
||||
continue;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto quant_param_holder = GetCNodeQuantHolder(primitive);
|
||||
MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
|
||||
if (quant_param_holder == nullptr) {
|
||||
MS_LOG(ERROR) << cnode->fullname_with_scope() << " quant_param_holder is null";
|
||||
return RET_ERROR;
|
||||
}
|
||||
quant_param_holder->ClearInputOutputQuantParam();
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int FullQuantQuantizer::PreProcess() {
|
||||
switch (device_) {
|
||||
case CPU:
|
||||
InitCpuConfig();
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << " Unsupported device " << device_;
|
||||
return RET_ERROR;
|
||||
break;
|
||||
}
|
||||
InitQMinMax();
|
||||
calibrator_ =
|
||||
std::make_unique<Calibrator>(this->bit_num_, q_max_, q_min_, this->flags.fullQuantParam.activation_quant_method,
|
||||
this->flags.dataPreProcessParam, activation_symmetry_);
|
||||
MSLITE_CHECK_PTR(calibrator_);
|
||||
auto ret = MarkQuantNode();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Mark quant node failed.";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int FullQuantQuantizer::CheckFp32TensorVec(const std::string &node_name,
|
||||
const std::vector<mindspore::tensor::MSTensor *> &tensor_vec) {
|
||||
if (tensor_vec.empty()) {
|
||||
|
@ -943,9 +970,10 @@ KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
|
|||
}
|
||||
auto tensor = before_inputs[0];
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
// op can be skipped.
|
||||
if (tensor->data_type() != kNumberTypeInt8) {
|
||||
MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type();
|
||||
return false;
|
||||
MS_LOG(INFO) << "tensor type is " << tensor->data_type();
|
||||
return true;
|
||||
}
|
||||
// do quantization: activation is always per layer quantized
|
||||
std::vector<int8_t> quant_datas;
|
||||
|
@ -999,9 +1027,10 @@ KernelCallBack FullQuantQuantizer::GetInt8AfterCallBack() {
|
|||
}
|
||||
auto tensor = afterOutputs[0];
|
||||
MS_ASSERT(tensor != nullptr);
|
||||
// op can be skipped.
|
||||
if (tensor->data_type() != kNumberTypeInt8) {
|
||||
MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type();
|
||||
return false;
|
||||
MS_LOG(INFO) << "tensor type is " << tensor->data_type();
|
||||
return true;
|
||||
}
|
||||
const int8_t *tensor_data = static_cast<int8_t *>(tensor->data());
|
||||
size_t elem_count = tensor->ElementsNum();
|
||||
|
|
|
@ -35,15 +35,23 @@
|
|||
#include "tools/converter/preprocess/preprocess_param.h"
|
||||
#include "tools/converter/quantizer/calibrator.h"
|
||||
#include "tools/converter/quantizer/data_distribution.h"
|
||||
#include "src/common/quant_utils.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
enum OperationType {
|
||||
STORE,
|
||||
FETCH,
|
||||
};
|
||||
|
||||
enum QuantRuntimeDevice {
|
||||
CPU,
|
||||
KIRIN,
|
||||
};
|
||||
|
||||
class FullQuantQuantizer : public Quantizer {
|
||||
public:
|
||||
FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type = kNumberTypeInt8);
|
||||
FullQuantQuantizer(FuncGraphPtr graph, int bit_num) : Quantizer(std::move(graph)), bit_num_(bit_num) {}
|
||||
|
||||
~FullQuantQuantizer() override;
|
||||
|
||||
int DoQuantize(FuncGraphPtr func_graph) override;
|
||||
|
@ -82,9 +90,20 @@ class FullQuantQuantizer : public Quantizer {
|
|||
KernelCallBack GetAfterCallBack(bool int8_op);
|
||||
KernelCallBack GetInt8AfterCallBack();
|
||||
KernelCallBack GetFloatAfterCallBack();
|
||||
void InitQMinMax();
|
||||
void InitCpuConfig();
|
||||
int MarkQuantNode();
|
||||
|
||||
private:
|
||||
// Config
|
||||
TypeId target_data_type_{kNumberTypeInt8};
|
||||
size_t bit_num_{8};
|
||||
int q_max_{INT8_MAX};
|
||||
int q_min_{INT8_MIN};
|
||||
bool activation_symmetry_{true};
|
||||
bool weight_symmetry_{true};
|
||||
QuantRuntimeDevice device_ = CPU;
|
||||
|
||||
std::unique_ptr<Calibrator> calibrator_{nullptr};
|
||||
session::LiteSession *fp32_session_{nullptr};
|
||||
Model *fp32_model_{nullptr};
|
||||
|
@ -96,10 +115,6 @@ class FullQuantQuantizer : public Quantizer {
|
|||
std::map<std::string, std::vector<float>> op_bias_diff_map_; // only use by int8 model
|
||||
std::mutex mutex_op_input_;
|
||||
std::mutex mutex_op_output_;
|
||||
|
||||
size_t bit_num_;
|
||||
int q_max_{INT8_MAX};
|
||||
int q_min_{INT8_MIN};
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FULL_QUANT_QUANTIZER_H
|
||||
|
|
|
@ -35,7 +35,7 @@ struct CommonQuantParam {
|
|||
int min_quant_weight_channel = 16;
|
||||
bool is_debug = false;
|
||||
std::string debug_info_save_path;
|
||||
std::set<std::string> skip_node;
|
||||
std::set<std::string> skip_quant_node;
|
||||
int thread_num = 4;
|
||||
};
|
||||
|
||||
|
|
|
@ -120,4 +120,11 @@ bool QuantStrategy::CanOpFullQuantized(const AnfNodePtr &node) {
|
|||
}
|
||||
return is_data_type_fp32;
|
||||
}
|
||||
|
||||
bool QuantStrategy::IsSkipOp(const AnfNodePtr &input_node) {
|
||||
if (skip_node_.find(input_node->fullname_with_scope()) == skip_node_.end()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -17,22 +17,29 @@
|
|||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_STRATEGY_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_QUANT_STRATEGY_H
|
||||
#include <cstddef>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore::lite::quant {
|
||||
class QuantStrategy {
|
||||
public:
|
||||
QuantStrategy(size_t min_quant_weight_size, size_t min_quant_weight_channel)
|
||||
: min_quant_weight_size_(min_quant_weight_size), min_quant_weight_channel_(min_quant_weight_channel) {}
|
||||
QuantStrategy(size_t min_quant_weight_size, size_t min_quant_weight_channel, std::set<std::string> skip_node)
|
||||
: min_quant_weight_size_(min_quant_weight_size),
|
||||
min_quant_weight_channel_(min_quant_weight_channel),
|
||||
skip_node_(std::move(skip_node)) {}
|
||||
|
||||
~QuantStrategy() = default;
|
||||
|
||||
static bool CanOpFullQuantized(const AnfNodePtr &node);
|
||||
bool CanOpFullQuantized(const AnfNodePtr &node);
|
||||
bool CanTensorQuantized(const AnfNodePtr &input_node, int preferred_dim);
|
||||
bool IsSkipOp(const AnfNodePtr &input_node);
|
||||
|
||||
private:
|
||||
size_t min_quant_weight_size_;
|
||||
size_t min_quant_weight_channel_;
|
||||
std::set<std::string> skip_node_;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
|
||||
|
|
|
@ -31,7 +31,8 @@ WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const converter::Flags &con
|
|||
mixed_bit_init_scale_ = config.mixedBitWeightQuantParam.init_scale;
|
||||
}
|
||||
quant_strategy_ = std::make_unique<QuantStrategy>(config.commonQuantParam.min_quant_weight_size,
|
||||
config.commonQuantParam.min_quant_weight_channel);
|
||||
config.commonQuantParam.min_quant_weight_channel,
|
||||
config.commonQuantParam.skip_quant_node);
|
||||
// parse param for fixed bit quant.
|
||||
if (!this->is_mixed_bit_) {
|
||||
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
|
||||
|
|
Loading…
Reference in New Issue