optimize full quant strategy
This commit is contained in:
parent
0342934468
commit
442ea81872
|
@ -899,7 +899,7 @@ int Scheduler::FindCpuKernel(const std::vector<Tensor *> &in_tensors, const std:
|
|||
}
|
||||
int ret;
|
||||
#ifndef WEIGHT_DECODE_CLIP
|
||||
ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type);
|
||||
ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kernel_data_type, src_model_->version_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
|
||||
return RET_NOT_SUPPORT;
|
||||
|
@ -949,7 +949,7 @@ int Scheduler::FindGpuKernel(const std::vector<Tensor *> &in_tensors, const std:
|
|||
int ret;
|
||||
#ifndef WEIGHT_DECODE_CLIP
|
||||
// weight dequant
|
||||
ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32);
|
||||
ret = WeightDecoder::DequantNode(op_parameter, in_tensors, kNumberTypeFloat32, src_model_->version_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(DEBUG) << "Dequant input tensors failed: " << ret;
|
||||
return RET_NOT_SUPPORT;
|
||||
|
|
|
@ -353,8 +353,8 @@ int TrainExport::ExportTensor(const Model *model, const std::vector<mindspore::l
|
|||
size_t pid = id - static_cast<size_t>(offset);
|
||||
mindspore::lite::Tensor *tensor = tensors.at(pid);
|
||||
schema::Tensor *scTensor = model->all_tensors_.at(pid);
|
||||
auto preferred_dim =
|
||||
WeightDecoder::GetPreferredDim(index.second.op_parameter, index.second.input_index, tensor->shape());
|
||||
auto preferred_dim = WeightDecoder::GetPreferredDim(index.second.op_parameter, index.second.input_index,
|
||||
tensor->shape(), model->version_);
|
||||
auto tensorT = CreateTensor(tensor, scTensor, preferred_dim);
|
||||
if (tensorT == nullptr) {
|
||||
MS_LOG(ERROR) << "error in tensor creation";
|
||||
|
|
|
@ -355,15 +355,15 @@ int WeightDecoder::UnPack(const SchemaTensorWrapper &src_tensor, lite::Tensor *d
|
|||
return ret;
|
||||
}
|
||||
|
||||
int WeightDecoder::DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors,
|
||||
TypeId dst_data_type) {
|
||||
int WeightDecoder::DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors, TypeId dst_data_type,
|
||||
const std::string &model_version) {
|
||||
if (op_parameter->quant_type_ != schema::QuantType_QUANT_WEIGHT) {
|
||||
return RET_OK;
|
||||
}
|
||||
int index = 0;
|
||||
for (auto &tensor : in_tensors) {
|
||||
MS_CHECK_TRUE_RET(tensor != nullptr, RET_ERROR);
|
||||
auto preferred_dim = GetPreferredDim(op_parameter, index++, tensor->shape());
|
||||
auto preferred_dim = GetPreferredDim(op_parameter, index++, tensor->shape(), model_version);
|
||||
auto ret = WeightDecoder::DequantTensor(tensor, preferred_dim, dst_data_type);
|
||||
if (ret != RET_OK && ret != RET_NO_CHANGE) {
|
||||
MS_LOG(DEBUG) << "Dequant tensor failed";
|
||||
|
@ -416,7 +416,26 @@ int WeightDecoder::GetMatMulPreferredDim(OpParameter *op_parameter, int input_in
|
|||
return 0;
|
||||
}
|
||||
|
||||
int WeightDecoder::GetPreferredDim(OpParameter *op_parameter, int index, const std::vector<int> &dims) {
|
||||
bool IsChannelFirst(int index, const OpParameter *op_parameter) {
|
||||
MS_ASSERT(op_parameter != nullptr);
|
||||
if (op_parameter->type_ == schema::PrimitiveType_MatMulFusion) {
|
||||
const auto *param = reinterpret_cast<const MatMulParameter *>(op_parameter);
|
||||
if (index == 0) {
|
||||
return !(param->a_transpose_);
|
||||
} else if (index == 1) {
|
||||
return param->b_transpose_;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int WeightDecoder::GetPreferredDim(OpParameter *op_parameter, int index, const std::vector<int> &dims,
|
||||
const std::string &model_version) {
|
||||
const int first_version_offset = 5;
|
||||
if (model_version.empty() ||
|
||||
model_version.substr(model_version.size() - first_version_offset, model_version.size()) < "1.6.0") {
|
||||
return IsChannelFirst(index, op_parameter) ? 0 : 1;
|
||||
}
|
||||
if (op_parameter->type_ == schema::PrimitiveType_MatMulFusion) {
|
||||
return GetMatMulPreferredDim(op_parameter, index, dims);
|
||||
}
|
||||
|
|
|
@ -132,11 +132,13 @@ int GetDataIndex(const std::vector<int> &dims, int preferred_dim, int bucket_ind
|
|||
|
||||
class WeightDecoder {
|
||||
public:
|
||||
static int DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors, TypeId dst_data_type);
|
||||
static int DequantNode(OpParameter *op_parameter, const std::vector<Tensor *> &in_tensors, TypeId dst_data_type,
|
||||
const std::string &model_version);
|
||||
|
||||
static int UnPack(const SchemaTensorWrapper &src_tensor, lite::Tensor *dst_tensor);
|
||||
|
||||
static int GetPreferredDim(OpParameter *op_parameter, int index, const std::vector<int> &dims);
|
||||
static int GetPreferredDim(OpParameter *op_parameter, int index, const std::vector<int> &dims,
|
||||
const std::string &model_version);
|
||||
|
||||
template <typename ST, typename DT = float>
|
||||
static DT *DequantData(const lite::Tensor *input_tensor, int preferred_dim) {
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
ml_face_mnet 58.2 827064
|
||||
ml_face_mnet 85 832368
|
||||
ml_face_landmark_2 0.8 472136
|
||||
mobilenet.tflite 0.4 26040
|
||||
transformer_20200831_encoder_fp32.tflite;36 13.4 54319144
|
||||
transformer_20200831_decoder_fp32.tflite;11 10.0 12970744
|
||||
ml_face_mnet_image 47.9 827072
|
||||
transformer_20200831_encoder_fp32.tflite;36 20 54319144
|
||||
transformer_20200831_decoder_fp32.tflite;11 17 15425680
|
||||
ml_face_mnet_image 61 832360
|
||||
resnet.tflite 0.4 69272
|
||||
0916_ct_ddd_culane_dlav0_withSigmoid_noMerge.onnx 46.6 22487224
|
||||
v3plus512_512_op11.onnx 43.1 6027648
|
||||
0916_ct_ddd_culane_dlav0_withSigmoid_noMerge.onnx 47 22487224
|
||||
v3plus512_512_op11.onnx 43.1 6028728
|
||||
resnet_image.mindir 7.0 38911216
|
||||
|
|
|
@ -87,17 +87,19 @@ int PreprocessParser::ParsePreprocess(const DataPreProcessString &data_pre_proce
|
|||
preprocess::ConvertColorConversionCodes(data_pre_process->image_pre_process.image_to_format);
|
||||
}
|
||||
}
|
||||
ret = ParseImagePreProcess(data_pre_process_str, &data_pre_process->image_pre_process);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "image preprocess parse failed.";
|
||||
return ret;
|
||||
}
|
||||
if (!data_pre_process_str.calibrate_path.empty() && !data_pre_process_str.calibrate_size.empty()) {
|
||||
ret = ParseImagePreProcess(data_pre_process_str, &data_pre_process->image_pre_process);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "image preprocess parse failed.";
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = CollectCalibInputs(data_pre_process->calibrate_path, data_pre_process->calibrate_size,
|
||||
&data_pre_process->calibrate_path_vector);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "collect calibrate inputs failed.";
|
||||
return ret;
|
||||
ret = CollectCalibInputs(data_pre_process->calibrate_path, data_pre_process->calibrate_size,
|
||||
&data_pre_process->calibrate_path_vector);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "collect calibrate inputs failed.";
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
@ -243,7 +245,7 @@ int PreprocessParser::ParseImageResize(const DataPreProcessString &data_pre_proc
|
|||
}
|
||||
if (!data_pre_process_str.resize_height.empty()) {
|
||||
if (!ConvertIntNum(data_pre_process_str.resize_height, &image_pre_process->resize_height)) {
|
||||
MS_LOG(ERROR) << "resize_width should be a valid number.";
|
||||
MS_LOG(ERROR) << "resize_height should be a valid number.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
if (image_pre_process->resize_height <= kMinSize || image_pre_process->resize_height > kMaxSize) {
|
||||
|
|
|
@ -195,7 +195,8 @@ int DebugInfoManager::SetOriginStaticInfo(QuantDebugInfo *quant_debug_info, cons
|
|||
|
||||
int DebugInfoManager::SetQuantStaticInfo(OpParameter *op_parameter, int tensor_index, QuantDebugInfo *quant_debug_info,
|
||||
const mindspore::lite::Tensor &tensor) {
|
||||
auto preferred_dim = mindspore::lite::WeightDecoder::GetPreferredDim(op_parameter, tensor_index, tensor.shape());
|
||||
auto preferred_dim =
|
||||
mindspore::lite::WeightDecoder::GetPreferredDim(op_parameter, tensor_index, tensor.shape(), Version());
|
||||
float *quant_data;
|
||||
if (tensor.data_type() == kNumberTypeInt8) {
|
||||
quant_data = mindspore::lite::WeightDecoder::DequantData<int8_t, float>(&tensor, preferred_dim);
|
||||
|
|
|
@ -389,13 +389,6 @@ void FullQuantQuantizer::InitCpuConfig() {
|
|||
prim::kPrimTranspose,
|
||||
prim::kPrimShape,
|
||||
prim::kPrimUnsqueeze,
|
||||
prim::kPrimSplit,
|
||||
prim::kPrimTupleGetItem,
|
||||
prim::kPrimConcat,
|
||||
prim::kPrimCrop,
|
||||
prim::kPrimGather,
|
||||
prim::kPrimReduceFusion,
|
||||
prim::kPrimAffine,
|
||||
};
|
||||
skip_check_dtype_ops_ = {prim::kPrimTupleGetItem, prim::kPrimShape};
|
||||
per_channel_ops_ = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion, prim::kPrimMatMulFusion,
|
||||
|
|
|
@ -20,9 +20,9 @@
|
|||
#include "nnacl/op_base.h"
|
||||
namespace mindspore::lite {
|
||||
bool CarryDataQuantTypeDeterminer::DetermineQuantAll(const schema::MetaGraphT &graph, schema::CNodeT *node) {
|
||||
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
|
||||
MS_ASSERT(node->inputIndex.size() >= 1);
|
||||
MS_ASSERT(node->outputIndex.size() >= 1);
|
||||
MS_CHECK_TRUE_MSG(node != nullptr, false, "node is nullptr.");
|
||||
|
||||
// check first in tensor
|
||||
MS_ASSERT(graph.allTensors.size() > node->inputIndex.at(0));
|
||||
|
|
|
@ -251,6 +251,7 @@ SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const conv
|
|||
auto status = fb_transform.Transform(flags);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "FBTransform model failed";
|
||||
delete meta_graph;
|
||||
return sm;
|
||||
}
|
||||
meta_graph->version = Version();
|
||||
|
@ -263,11 +264,13 @@ SessionModel CreateSessionByFuncGraph(const FuncGraphPtr &func_graph, const conv
|
|||
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
|
||||
if (content == nullptr) {
|
||||
MS_LOG(ERROR) << "GetBufferPointer return null";
|
||||
delete meta_graph;
|
||||
return sm;
|
||||
}
|
||||
auto model = lite::Model::Import(content, *size);
|
||||
if (model == nullptr) {
|
||||
MS_LOG(ERROR) << "Import model failed";
|
||||
delete meta_graph;
|
||||
return sm;
|
||||
}
|
||||
Context ctx;
|
||||
|
|
Loading…
Reference in New Issue