!29945 fix gather preference axis bug

Merge pull request !29945 from yeyunpeng2020/dynamic_quant_success
This commit is contained in:
i-robot 2022-02-14 01:30:56 +00:00 committed by Gitee
commit 33bc6978a0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
10 changed files with 83 additions and 41 deletions

View File

@ -19,6 +19,7 @@
#include "src/huffman_decode.h"
#include "tools/converter/quantizer/fse_decoder.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/gather_parameter.h"
namespace mindspore::lite {
namespace {
@ -430,6 +431,12 @@ int WeightDecoder::GetDeConvPreferredDim(const OpParameter *op_parameter, const
}
}
int WeightDecoder::GetGatherPreferredDim(const OpParameter *op_parameter) {
MS_ASSERT(op_parameter != nullptr);
const auto *param = reinterpret_cast<const GatherParameter *>(op_parameter);
return param->axis_;
}
bool IsChannelFirst(int index, const OpParameter *op_parameter) {
MS_ASSERT(op_parameter != nullptr);
if (op_parameter->type_ == schema::PrimitiveType_MatMulFusion) {
@ -443,7 +450,7 @@ bool IsChannelFirst(int index, const OpParameter *op_parameter) {
return true;
}
int WeightDecoder::GetPreferredDim(OpParameter *op_parameter, int index, const std::vector<int> &dims,
int WeightDecoder::GetPreferredDim(const OpParameter *op_parameter, int index, const std::vector<int> &dims,
const std::string &model_version) {
const int first_version_offset = 5;
if (model_version.empty() ||
@ -454,6 +461,8 @@ int WeightDecoder::GetPreferredDim(OpParameter *op_parameter, int index, const s
return GetMatMulPreferredDim(op_parameter, index, dims);
} else if (op_parameter->type_ == schema::PrimitiveType_Conv2dTransposeFusion) {
return 0;
} else if (op_parameter->type_ == schema::PrimitiveType_Gather) {
return GetGatherPreferredDim(op_parameter);
}
// The first index.
return 0;

View File

@ -137,7 +137,7 @@ class WeightDecoder {
static int UnPack(const SchemaTensorWrapper &src_tensor, lite::Tensor *dst_tensor);
static int GetPreferredDim(OpParameter *op_parameter, int index, const std::vector<int> &dims,
static int GetPreferredDim(const OpParameter *op_parameter, int index, const std::vector<int> &dims,
const std::string &model_version);
template <typename ST, typename DT = float>
@ -244,6 +244,7 @@ class WeightDecoder {
static int GetMatMulPreferredDim(const OpParameter *op_parameter, int input_index, const std::vector<int> &dims);
static int GetDeConvPreferredDim(const OpParameter *op_parameter, const std::vector<int> &dims);
static int GetGatherPreferredDim(const OpParameter *op_parameter);
static int DequantWeight(lite::Tensor *input_tensor, int preferred_dim, TypeId dst_data_type = kNumberTypeFloat32);
template <typename T1, typename T2>

View File

@ -54,7 +54,7 @@ std::vector<schema::CNodeT *> GetGraphNodes(const schema::MetaGraphT &graph_defT
int QuantTransform(const converter::Flags &ctx, schema::MetaGraphT *graph_defT) {
MS_ASSERT(graph_defT != nullptr);
// quantization
if (ctx.commonQuantParam.quant_type != schema::QuantType_QUANT_ALL) {
if (ctx.commonQuantParam.quant_type == schema::QuantType_QUANT_NONE) {
{
// quantization
// init old node indices

View File

@ -25,7 +25,7 @@ int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) {
flags_.commonQuantParam.min_quant_weight_size = 0;
auto quantizer = WeightQuantizer(flags_);
const std::set<PrimitivePtr> support_weight_quant_nodes = {prim::kPrimMatMulFusion, prim::kPrimGather};
const std::set<PrimitivePtr> symmetric_nodes = {};
const std::set<PrimitivePtr> symmetric_nodes = {prim::kPrimMatMulFusion};
auto ret = quantizer.WeightQuant(func_graph, support_weight_quant_nodes, {}, symmetric_nodes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Weight Quant failed.";

View File

@ -81,8 +81,9 @@ int FullQuantQuantizer::SetInOutQuantParam(const AnfNodePtr &input_node, const s
return RET_OK;
}
int FullQuantQuantizer::DoParameterWeightQuant(const ParameterPtr &weight, const PrimitivePtr &primitive,
bool per_channel, int input_index) const {
int FullQuantQuantizer::DoParameterWeightQuant(const CNodePtr &cnode, const ParameterPtr &weight,
const PrimitivePtr &primitive, bool per_channel, int input_index) const {
CHECK_NULL_RETURN(cnode);
CHECK_NULL_RETURN(weight);
CHECK_NULL_RETURN(primitive);
auto tensor_info = weight->default_param()->cast<tensor::TensorPtr>();
@ -90,10 +91,12 @@ int FullQuantQuantizer::DoParameterWeightQuant(const ParameterPtr &weight, const
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
return RET_NULL_PTR;
}
int preferred_dim =
GetPreferredDim(cnode, primitive, input_index - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
auto status =
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max_, weight_q_min_,
bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
auto status = FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max_,
weight_q_min_, bit_num_, weight_quant_type, kNumberTypeInt8,
input_index - 1, preferred_dim, weight_symmetry_, true);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status;
@ -171,13 +174,13 @@ int FullQuantQuantizer::DoParameterNodeQuant(const CNodePtr &cnode, const Parame
return ret;
}
} else if (CheckNodeInSet(cnode, per_channel_ops_)) {
ret = DoParameterWeightQuant(input_node, primitive, true, input_index);
ret = DoParameterWeightQuant(cnode, input_node, primitive, true, input_index);
if (ret != RET_OK) {
MS_LOG(ERROR) << op_name << " Do bias quant failed.";
return ret;
}
} else {
ret = DoParameterWeightQuant(input_node, primitive, false, input_index);
ret = DoParameterWeightQuant(cnode, input_node, primitive, false, input_index);
if (ret != RET_OK) {
MS_LOG(ERROR) << op_name << " Do bias quant failed.";
return ret;

View File

@ -57,8 +57,8 @@ class FullQuantQuantizer : public Quantizer {
int QuantNode(const FuncGraphPtr &func_graph);
int SetInOutQuantParam(const AnfNodePtr &input_node, const std::unique_ptr<DataDistribution> &info,
const PrimitivePtr &primitive, bool is_input, size_t index) const;
int DoParameterWeightQuant(const ParameterPtr &weight, const PrimitivePtr &primitive, bool per_channel,
int input_index) const;
int DoParameterWeightQuant(const CNodePtr &cnode, const ParameterPtr &weight, const PrimitivePtr &primitive,
bool per_channel, int input_index) const;
int DoValueNodeWeightQuant(const ValueNodePtr &weight, const PrimitivePtr &primitive, bool per_channel,
int input_index) const;
int DoParameterNodeQuant(const CNodePtr &cnode, const ParameterPtr &input_node, size_t input_index);

View File

@ -53,7 +53,7 @@ class InsertQuantNodeManager {
private:
TypeId dst_type_ = kNumberTypeInt8;
bool symmetric_ = true;
bool symmetric_ = false;
};
} // namespace mindspore::lite::quant
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_INSERT_QUANT_NODE_MANAGER_H

View File

@ -24,6 +24,7 @@
#include "include/version.h"
#include "ops/fusion/mat_mul_fusion.h"
#include "ops/fusion/conv2d_transpose_fusion.h"
#include "ops/gather.h"
#include "tools/converter/ops/ops_def.h"
#include "tools/anf_exporter/anf_exporter.h"
#include "tools/converter/quantizer/bitpacking.h"
@ -32,6 +33,7 @@
#include "abstract/abstract_value.h"
#include "securec/include/securec.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "nnacl/op_base.h"
using std::string;
using std::vector;
@ -45,6 +47,7 @@ constexpr int kSingleDirBiasTensorSize = 4;
constexpr int kLstmBiasShapeSize = 2;
constexpr int kLstmBiasIndex = 3;
constexpr size_t kBitNumPerByte = 8;
constexpr size_t kGatherAxisIndex = 3;
int ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, const std::vector<double> &input_scales,
const float *raw_datas, const QuantParamHolderPtr &quant_param_holder,
@ -414,10 +417,30 @@ int GetDeConvPreferredDim(const PrimitivePtr &primitive, const std::vector<int>
return 0;
}
int GetGatherPreferredDim(const CNodePtr &cnode) {
if (cnode->size() < kGatherAxisIndex + 1) {
MS_LOG(WARNING) << "gather cnode size < 4.";
return 0;
}
auto axis = cnode->input(kGatherAxisIndex);
tensor::TensorPtr tensor_info;
ParameterPtr parameter;
GetLiteParameter(axis, &parameter, &tensor_info);
size_t elem_count = tensor_info->DataSize();
if (elem_count != 1) {
MS_LOG(WARNING) << "gather axis data elem_count" << elem_count << " != 1.";
return 0;
} else {
auto *axis_data = static_cast<int *>(tensor_info->data_c());
return axis_data[0];
}
return 0;
}
int CalChannels(const std::vector<int> &dims, int channel_cnt, bool *channel_at_first) {
auto channels = dims[0];
if (!(*channel_at_first)) {
if (dims.size() != 2) {
if (dims.size() != DIMENSION_2D) {
MS_LOG(WARNING) << "unexpected dims size: " << dims.size();
*channel_at_first = true;
} else {
@ -429,11 +452,14 @@ int CalChannels(const std::vector<int> &dims, int channel_cnt, bool *channel_at_
return channels;
}
int GetPreferredDim(const PrimitivePtr &primitive, int input_index, const std::vector<int> &dims) {
int GetPreferredDim(const CNodePtr &cnode, const PrimitivePtr &primitive, int input_index,
const std::vector<int> &dims) {
if (primitive->name() == ops::kNameMatMulFusion) {
return GetMatMulPreferredDim(primitive, input_index, dims);
} else if (primitive->name() == ops::kNameConv2dTransposeFusion) {
return 0;
} else if (primitive->name() == ops::kNameGather) {
return GetGatherPreferredDim(cnode);
}
// The first index.
return 0;
@ -485,9 +511,9 @@ void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<in
}
}
int MixedBitQuantFilter(const AnfNodePtr &node, const tensor::TensorPtr &weight, const PrimitivePtr &primitive,
QuantType quant_type, WeightQuantType weight_quant_type, TypeId quant_data_type,
double init_scale, int index) {
int MixedBitQuantFilter(const AnfNodePtr &parameter_node, const tensor::TensorPtr &weight,
const PrimitivePtr &primitive, QuantType quant_type, WeightQuantType weight_quant_type,
TypeId quant_data_type, double init_scale, int index, int preferred_dim, bool symmetry) {
MS_CHECK_TRUE_RET(primitive != nullptr, RET_NULL_PTR);
MS_CHECK_TRUE_RET(weight != nullptr, RET_NULL_PTR);
auto dims = weight->shape();
@ -517,17 +543,17 @@ int MixedBitQuantFilter(const AnfNodePtr &node, const tensor::TensorPtr &weight,
const int quant_min = QuantMin(k8Bit, false, false); // -128
const int quant_max = QuantMax(k8Bit); // 127
MS_LOG(WARNING)
<< node->fullname_with_scope()
<< parameter_node->fullname_with_scope()
<< " mixed bit quantization search failed, the current layer rolls back to 8 bit fixed quantization.";
return FixedBitQuantFilter<int8_t>(node, weight, primitive, QuantType_QUANT_WEIGHT, quant_max, quant_min, k8Bit,
FIXED_BIT_PER_CHANNEL, kNumberTypeInt8, index);
return FixedBitQuantFilter<int8_t>(parameter_node, weight, primitive, QuantType_QUANT_WEIGHT, quant_max, quant_min,
k8Bit, FIXED_BIT_PER_CHANNEL, kNumberTypeInt8, index, preferred_dim, symmetry);
}
if (ret != RET_OK) {
return ret;
}
auto status =
UpdateTensorDataAndSize(node, weight, quant_data.data(), quant_data.size() * sizeof(int16_t), quant_data_type);
auto status = UpdateTensorDataAndSize(parameter_node, weight, quant_data.data(), quant_data.size() * sizeof(int16_t),
quant_data_type);
if (status != RET_OK) {
MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
return RET_ERROR;

View File

@ -86,13 +86,14 @@ void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<in
bool TensorQuantParamsInited(const schema::TensorT &tensor);
int MixedBitQuantFilter(const AnfNodePtr &node, const tensor::TensorPtr &weight, const PrimitivePtr &primitive,
QuantType quant_type, WeightQuantType weight_quant_type, TypeId quant_data_type,
double init_scale, int index);
int MixedBitQuantFilter(const AnfNodePtr &parameter_node, const tensor::TensorPtr &weight,
const PrimitivePtr &primitive, QuantType quant_type, WeightQuantType weight_quant_type,
TypeId quant_data_type, double init_scale, int index, int preferred_dim, bool symmetry);
int CalChannels(const std::vector<int> &dims, int channel_cnt, bool *channel_at_first);
int GetPreferredDim(const PrimitivePtr &primitive, int input_index, const std::vector<int> &dims);
int GetPreferredDim(const CNodePtr &cnode, const PrimitivePtr &primitive, int input_index,
const std::vector<int> &dims);
std::vector<int> ConvertShapeVectorToInt32(const ShapeVector &dims);
@ -121,10 +122,10 @@ int DeQuantData(const int8_t *tensor_data, int64_t elements_num, std::vector<lit
}
template <typename T>
int FixedBitQuantFilter(const AnfNodePtr &parameter, const tensor::TensorPtr &weight, const PrimitivePtr &primitive,
QuantType quant_type, int quant_max, int quant_min, size_t bit_num,
WeightQuantType weight_quant_type, TypeId quant_data_type, int index, bool symmetry = false,
bool narrow_range = false, bool k_means = false) {
int FixedBitQuantFilter(const AnfNodePtr &parameter_node, const tensor::TensorPtr &weight,
const PrimitivePtr &primitive, QuantType quant_type, int quant_max, int quant_min,
size_t bit_num, WeightQuantType weight_quant_type, TypeId quant_data_type, int index,
int preferred_dim, bool symmetry = false, bool narrow_range = false, bool k_means = false) {
MS_ASSERT(weight != nullptr);
MS_ASSERT(primitive != nullptr);
auto dims = weight->shape();
@ -146,7 +147,6 @@ int FixedBitQuantFilter(const AnfNodePtr &parameter, const tensor::TensorPtr &we
std::vector<T> quant_data(elem_count);
int ret = RET_OK;
if (weight_quant_type == FIXED_BIT_PER_CHANNEL) {
int preferred_dim = GetPreferredDim(primitive, index, ConvertShapeVectorToInt32(dims));
ret = DoPerChannelQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(),
static_cast<mindspore::schema::QuantType>(quant_type), &quant_params, quant_max,
quant_min, bit_num, &quant_data, ConvertShapeVectorToInt32(dims), preferred_dim,
@ -166,9 +166,10 @@ int FixedBitQuantFilter(const AnfNodePtr &parameter, const tensor::TensorPtr &we
}
} else {
MS_LOG(ERROR) << "Unsupported weight quant type:" << weight_quant_type;
return RET_ERROR;
}
auto status =
UpdateTensorDataAndSize(parameter, weight, quant_data.data(), quant_data.size() * sizeof(T), quant_data_type);
UpdateTensorDataAndSize(parameter_node, weight, quant_data.data(), quant_data.size() * sizeof(T), quant_data_type);
if (status != RET_OK) {
MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
return RET_ERROR;

View File

@ -100,11 +100,11 @@ int WeightQuantizer::DoCNodeWeightQuant(const FuncGraphPtr &func_graph, const CN
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " can not quant weight";
continue;
}
int preferred_dim = GetPreferredDim(primitive, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
auto quant_strategy = std::make_unique<QuantStrategy>(flags_.commonQuantParam.min_quant_weight_size,
flags_.commonQuantParam.min_quant_weight_channel,
flags_.commonQuantParam.skip_quant_node);
CHECK_NULL_RETURN(quant_strategy);
int preferred_dim = GetPreferredDim(cnode, primitive, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
if (!quant_strategy->CanTensorQuantized(cnode, input, preferred_dim)) {
MS_LOG(INFO) << input->fullname_with_scope() << " is not quantizable";
continue;
@ -120,14 +120,16 @@ int WeightQuantizer::DoCNodeWeightQuant(const FuncGraphPtr &func_graph, const CN
auto status = RET_ERROR;
if (is_mixed_bit_) {
status = MixedBitQuantFilter(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type,
WeightQuantType::MIXED_BIT_PER_LAYER, type_id_, mixed_bit_init_scale_, idx - 1);
WeightQuantType::MIXED_BIT_PER_LAYER, type_id_, mixed_bit_init_scale_, idx - 1,
preferred_dim, symmetric);
} else if (type_id_ == kNumberTypeInt8) {
status = FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type, q_max,
q_min, bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
} else if (type_id_ == kNumberTypeInt16) {
status =
FixedBitQuantFilter<int16_t>(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type, q_max,
q_min, bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type, q_max, q_min,
bit_num_, tmp_weight_quant_type, type_id_, idx - 1, preferred_dim, symmetric);
} else if (type_id_ == kNumberTypeInt16) {
status = FixedBitQuantFilter<int16_t>(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type,
q_max, q_min, bit_num_, tmp_weight_quant_type, type_id_, idx - 1,
preferred_dim, symmetric);
}
if (status == RET_NO_CHANGE) {
continue;