optimize full quant strategy

This commit is contained in:
yeyunpeng2020 2022-01-27 13:02:53 +08:00
parent 0342934468
commit 442ea81872
10 changed files with 56 additions and 36 deletions

View File

@ -899,7 +899,7 @@ int Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std:
} }
int ret; int ret;
#ifndef WEIGHT_DECODE_CLIP #ifndef WEIGHT_DECODE_CLIP
ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type); ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type, src_model_->version_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret; MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
return RET_NOT_SUPPORT; return RET_NOT_SUPPORT;
@ -949,7 +949,7 @@ int Scheduler::FindGpuKernel(const std::vector<Tensor *> &in_tensors, const std:
int ret; int ret;
#ifndef WEIGHT_DECODE_CLIP #ifndef WEIGHT_DECODE_CLIP
// weight dequant // weight dequant
ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32); ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32, src_model_->version_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret; MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
return RET_NOT_SUPPORT; return RET_NOT_SUPPORT;

View File

@ -353,8 +353,8 @@ int TrainExport::ExportTensor(const Model *model, const std::vector<mindspore::l
size_t pid = id - static_cast<size_t>(offset); size_t pid = id - static_cast<size_t>(offset);
mindspore::lite::Tensor *tensor = tensors.at(pid); mindspore::lite::Tensor *tensor = tensors.at(pid);
schema::Tensor *scTensor = model->all_tensors_.at(pid); schema::Tensor *scTensor = model->all_tensors_.at(pid);
auto preferred_dim = auto preferred_dim = WeightDecoder::GetPreferredDim(index.second.op_parameter, index.second.input_index,
WeightDecoder::GetPreferredDim(index.second.op_parameter, index.second.input_index, tensor->shape()); tensor->shape(), model->version_);
auto tensorT = CreateTensor(tensor, scTensor, preferred_dim); auto tensorT = CreateTensor(tensor, scTensor, preferred_dim);
if (tensorT == nullptr) { if (tensorT == nullptr) {
MS_LOG(ERROR) << "error in tensor creation"; MS_LOG(ERROR) << "error in tensor creation";

View File

@ -355,15 +355,15 @@ int WeightDecoder::UnPack(const SchemaTensorWrapper &src_tensor, lite::Tensor *d
return ret; return ret;
} }
int WeightDecoder::DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors, int WeightDecoder::DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors, TypeId dst_data_type,
TypeId dst_data_type) { const std::string &model_version) {
if (op_parameter->quant_type_ != schema::QuantType_QUANT_WEIGHT) { if (op_parameter->quant_type_ != schema::QuantType_QUANT_WEIGHT) {
return RET_OK; return RET_OK;
} }
int index = 0; int index = 0;
for (auto &tensor : in_tensors) { for (auto &tensor : in_tensors) {
MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR); MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
auto preferred_dim = GetPreferredDim(op_parameter, index++, tensor->shape()); auto preferred_dim = GetPreferredDim(op_parameter, index++, tensor->shape(), model_version);
auto ret = WeightDecoder::DequantTensor(tensor, preferred_dim, dst_data_type); auto ret = WeightDecoder::DequantTensor(tensor, preferred_dim, dst_data_type);
if (ret != RET_OK && ret != RET_NO_CHANGE) { if (ret != RET_OK && ret != RET_NO_CHANGE) {
MS_LOG(DEBUG) << "Dequant tensor failed"; MS_LOG(DEBUG) << "Dequant tensor failed";
@ -416,7 +416,26 @@ int WeightDecoder::GetMatMulPreferredDim(OpParameter *op_parameter, int input_in
return 0; return 0;
} }
int WeightDecoder::GetPreferredDim(OpParameter *op_parameter, int index, const std::vector<int> &dims) { bool IsChannelFirst(int index, const OpParameter *op_parameter) {
MS_ASSERT(op_parameter != nullptr);
if (op_parameter->type_ == schema::PrimitiveType_MatMulFusion) {
const auto *param = reinterpret_cast<const MatMulParameter *>(op_parameter);
if (index == 0) {
return !(param->a_transpose_);
} else if (index == 1) {
return param->b_transpose_;
}
}
return true;
}
int WeightDecoder::GetPreferredDim(OpParameter *op_parameter, int index, const std::vector<int> &dims,
const std::string &model_version) {
const int first_version_offset = 5;
if (model_version.empty() ||
model_version.substr(model_version.size() - first_version_offset, model_version.size()) < "1.6.0") {
return IsChannelFirst(index, op_parameter) ? 0 : 1;
}
if (op_parameter->type_ == schema::PrimitiveType_MatMulFusion) { if (op_parameter->type_ == schema::PrimitiveType_MatMulFusion) {
return GetMatMulPreferredDim(op_parameter, index, dims); return GetMatMulPreferredDim(op_parameter, index, dims);
} }

View File

@ -132,11 +132,13 @@ int GetDataIndex(const std::vector<int> &dims, int preferred_dim, int bucket_ind
class WeightDecoder { class WeightDecoder {
public: public:
static int DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors, TypeId dst_data_type); static int DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors, TypeId dst_data_type,
const std::string &model_version);
static int UnPack(const SchemaTensorWrapper &src_tensor, lite::Tensor *dst_tensor); static int UnPack(const SchemaTensorWrapper &src_tensor, lite::Tensor *dst_tensor);
static int GetPreferredDim(OpParameter *op_parameter, int index, const std::vector<int> &dims); static int GetPreferredDim(OpParameter *op_parameter, int index, const std::vector<int> &dims,
const std::string &model_version);
template <typename ST, typename DT = float> template <typename ST, typename DT = float>
static DT *DequantData(const lite::Tensor *input_tensor, int preferred_dim) { static DT *DequantData(const lite::Tensor *input_tensor, int preferred_dim) {

View File

@ -1,10 +1,10 @@
ml_face_mnet 58.2 827064 ml_face_mnet 85 832368
ml_face_landmark_2 0.8 472136 ml_face_landmark_2 0.8 472136
mobilenet.tflite 0.4 26040 mobilenet.tflite 0.4 26040
transformer_20200831_encoder_fp32.tflite;36 13.4 54319144 transformer_20200831_encoder_fp32.tflite;36 20 54319144
transformer_20200831_decoder_fp32.tflite;11 10.0 12970744 transformer_20200831_decoder_fp32.tflite;11 17 15425680
ml_face_mnet_image 47.9 827072 ml_face_mnet_image 61 832360
resnet.tflite 0.4 69272 resnet.tflite 0.4 69272
0916_ct_ddd_culane_dlav0_withSigmoid_noMerge.onnx 46.6 22487224 0916_ct_ddd_culane_dlav0_withSigmoid_noMerge.onnx 47 22487224
v3plus512_512_op11.onnx 43.1 6027648 v3plus512_512_op11.onnx 43.1 6028728
resnet_image.mindir 7.0 38911216 resnet_image.mindir 7.0 38911216

View File

@ -87,17 +87,19 @@ int PreprocessParser::ParsePreprocess(const DataPreProcessString &data_pre_proce
preprocess::ConvertColorConversionCodes(data_pre_process->image_pre_process.image_to_format); preprocess::ConvertColorConversionCodes(data_pre_process->image_pre_process.image_to_format);
} }
} }
ret = ParseImagePreProcess(data_pre_process_str, &data_pre_process->image_pre_process); if (!data_pre_process_str.calibrate_path.empty() && !data_pre_process_str.calibrate_size.empty()) {
if (ret != RET_OK) { ret = ParseImagePreProcess(data_pre_process_str, &data_pre_process->image_pre_process);
MS_LOG(ERROR) << "image preprocess parse failed."; if (ret != RET_OK) {
return ret; MS_LOG(ERROR) << "image preprocess parse failed.";
} return ret;
}
ret = CollectCalibInputs(data_pre_process->calibrate_path, data_pre_process->calibrate_size, ret = CollectCalibInputs(data_pre_process->calibrate_path, data_pre_process->calibrate_size,
&data_pre_process->calibrate_path_vector); &data_pre_process->calibrate_path_vector);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "collect calibrate inputs failed."; MS_LOG(ERROR) << "collect calibrate inputs failed.";
return ret; return ret;
}
} }
return RET_OK; return RET_OK;
} }
@ -243,7 +245,7 @@ int PreprocessParser::ParseImageResize(const DataPreProcessString &data_pre_proc
} }
if (!data_pre_process_str.resize_height.empty()) { if (!data_pre_process_str.resize_height.empty()) {
if (!ConvertIntNum(data_pre_process_str.resize_height, &image_pre_process->resize_height)) { if (!ConvertIntNum(data_pre_process_str.resize_height, &image_pre_process->resize_height)) {
MS_LOG(ERROR) << "resize_width should be a valid number."; MS_LOG(ERROR) << "resize_height should be a valid number.";
return RET_INPUT_PARAM_INVALID; return RET_INPUT_PARAM_INVALID;
} }
if (image_pre_process->resize_height <= kMinSize || image_pre_process->resize_height > kMaxSize) { if (image_pre_process->resize_height <= kMinSize || image_pre_process->resize_height > kMaxSize) {

View File

@ -195,7 +195,8 @@ int DebugInfoManager::SetOriginStaticInfo(QuantDebugInfo *quant_debug_info, cons
int DebugInfoManager::SetQuantStaticInfo(OpParameter *op_parameter, int tensor_index, QuantDebugInfo *quant_debug_info, int DebugInfoManager::SetQuantStaticInfo(OpParameter *op_parameter, int tensor_index, QuantDebugInfo *quant_debug_info,
const mindspore::lite::Tensor &tensor) { const mindspore::lite::Tensor &tensor) {
auto preferred_dim = mindspore::lite::WeightDecoder::GetPreferredDim(op_parameter, tensor_index, tensor.shape()); auto preferred_dim =
mindspore::lite::WeightDecoder::GetPreferredDim(op_parameter, tensor_index, tensor.shape(), Version());
float *quant_data; float *quant_data;
if (tensor.data_type() == kNumberTypeInt8) { if (tensor.data_type() == kNumberTypeInt8) {
quant_data = mindspore::lite::WeightDecoder::DequantData<int8_t, float>(&tensor, preferred_dim); quant_data = mindspore::lite::WeightDecoder::DequantData<int8_t, float>(&tensor, preferred_dim);

View File

@ -389,13 +389,6 @@ void FullQuantQuantizer::InitCpuConfig() {
prim::kPrimTranspose, prim::kPrimTranspose,
prim::kPrimShape, prim::kPrimShape,
prim::kPrimUnsqueeze, prim::kPrimUnsqueeze,
prim::kPrimSplit,
prim::kPrimTupleGetItem,
prim::kPrimConcat,
prim::kPrimCrop,
prim::kPrimGather,
prim::kPrimReduceFusion,
prim::kPrimAffine,
}; };
skip_check_dtype_ops_ = {prim::kPrimTupleGetItem, prim::kPrimShape}; skip_check_dtype_ops_ = {prim::kPrimTupleGetItem, prim::kPrimShape};
per_channel_ops_ = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion, prim::kPrimMatMulFusion, per_channel_ops_ = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion, prim::kPrimMatMulFusion,

View File

@ -20,9 +20,9 @@
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
namespace mindspore::lite { namespace mindspore::lite {
bool CarryDataQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) { bool CarryDataQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
MS_ASSERT(node->inputIndex.size() >= 1); MS_ASSERT(node->inputIndex.size() >= 1);
MS_ASSERT(node->outputIndex.size() >= 1); MS_ASSERT(node->outputIndex.size() >= 1);
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
// check first in tensor // check first in tensor
MS_ASSERT(graph.allTensors.size() > node->inputIndex.at(0)); MS_ASSERT(graph.allTensors.size() > node->inputIndex.at(0));

View File

@ -251,6 +251,7 @@ SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const conv
auto status = fb_transform.Transform(flags); auto status = fb_transform.Transform(flags);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "FBTransform model failed"; MS_LOG(ERROR) << "FBTransform model failed";
delete meta_graph;
return sm; return sm;
} }
meta_graph->version = Version(); meta_graph->version = Version();
@ -263,11 +264,13 @@ SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const conv
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer()); auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
if (content == nullptr) { if (content == nullptr) {
MS_LOG(ERROR) << "GetBufferPointer return null"; MS_LOG(ERROR) << "GetBufferPointer return null";
delete meta_graph;
return sm; return sm;
} }
auto model = lite::Model::Import(content, *size); auto model = lite::Model::Import(content, *size);
if (model == nullptr) { if (model == nullptr) {
MS_LOG(ERROR) << "Import model failed"; MS_LOG(ERROR) << "Import model failed";
delete meta_graph;
return sm; return sm;
} }
Context ctx; Context ctx;