!29945 fix gather preference axis bug
Merge pull request !29945 from yeyunpeng2020/dynamic_quant_success
This commit is contained in:
commit
33bc6978a0
|
@ -19,6 +19,7 @@
|
|||
#include "src/huffman_decode.h"
|
||||
#include "tools/converter/quantizer/fse_decoder.h"
|
||||
#include "nnacl/conv_parameter.h"
|
||||
#include "nnacl/gather_parameter.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
namespace {
|
||||
|
@ -430,6 +431,12 @@ int WeightDecoder::GetDeConvPreferredDim(const OpParameter *op_parameter, const
|
|||
}
|
||||
}
|
||||
|
||||
int WeightDecoder::GetGatherPreferredDim(const OpParameter *op_parameter) {
|
||||
MS_ASSERT(op_parameter != nullptr);
|
||||
const auto *param = reinterpret_cast<const GatherParameter *>(op_parameter);
|
||||
return param->axis_;
|
||||
}
|
||||
|
||||
bool IsChannelFirst(int index, const OpParameter *op_parameter) {
|
||||
MS_ASSERT(op_parameter != nullptr);
|
||||
if (op_parameter->type_ == schema::PrimitiveType_MatMulFusion) {
|
||||
|
@ -443,7 +450,7 @@ bool IsChannelFirst(int index, const OpParameter *op_parameter) {
|
|||
return true;
|
||||
}
|
||||
|
||||
int WeightDecoder::GetPreferredDim(OpParameter *op_parameter, int index, const std::vector<int> &dims,
|
||||
int WeightDecoder::GetPreferredDim(const OpParameter *op_parameter, int index, const std::vector<int> &dims,
|
||||
const std::string &model_version) {
|
||||
const int first_version_offset = 5;
|
||||
if (model_version.empty() ||
|
||||
|
@ -454,6 +461,8 @@ int WeightDecoder::GetPreferredDim(OpParameter *op_parameter, int index, const s
|
|||
return GetMatMulPreferredDim(op_parameter, index, dims);
|
||||
} else if (op_parameter->type_ == schema::PrimitiveType_Conv2dTransposeFusion) {
|
||||
return 0;
|
||||
} else if (op_parameter->type_ == schema::PrimitiveType_Gather) {
|
||||
return GetGatherPreferredDim(op_parameter);
|
||||
}
|
||||
// The first index.
|
||||
return 0;
|
||||
|
|
|
@ -137,7 +137,7 @@ class WeightDecoder {
|
|||
|
||||
static int UnPack(const SchemaTensorWrapper &src_tensor, lite::Tensor *dst_tensor);
|
||||
|
||||
static int GetPreferredDim(OpParameter *op_parameter, int index, const std::vector<int> &dims,
|
||||
static int GetPreferredDim(const OpParameter *op_parameter, int index, const std::vector<int> &dims,
|
||||
const std::string &model_version);
|
||||
|
||||
template <typename ST, typename DT = float>
|
||||
|
@ -244,6 +244,7 @@ class WeightDecoder {
|
|||
|
||||
static int GetMatMulPreferredDim(const OpParameter *op_parameter, int input_index, const std::vector<int> &dims);
|
||||
static int GetDeConvPreferredDim(const OpParameter *op_parameter, const std::vector<int> &dims);
|
||||
static int GetGatherPreferredDim(const OpParameter *op_parameter);
|
||||
static int DequantWeight(lite::Tensor *input_tensor, int preferred_dim, TypeId dst_data_type = kNumberTypeFloat32);
|
||||
|
||||
template <typename T1, typename T2>
|
||||
|
|
|
@ -54,7 +54,7 @@ std::vector<schema::CNodeT *> GetGraphNodes(const schema::MetaGraphT &graph_defT
|
|||
int QuantTransform(const converter::Flags &ctx, schema::MetaGraphT *graph_defT) {
|
||||
MS_ASSERT(graph_defT != nullptr);
|
||||
// quantization
|
||||
if (ctx.commonQuantParam.quant_type != schema::QuantType_QUANT_ALL) {
|
||||
if (ctx.commonQuantParam.quant_type == schema::QuantType_QUANT_NONE) {
|
||||
{
|
||||
// quantization
|
||||
// init old node indices
|
||||
|
|
|
@ -25,7 +25,7 @@ int DynamicQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
flags_.commonQuantParam.min_quant_weight_size = 0;
|
||||
auto quantizer = WeightQuantizer(flags_);
|
||||
const std::set<PrimitivePtr> support_weight_quant_nodes = {prim::kPrimMatMulFusion, prim::kPrimGather};
|
||||
const std::set<PrimitivePtr> symmetric_nodes = {};
|
||||
const std::set<PrimitivePtr> symmetric_nodes = {prim::kPrimMatMulFusion};
|
||||
auto ret = quantizer.WeightQuant(func_graph, support_weight_quant_nodes, {}, symmetric_nodes);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Weight Quant failed.";
|
||||
|
|
|
@ -81,8 +81,9 @@ int FullQuantQuantizer::SetInOutQuantParam(const AnfNodePtr &input_node, const s
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int FullQuantQuantizer::DoParameterWeightQuant(const ParameterPtr &weight, const PrimitivePtr &primitive,
|
||||
bool per_channel, int input_index) const {
|
||||
int FullQuantQuantizer::DoParameterWeightQuant(const CNodePtr &cnode, const ParameterPtr &weight,
|
||||
const PrimitivePtr &primitive, bool per_channel, int input_index) const {
|
||||
CHECK_NULL_RETURN(cnode);
|
||||
CHECK_NULL_RETURN(weight);
|
||||
CHECK_NULL_RETURN(primitive);
|
||||
auto tensor_info = weight->default_param()->cast<tensor::TensorPtr>();
|
||||
|
@ -90,10 +91,12 @@ int FullQuantQuantizer::DoParameterWeightQuant(const ParameterPtr &weight, const
|
|||
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
int preferred_dim =
|
||||
GetPreferredDim(cnode, primitive, input_index - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
|
||||
auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
|
||||
auto status =
|
||||
FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max_, weight_q_min_,
|
||||
bit_num_, weight_quant_type, kNumberTypeInt8, input_index - 1, weight_symmetry_, true);
|
||||
auto status = FixedBitQuantFilter<int8_t>(weight, tensor_info, primitive, QuantType_QUANT_ALL, weight_q_max_,
|
||||
weight_q_min_, bit_num_, weight_quant_type, kNumberTypeInt8,
|
||||
input_index - 1, preferred_dim, weight_symmetry_, true);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||
return status;
|
||||
|
@ -171,13 +174,13 @@ int FullQuantQuantizer::DoParameterNodeQuant(const CNodePtr &cnode, const Parame
|
|||
return ret;
|
||||
}
|
||||
} else if (CheckNodeInSet(cnode, per_channel_ops_)) {
|
||||
ret = DoParameterWeightQuant(input_node, primitive, true, input_index);
|
||||
ret = DoParameterWeightQuant(cnode, input_node, primitive, true, input_index);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << op_name << " Do bias quant failed.";
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
ret = DoParameterWeightQuant(input_node, primitive, false, input_index);
|
||||
ret = DoParameterWeightQuant(cnode, input_node, primitive, false, input_index);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << op_name << " Do bias quant failed.";
|
||||
return ret;
|
||||
|
|
|
@ -57,8 +57,8 @@ class FullQuantQuantizer : public Quantizer {
|
|||
int QuantNode(const FuncGraphPtr &func_graph);
|
||||
int SetInOutQuantParam(const AnfNodePtr &input_node, const std::unique_ptr<DataDistribution> &info,
|
||||
const PrimitivePtr &primitive, bool is_input, size_t index) const;
|
||||
int DoParameterWeightQuant(const ParameterPtr &weight, const PrimitivePtr &primitive, bool per_channel,
|
||||
int input_index) const;
|
||||
int DoParameterWeightQuant(const CNodePtr &cnode, const ParameterPtr &weight, const PrimitivePtr &primitive,
|
||||
bool per_channel, int input_index) const;
|
||||
int DoValueNodeWeightQuant(const ValueNodePtr &weight, const PrimitivePtr &primitive, bool per_channel,
|
||||
int input_index) const;
|
||||
int DoParameterNodeQuant(const CNodePtr &cnode, const ParameterPtr &input_node, size_t input_index);
|
||||
|
|
|
@ -53,7 +53,7 @@ class InsertQuantNodeManager {
|
|||
|
||||
private:
|
||||
TypeId dst_type_ = kNumberTypeInt8;
|
||||
bool symmetric_ = true;
|
||||
bool symmetric_ = false;
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_INSERT_QUANT_NODE_MANAGER_H
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "include/version.h"
|
||||
#include "ops/fusion/mat_mul_fusion.h"
|
||||
#include "ops/fusion/conv2d_transpose_fusion.h"
|
||||
#include "ops/gather.h"
|
||||
#include "tools/converter/ops/ops_def.h"
|
||||
#include "tools/anf_exporter/anf_exporter.h"
|
||||
#include "tools/converter/quantizer/bitpacking.h"
|
||||
|
@ -32,6 +33,7 @@
|
|||
#include "abstract/abstract_value.h"
|
||||
#include "securec/include/securec.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
@ -45,6 +47,7 @@ constexpr int kSingleDirBiasTensorSize = 4;
|
|||
constexpr int kLstmBiasShapeSize = 2;
|
||||
constexpr int kLstmBiasIndex = 3;
|
||||
constexpr size_t kBitNumPerByte = 8;
|
||||
constexpr size_t kGatherAxisIndex = 3;
|
||||
|
||||
int ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, const std::vector<double> &input_scales,
|
||||
const float *raw_datas, const QuantParamHolderPtr &quant_param_holder,
|
||||
|
@ -414,10 +417,30 @@ int GetDeConvPreferredDim(const PrimitivePtr &primitive, const std::vector<int>
|
|||
return 0;
|
||||
}
|
||||
|
||||
int GetGatherPreferredDim(const CNodePtr &cnode) {
|
||||
if (cnode->size() < kGatherAxisIndex + 1) {
|
||||
MS_LOG(WARNING) << "gather cnode size < 4.";
|
||||
return 0;
|
||||
}
|
||||
auto axis = cnode->input(kGatherAxisIndex);
|
||||
tensor::TensorPtr tensor_info;
|
||||
ParameterPtr parameter;
|
||||
GetLiteParameter(axis, ¶meter, &tensor_info);
|
||||
size_t elem_count = tensor_info->DataSize();
|
||||
if (elem_count != 1) {
|
||||
MS_LOG(WARNING) << "gather axis data elem_count" << elem_count << " != 1.";
|
||||
return 0;
|
||||
} else {
|
||||
auto *axis_data = static_cast<int *>(tensor_info->data_c());
|
||||
return axis_data[0];
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int CalChannels(const std::vector<int> &dims, int channel_cnt, bool *channel_at_first) {
|
||||
auto channels = dims[0];
|
||||
if (!(*channel_at_first)) {
|
||||
if (dims.size() != 2) {
|
||||
if (dims.size() != DIMENSION_2D) {
|
||||
MS_LOG(WARNING) << "unexpected dims size: " << dims.size();
|
||||
*channel_at_first = true;
|
||||
} else {
|
||||
|
@ -429,11 +452,14 @@ int CalChannels(const std::vector<int> &dims, int channel_cnt, bool *channel_at_
|
|||
return channels;
|
||||
}
|
||||
|
||||
int GetPreferredDim(const PrimitivePtr &primitive, int input_index, const std::vector<int> &dims) {
|
||||
int GetPreferredDim(const CNodePtr &cnode, const PrimitivePtr &primitive, int input_index,
|
||||
const std::vector<int> &dims) {
|
||||
if (primitive->name() == ops::kNameMatMulFusion) {
|
||||
return GetMatMulPreferredDim(primitive, input_index, dims);
|
||||
} else if (primitive->name() == ops::kNameConv2dTransposeFusion) {
|
||||
return 0;
|
||||
} else if (primitive->name() == ops::kNameGather) {
|
||||
return GetGatherPreferredDim(cnode);
|
||||
}
|
||||
// The first index.
|
||||
return 0;
|
||||
|
@ -485,9 +511,9 @@ void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<in
|
|||
}
|
||||
}
|
||||
|
||||
int MixedBitQuantFilter(const AnfNodePtr &node, const tensor::TensorPtr &weight, const PrimitivePtr &primitive,
|
||||
QuantType quant_type, WeightQuantType weight_quant_type, TypeId quant_data_type,
|
||||
double init_scale, int index) {
|
||||
int MixedBitQuantFilter(const AnfNodePtr ¶meter_node, const tensor::TensorPtr &weight,
|
||||
const PrimitivePtr &primitive, QuantType quant_type, WeightQuantType weight_quant_type,
|
||||
TypeId quant_data_type, double init_scale, int index, int preferred_dim, bool symmetry) {
|
||||
MS_CHECK_TRUE_RET(primitive != nullptr, RET_NULL_PTR);
|
||||
MS_CHECK_TRUE_RET(weight != nullptr, RET_NULL_PTR);
|
||||
auto dims = weight->shape();
|
||||
|
@ -517,17 +543,17 @@ int MixedBitQuantFilter(const AnfNodePtr &node, const tensor::TensorPtr &weight,
|
|||
const int quant_min = QuantMin(k8Bit, false, false); // -128
|
||||
const int quant_max = QuantMax(k8Bit); // 127
|
||||
MS_LOG(WARNING)
|
||||
<< node->fullname_with_scope()
|
||||
<< parameter_node->fullname_with_scope()
|
||||
<< " mixed bit quantization search failed, the current layer rolls back to 8 bit fixed quantization.";
|
||||
return FixedBitQuantFilter<int8_t>(node, weight, primitive, QuantType_QUANT_WEIGHT, quant_max, quant_min, k8Bit,
|
||||
FIXED_BIT_PER_CHANNEL, kNumberTypeInt8, index);
|
||||
return FixedBitQuantFilter<int8_t>(parameter_node, weight, primitive, QuantType_QUANT_WEIGHT, quant_max, quant_min,
|
||||
k8Bit, FIXED_BIT_PER_CHANNEL, kNumberTypeInt8, index, preferred_dim, symmetry);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
auto status =
|
||||
UpdateTensorDataAndSize(node, weight, quant_data.data(), quant_data.size() * sizeof(int16_t), quant_data_type);
|
||||
auto status = UpdateTensorDataAndSize(parameter_node, weight, quant_data.data(), quant_data.size() * sizeof(int16_t),
|
||||
quant_data_type);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -86,13 +86,14 @@ void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector<in
|
|||
|
||||
bool TensorQuantParamsInited(const schema::TensorT &tensor);
|
||||
|
||||
int MixedBitQuantFilter(const AnfNodePtr &node, const tensor::TensorPtr &weight, const PrimitivePtr &primitive,
|
||||
QuantType quant_type, WeightQuantType weight_quant_type, TypeId quant_data_type,
|
||||
double init_scale, int index);
|
||||
int MixedBitQuantFilter(const AnfNodePtr ¶meter_node, const tensor::TensorPtr &weight,
|
||||
const PrimitivePtr &primitive, QuantType quant_type, WeightQuantType weight_quant_type,
|
||||
TypeId quant_data_type, double init_scale, int index, int preferred_dim, bool symmetry);
|
||||
|
||||
int CalChannels(const std::vector<int> &dims, int channel_cnt, bool *channel_at_first);
|
||||
|
||||
int GetPreferredDim(const PrimitivePtr &primitive, int input_index, const std::vector<int> &dims);
|
||||
int GetPreferredDim(const CNodePtr &cnode, const PrimitivePtr &primitive, int input_index,
|
||||
const std::vector<int> &dims);
|
||||
|
||||
std::vector<int> ConvertShapeVectorToInt32(const ShapeVector &dims);
|
||||
|
||||
|
@ -121,10 +122,10 @@ int DeQuantData(const int8_t *tensor_data, int64_t elements_num, std::vector<lit
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
int FixedBitQuantFilter(const AnfNodePtr ¶meter, const tensor::TensorPtr &weight, const PrimitivePtr &primitive,
|
||||
QuantType quant_type, int quant_max, int quant_min, size_t bit_num,
|
||||
WeightQuantType weight_quant_type, TypeId quant_data_type, int index, bool symmetry = false,
|
||||
bool narrow_range = false, bool k_means = false) {
|
||||
int FixedBitQuantFilter(const AnfNodePtr ¶meter_node, const tensor::TensorPtr &weight,
|
||||
const PrimitivePtr &primitive, QuantType quant_type, int quant_max, int quant_min,
|
||||
size_t bit_num, WeightQuantType weight_quant_type, TypeId quant_data_type, int index,
|
||||
int preferred_dim, bool symmetry = false, bool narrow_range = false, bool k_means = false) {
|
||||
MS_ASSERT(weight != nullptr);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
auto dims = weight->shape();
|
||||
|
@ -146,7 +147,6 @@ int FixedBitQuantFilter(const AnfNodePtr ¶meter, const tensor::TensorPtr &we
|
|||
std::vector<T> quant_data(elem_count);
|
||||
int ret = RET_OK;
|
||||
if (weight_quant_type == FIXED_BIT_PER_CHANNEL) {
|
||||
int preferred_dim = GetPreferredDim(primitive, index, ConvertShapeVectorToInt32(dims));
|
||||
ret = DoPerChannelQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(),
|
||||
static_cast<mindspore::schema::QuantType>(quant_type), &quant_params, quant_max,
|
||||
quant_min, bit_num, &quant_data, ConvertShapeVectorToInt32(dims), preferred_dim,
|
||||
|
@ -166,9 +166,10 @@ int FixedBitQuantFilter(const AnfNodePtr ¶meter, const tensor::TensorPtr &we
|
|||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported weight quant type:" << weight_quant_type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status =
|
||||
UpdateTensorDataAndSize(parameter, weight, quant_data.data(), quant_data.size() * sizeof(T), quant_data_type);
|
||||
UpdateTensorDataAndSize(parameter_node, weight, quant_data.data(), quant_data.size() * sizeof(T), quant_data_type);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -100,11 +100,11 @@ int WeightQuantizer::DoCNodeWeightQuant(const FuncGraphPtr &func_graph, const CN
|
|||
MS_LOG(INFO) << "This op " << cnode->fullname_with_scope() << " can not quant weight";
|
||||
continue;
|
||||
}
|
||||
int preferred_dim = GetPreferredDim(primitive, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
|
||||
auto quant_strategy = std::make_unique<QuantStrategy>(flags_.commonQuantParam.min_quant_weight_size,
|
||||
flags_.commonQuantParam.min_quant_weight_channel,
|
||||
flags_.commonQuantParam.skip_quant_node);
|
||||
CHECK_NULL_RETURN(quant_strategy);
|
||||
int preferred_dim = GetPreferredDim(cnode, primitive, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
|
||||
if (!quant_strategy->CanTensorQuantized(cnode, input, preferred_dim)) {
|
||||
MS_LOG(INFO) << input->fullname_with_scope() << " is not quantizable";
|
||||
continue;
|
||||
|
@ -120,14 +120,16 @@ int WeightQuantizer::DoCNodeWeightQuant(const FuncGraphPtr &func_graph, const CN
|
|||
auto status = RET_ERROR;
|
||||
if (is_mixed_bit_) {
|
||||
status = MixedBitQuantFilter(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type,
|
||||
WeightQuantType::MIXED_BIT_PER_LAYER, type_id_, mixed_bit_init_scale_, idx - 1);
|
||||
WeightQuantType::MIXED_BIT_PER_LAYER, type_id_, mixed_bit_init_scale_, idx - 1,
|
||||
preferred_dim, symmetric);
|
||||
} else if (type_id_ == kNumberTypeInt8) {
|
||||
status = FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type, q_max,
|
||||
q_min, bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status =
|
||||
FixedBitQuantFilter<int16_t>(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type, q_max,
|
||||
q_min, bit_num_, tmp_weight_quant_type, type_id_, idx - 1, symmetric);
|
||||
FixedBitQuantFilter<int8_t>(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type, q_max, q_min,
|
||||
bit_num_, tmp_weight_quant_type, type_id_, idx - 1, preferred_dim, symmetric);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = FixedBitQuantFilter<int16_t>(parameter, tensor_info, primitive, flags_.commonQuantParam.quant_type,
|
||||
q_max, q_min, bit_num_, tmp_weight_quant_type, type_id_, idx - 1,
|
||||
preferred_dim, symmetric);
|
||||
}
|
||||
if (status == RET_NO_CHANGE) {
|
||||
continue;
|
||||
|
|
Loading…
Reference in New Issue