forked from mindspore-Ecosystem/mindspore
!28185 support r1.1 weight quant
Merge pull request !28185 from yeyunpeng2020/quant
This commit is contained in:
commit
78a4537cbb
|
@ -549,6 +549,20 @@ LiteModel *LiteImportFromPath(const char *model_path) {
|
|||
return model;
|
||||
}
|
||||
|
||||
bool LiteModel::CheckQuantAllInit(
|
||||
const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::QuantParam>> *quant_params) {
|
||||
if (quant_params == nullptr) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < quant_params->size(); i++) {
|
||||
auto quant_param = quant_params->Get(i);
|
||||
if (quant_param != nullptr && quant_param->inited() == false) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Model *ImportFromPath(const char *model_path) { return LiteImportFromPath(model_path); }
|
||||
|
||||
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
|
||||
|
|
|
@ -72,6 +72,57 @@ class LiteModel : public Model {
|
|||
|
||||
bool PrepareInnerTensors();
|
||||
|
||||
bool CheckQuantAllInit(const flatbuffers::Vector<flatbuffers::Offset<mindspore::schema::QuantParam>> *quant_params);
|
||||
|
||||
template <typename T = schema::MetaGraph, typename U = schema::CNode>
|
||||
int SetQuantType(const T &meta_graph, const U *c_node, Node *node) {
|
||||
node->quant_type_ = c_node->quantType();
|
||||
if (node->quant_type_ < schema::QuantType_MIN || node->quant_type_ > schema::QuantType_MAX) {
|
||||
MS_LOG(ERROR) << "node->quant_type_:" << node->quant_type_ << " is invalid.";
|
||||
delete node;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (schema_version_ == SCHEMA_VERSION::SCHEMA_CUR) {
|
||||
SetNodeDeviceType(node, *c_node);
|
||||
}
|
||||
#ifdef ENABLE_V0
|
||||
if (schema_version_ == SCHEMA_VERSION::SCHEMA_V0) {
|
||||
SetNodeDeviceType(node, *c_node);
|
||||
}
|
||||
#endif
|
||||
bool old_version_weight_quant =
|
||||
((meta_graph.version() == nullptr || meta_graph.version()->str() < "1.3.0") &&
|
||||
node->quant_type_ == schema::QuantType_QUANT_NONE && CheckNeedWeightQuant(meta_graph, c_node->inputIndex()));
|
||||
if (node->quant_type_ == schema::QuantType_PostTraining || node->quant_type_ == schema::QuantType_AwareTraining) {
|
||||
node->quant_type_ = schema::QuantType_QUANT_ALL;
|
||||
} else if (node->quant_type_ == schema::QuantType_WeightQuant || old_version_weight_quant) {
|
||||
node->quant_type_ = schema::QuantType_QUANT_WEIGHT;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CheckNeedWeightQuant(const T &meta_graph, const flatbuffers::Vector<uint32_t> *in_tensor_index) {
|
||||
const size_t min_quant_size = 2;
|
||||
if (in_tensor_index->size() < min_quant_size) {
|
||||
return false;
|
||||
}
|
||||
bool global_init_flag = false;
|
||||
for (size_t i = 0; i < in_tensor_index->size(); i++) {
|
||||
auto index = size_t(in_tensor_index->template GetAs<uint32_t>(i));
|
||||
auto tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(index);
|
||||
bool cur_tensor_init_flag = CheckQuantAllInit(tensor->quantParams());
|
||||
global_init_flag = global_init_flag || cur_tensor_init_flag;
|
||||
if (tensor->data() == nullptr && cur_tensor_init_flag) {
|
||||
MS_LOG(DEBUG) << tensor->name()->c_str()
|
||||
<< " is a non-const tensor, but there are quantization parameters, which may belong to full "
|
||||
"quantization.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return global_init_flag;
|
||||
}
|
||||
|
||||
template <typename T = schema::MetaGraph, typename U = schema::CNode>
|
||||
bool ConvertNodes(const T &meta_graph) {
|
||||
MS_CHECK_TRUE_MSG(meta_graph.nodes() != nullptr, false, "meta_graph is invalid, please check your model file.");
|
||||
|
@ -105,25 +156,10 @@ class LiteModel : public Model {
|
|||
#else
|
||||
node->primitive_ = c_node->primitive();
|
||||
#endif
|
||||
node->quant_type_ = c_node->quantType();
|
||||
if (node->quant_type_ < schema::QuantType_MIN || node->quant_type_ > schema::QuantType_MAX) {
|
||||
MS_LOG(ERROR) << "node->quant_type_:" << node->quant_type_ << " is invalid.";
|
||||
delete node;
|
||||
auto status = SetQuantType(meta_graph, c_node, node);
|
||||
if (status == RET_ERROR) {
|
||||
return false;
|
||||
}
|
||||
if (schema_version_ == SCHEMA_VERSION::SCHEMA_CUR) {
|
||||
SetNodeDeviceType(node, *c_node);
|
||||
}
|
||||
#ifdef ENABLE_V0
|
||||
if (schema_version_ == SCHEMA_VERSION::SCHEMA_V0) {
|
||||
SetNodeDeviceType(node, *c_node);
|
||||
}
|
||||
#endif
|
||||
if (node->quant_type_ == schema::QuantType_PostTraining || node->quant_type_ == schema::QuantType_AwareTraining) {
|
||||
node->quant_type_ = schema::QuantType_QUANT_ALL;
|
||||
} else if (node->quant_type_ == schema::QuantType_WeightQuant) {
|
||||
node->quant_type_ = schema::QuantType_QUANT_WEIGHT;
|
||||
}
|
||||
if (c_node->name() == nullptr) {
|
||||
node->name_ = "";
|
||||
} else {
|
||||
|
|
|
@ -23,39 +23,6 @@ namespace mindspore::lite {
|
|||
namespace {
|
||||
constexpr int kBit8 = 8;
|
||||
constexpr int kBit32 = 32;
|
||||
bool HasInitQuantParam(const std::vector<LiteQuantParam> &quant_params) {
|
||||
if (quant_params.empty()) {
|
||||
return false;
|
||||
}
|
||||
if (std::all_of(quant_params.cbegin(), quant_params.cend(),
|
||||
[](const LiteQuantParam &quant_param) { return quant_param.inited; })) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool CheckNeedWeightQuant(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors) {
|
||||
if (op_parameter->quant_type_ == schema::QuantType_QUANT_WEIGHT) {
|
||||
return true;
|
||||
}
|
||||
// compatible with r1.1
|
||||
if (op_parameter->quant_type_ == schema::QuantType_QUANT_NONE) {
|
||||
const size_t min_quant_size = 2;
|
||||
if (in_tensors.size() < min_quant_size) {
|
||||
return false;
|
||||
}
|
||||
for (auto tensor : in_tensors) {
|
||||
if (!tensor->IsConst() && HasInitQuantParam(tensor->quant_params())) {
|
||||
MS_LOG(DEBUG) << tensor->tensor_name()
|
||||
<< " is a non-const tensor, but there are quantization parameters, which may belong to full "
|
||||
"quantization.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
std::vector<bool> StringToBitVector(const std::string &str) {
|
||||
std::vector<bool> vec(str.size() * kBit8);
|
||||
|
@ -390,7 +357,7 @@ int WeightDecoder::UnPack(const SchemaTensorWrapper &src_tensor, lite::Tensor *d
|
|||
|
||||
int WeightDecoder::DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors,
|
||||
TypeId dst_data_type) {
|
||||
if (!CheckNeedWeightQuant(op_parameter, in_tensors)) {
|
||||
if (op_parameter->quant_type_ != schema::QuantType_QUANT_WEIGHT) {
|
||||
return RET_OK;
|
||||
}
|
||||
int index = 0;
|
||||
|
|
Loading…
Reference in New Issue