!28185 support r1.1 weight quant

Merge pull request !28185 from yeyunpeng2020/quant
This commit is contained in:
i-robot 2021-12-27 01:19:41 +00:00 committed by Gitee
commit 78a4537cbb
3 changed files with 68 additions and 51 deletions

View File

@ -549,6 +549,20 @@ LiteModel *LiteImportFromPath(const char *model_path) {
return model;
}
bool LiteModel::CheckQuantAllInit(
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::QuantParam>> *quant_params) {
if (quant_params == nullptr) {
return false;
}
for (size_t i = 0; i < quant_params->size(); i++) {
auto quant_param = quant_params->Get(i);
if (quant_param != nullptr && quant_param->inited() == false) {
return false;
}
}
return true;
}
Model *ImportFromPath(const char *model_path) { return LiteImportFromPath(model_path); }
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {

View File

@ -72,6 +72,57 @@ class LiteModel : public Model {
bool PrepareInnerTensors();
bool CheckQuantAllInit(const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::QuantParam>> *quant_params);
template <typename T = schema::MetaGraph, typename U = schema::CNode>
int SetQuantType(const T &meta_graph, const U *c_node, Node *node) {
node->quant_type_ = c_node->quantType();
if (node->quant_type_ < schema::QuantType_MIN || node->quant_type_ > schema::QuantType_MAX) {
MS_LOG(ERROR) << "node->quant_type_:" << node->quant_type_ << " is invalid.";
delete node;
return RET_ERROR;
}
if (schema_version_ == SCHEMA_VERSION::SCHEMA_CUR) {
SetNodeDeviceType(node, *c_node);
}
#ifdef ENABLE_V0
if (schema_version_ == SCHEMA_VERSION::SCHEMA_V0) {
SetNodeDeviceType(node, *c_node);
}
#endif
bool old_version_weight_quant =
((meta_graph.version() == nullptr || meta_graph.version()->str() < "1.3.0") &&
node->quant_type_ == schema::QuantType_QUANT_NONE && CheckNeedWeightQuant(meta_graph, c_node->inputIndex()));
if (node->quant_type_ == schema::QuantType_PostTraining || node->quant_type_ == schema::QuantType_AwareTraining) {
node->quant_type_ = schema::QuantType_QUANT_ALL;
} else if (node->quant_type_ == schema::QuantType_WeightQuant || old_version_weight_quant) {
node->quant_type_ = schema::QuantType_QUANT_WEIGHT;
}
return RET_OK;
}
template <typename T>
bool CheckNeedWeightQuant(const T &meta_graph, const flatbuffers::Vector<uint32_t> *in_tensor_index) {
const size_t min_quant_size = 2;
if (in_tensor_index->size() < min_quant_size) {
return false;
}
bool global_init_flag = false;
for (size_t i = 0; i < in_tensor_index->size(); i++) {
auto index = size_t(in_tensor_index->template GetAs<uint32_t>(i));
auto tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(index);
bool cur_tensor_init_flag = CheckQuantAllInit(tensor->quantParams());
global_init_flag = global_init_flag || cur_tensor_init_flag;
if (tensor->data() == nullptr && cur_tensor_init_flag) {
MS_LOG(DEBUG) << tensor->name()->c_str()
<< " is a non-const tensor, but there are quantization parameters, which may belong to full "
"quantization.";
return false;
}
}
return global_init_flag;
}
template <typename T = schema::MetaGraph, typename U = schema::CNode>
bool ConvertNodes(const T &meta_graph) {
MS_CHECK_TRUE_MSG(meta_graph.nodes() != nullptr, false, "meta_graph is invalid, please check your model file.");
@ -105,25 +156,10 @@ class LiteModel : public Model {
#else
node->primitive_ = c_node->primitive();
#endif
node->quant_type_ = c_node->quantType();
if (node->quant_type_ < schema::QuantType_MIN || node->quant_type_ > schema::QuantType_MAX) {
MS_LOG(ERROR) << "node->quant_type_:" << node->quant_type_ << " is invalid.";
delete node;
auto status = SetQuantType(meta_graph, c_node, node);
if (status == RET_ERROR) {
return false;
}
if (schema_version_ == SCHEMA_VERSION::SCHEMA_CUR) {
SetNodeDeviceType(node, *c_node);
}
#ifdef ENABLE_V0
if (schema_version_ == SCHEMA_VERSION::SCHEMA_V0) {
SetNodeDeviceType(node, *c_node);
}
#endif
if (node->quant_type_ == schema::QuantType_PostTraining || node->quant_type_ == schema::QuantType_AwareTraining) {
node->quant_type_ = schema::QuantType_QUANT_ALL;
} else if (node->quant_type_ == schema::QuantType_WeightQuant) {
node->quant_type_ = schema::QuantType_QUANT_WEIGHT;
}
if (c_node->name() == nullptr) {
node->name_ = "";
} else {

View File

@ -23,39 +23,6 @@ namespace mindspore::lite {
namespace {
constexpr int kBit8 = 8;
constexpr int kBit32 = 32;
bool HasInitQuantParam(const std::vector<LiteQuantParam> &quant_params) {
if (quant_params.empty()) {
return false;
}
if (std::all_of(quant_params.cbegin(), quant_params.cend(),
[](const LiteQuantParam &quant_param) { return quant_param.inited; })) {
return true;
}
return false;
}
bool CheckNeedWeightQuant(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors) {
if (op_parameter->quant_type_ == schema::QuantType_QUANT_WEIGHT) {
return true;
}
// compatible with r1.1
if (op_parameter->quant_type_ == schema::QuantType_QUANT_NONE) {
const size_t min_quant_size = 2;
if (in_tensors.size() < min_quant_size) {
return false;
}
for (auto tensor : in_tensors) {
if (!tensor->IsConst() && HasInitQuantParam(tensor->quant_params())) {
MS_LOG(DEBUG) << tensor->tensor_name()
<< " is a non-const tensor, but there are quantization parameters, which may belong to full "
"quantization.";
return false;
}
}
return true;
}
return false;
}
} // namespace
std::vector<bool> StringToBitVector(const std::string &str) {
std::vector<bool> vec(str.size() * kBit8);
@ -390,7 +357,7 @@ int WeightDecoder::UnPack(const SchemaTensorWrapper &src_tensor, lite::Tensor *d
int WeightDecoder::DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors,
TypeId dst_data_type) {
if (!CheckNeedWeightQuant(op_parameter, in_tensors)) {
if (op_parameter->quant_type_ != schema::QuantType_QUANT_WEIGHT) {
return RET_OK;
}
int index = 0;