From 33ab40af78adff4b16150ba5a79be1fdcc4a2357 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Mon, 12 Apr 2021 18:40:51 +0800 Subject: [PATCH] [MS][LITE] support perchannel mindir quant model --- mindspore/lite/test/models_mindspore.cfg | 1 + mindspore/lite/tools/common/node_util.cc | 10 +++ mindspore/lite/tools/common/node_util.h | 3 +- .../graph/tensor_quant_pass.cc | 67 ++++++++++++++++++- .../concat_quant_param_propogator.cc | 6 ++ .../conv_quant_param_propogator.cc | 31 +++++---- .../converter/quantizer/quantize_util.cc | 28 ++++++++ .../tools/converter/quantizer/quantize_util.h | 3 + .../optimizer/graph/mindir_adjust_pass.cc | 18 ++--- 9 files changed, 140 insertions(+), 27 deletions(-) diff --git a/mindspore/lite/test/models_mindspore.cfg b/mindspore/lite/test/models_mindspore.cfg index 6fe2d77fb68..c5040089d92 100644 --- a/mindspore/lite/test/models_mindspore.cfg +++ b/mindspore/lite/test/models_mindspore.cfg @@ -2,3 +2,4 @@ deeplabv3.r1.1.mindir 1.5 mobilenetv2.r1.1.mindir 0.5 ssd.r1.1.mindir 0.5 ssd_ghostnet.r1.1.mindir 2.0 +lenet_quant.mindir 0.5 diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 2c7772850a5..98daaae2aac 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -382,6 +382,16 @@ STATUS NodeInferShpae(const schema::CNodeT &node, const std::vector &i return ret; } +size_t GetTensorInputIndexInCNode(const uint32_t &tensor_index, const schema::CNodeT &cnode) { + size_t ret = -1; + for (size_t i = 0; i < cnode.inputIndex.size(); i++) { + if (cnode.inputIndex.at(i) == tensor_index) { + ret = i; + } + } + return ret; +} + STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { if (tensor == nullptr) { MS_LOG(ERROR) << "tensor is null"; diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index dd430594d1d..e4bf078bb00 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -71,11 +71,12 @@ std::unordered_map> GetExtNhwcIndexes(); std::vector Getfp32FullOpList(); std::vector GetUint8NhwcOpList(); - std::vector GetInt8OpList(); const schema::Primitive *ConvertToPrimitive(schema::PrimitiveT *primitive_t, flatbuffers::FlatBufferBuilder *fbb); +size_t GetTensorInputIndexInCNode(const uint32_t &tensor_index, const schema::CNodeT &cnode); + class NodeUtils { public: static STATUS ConvertDims(schema::Format src_format, const std::vector &src_dims, schema::Format dst_format, diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc index ade075a6675..7ea6ee35a54 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_quant_pass.cc @@ -20,6 +20,8 @@ #include "tools/converter/converter_context.h" #include "tools/converter/quantizer/quantize_util.h" #include "tools/common/tensor_util.h" +#include "tools/common/graph_util.h" +#include "tools/common/node_util.h" namespace mindspore::lite { namespace { @@ -112,6 +114,62 @@ STATUS ComputeDataToInt32(const std::unique_ptr &tensor) { } return RET_OK; } + +STATUS ComputeQuantTensorPerChannel(TensorT *tensor, const int &tensor_index, const schema::MetaGraphT &graph) { + bool channel_at_first = true; + int channel_cnt = -1; + auto used_nodes_idx = GetLinkedPostIdx(graph, tensor_index); + if (used_nodes_idx.size() != 1) { + MS_LOG(ERROR) << "Tensor is used by nodes more than one"; + return RET_ERROR; + } + auto &used_node = graph.nodes.at(used_nodes_idx.front()); + auto &primitive = used_node->primitive; + int input_index = GetTensorInputIndexInCNode(tensor_index, *used_node); + quant::CalQuantAssitInfo(*primitive, tensor->dims, input_index, &channel_at_first, &channel_cnt); + + auto *raw_datas = reinterpret_cast(tensor->data.data()); + ShapeVector dims; + std::transform(tensor->dims.begin(), tensor->dims.end(), std::back_inserter(dims), + [&](int32_t dim) { return (int64_t)dim; }); + auto channels = quant::CalChannels(dims, channel_cnt, &channel_at_first); + if (channels == 0) { + MS_LOG(ERROR) << "channels is zero"; + return RET_ERROR; + } + int32_t dst_dtype = tensor->quantParams.front()->dstDtype == kNumberTypeInt32 ? kNumberTypeInt32 : kNumberTypeInt8; + size_t elem_count = tensor->data.size() / sizeof(float); + size_t data_size = dst_dtype == kNumberTypeInt32 ? elem_count * sizeof(int32_t) : elem_count * sizeof(int8_t); + std::vector dst_data(data_size); + size_t one_filter_size = elem_count / channels; + for (int i = 0; i < channels; i++) { + // do quantization + for (uint32_t j = 0; j < one_filter_size; j++) { + auto index = j + i * one_filter_size; + if (!channel_at_first) { + index = j * channels + i; + } + MS_ASSERT(index < elem_count); + float raw_data = raw_datas[index]; + if (tensor->quantParams.at(i)->dstDtype == kNumberTypeInt32) { + auto quant_data = (int32_t)std::round(raw_datas[i] / tensor->quantParams.at(i)->scale); + auto *dst_data_int32 = reinterpret_cast(dst_data.data()); + dst_data_int32[index] = quant_data; + } else { + auto quant_data = quant::QuantizeData(raw_data, tensor->quantParams.at(i).get()); + dst_data[index] = quant_data; + } + } + } + tensor->data.clear(); + tensor->data.resize(data_size); + tensor->dataType = dst_dtype; + if (memcpy_s(tensor->data.data(), data_size, dst_data.data(), data_size) != EOK) { + MS_LOG(ERROR) << "memcpy_s failed"; + return RET_ERROR; + } + return RET_OK; +} } // namespace STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { @@ -133,8 +191,13 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) { continue; } if (tensor->quantParams.size() != 1) { // perchannel - MS_LOG(ERROR) << "perchannel do quant is not supported yet"; - return RET_ERROR; + status = ComputeQuantTensorPerChannel(tensor.get(), index, *graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "compute tensor to int8 prechannel failed."; + return RET_ERROR; + } + index++; + continue; } // perlayer auto &quantParam = tensor->quantParams.front(); diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/concat_quant_param_propogator.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/concat_quant_param_propogator.cc index ecf87296585..35264c8130d 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/concat_quant_param_propogator.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/concat_quant_param_propogator.cc @@ -52,6 +52,12 @@ STATUS ConcatQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaG MS_ASSERT(narrow_range == quantParam->narrowRange); MS_ASSERT(num_bits == quantParam->numBits); } + + if (in_quant_param->max < in_quant_param->min) { + MS_LOG(DEBUG) << "Input quant param is invalid for propogator"; + return RET_ERROR; + } + if (min_min > in_quant_param->min) { min_min = in_quant_param->min; } diff --git a/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_param_propogator.cc b/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_param_propogator.cc index 94348068cbc..006871f0fb5 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_param_propogator.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_helper/conv_quant_param_propogator.cc @@ -35,23 +35,22 @@ STATUS ConvQuantParamPropogator::PropogateQuantParams(mindspore::schema::MetaGra return RET_OK; } auto &input_quant_param = input_tensor->quantParams.at(0); - auto &weight_quant_param = weight_tensor->quantParams.at(0); - - if (bias_tensor->quantParams.empty()) { - auto tmp_quant_param = std::make_unique(); - bias_tensor->quantParams.emplace_back(std::move(tmp_quant_param)); + std::vector> bias_quant_params; + for (auto &weight_quant_param : weight_tensor->quantParams) { + auto bias_quant_param = std::make_unique(); + bias_quant_param->min = 0.0; + bias_quant_param->max = 0.0; + bias_quant_param->dstDtype = kNumberTypeInt32; + bias_quant_param->inited = input_quant_param->inited && weight_quant_param->inited; + bias_quant_param->zeroPoint = 0; + if (bias_quant_param->inited) { + bias_quant_param->scale = input_quant_param->scale * weight_quant_param->scale; + } + bias_quant_param->roundType = 1; + bias_quant_param->multiplier = 1; + bias_quant_params.emplace_back(std::move(bias_quant_param)); } - auto &bias_quant_param = bias_tensor->quantParams.front(); - bias_quant_param->min = 0.0; - bias_quant_param->max = 0.0; - bias_quant_param->dstDtype = kNumberTypeInt32; - bias_quant_param->inited = input_quant_param->inited && weight_quant_param->inited; - bias_quant_param->zeroPoint = 0; - if (bias_quant_param->inited) { - bias_quant_param->scale = input_quant_param->scale * weight_quant_param->scale; - } - bias_quant_param->roundType = 1; - bias_quant_param->multiplier = 1; + bias_tensor->quantParams = std::move(bias_quant_params); } for (auto &quantParam : bias_tensor->quantParams) { quantParam->dstDtype = TypeId::kNumberTypeInt32; diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index a8482ad36e0..d91eb12437c 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -1042,4 +1042,32 @@ void CalQuantAssitInfo(const PrimitivePtr &primitive, const ShapeVector &shapes, } } +void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector &shapes, int index, + bool *channel_at_first, int *channel_cnt) { + if (primitive.value.type == schema::PrimitiveType_MatMul && static_cast(shapes.size()) == 2) { + auto matmul_prim = primitive.value.AsMatMul(); + MS_ASSERT(matmul_prim != nullptr); + *channel_at_first = index != 1 || matmul_prim->transpose_b; + } else if (primitive.value.type == schema::PrimitiveType_LSTM) { + if (index == 1 || index == 2) { + if (shapes.size() != 3) { + MS_LOG(WARNING) << "unexpected lstm shape size: " << shapes.size(); + } else { + *channel_cnt = shapes[0] * shapes[1]; + } + } else if (index == 3) { + if (shapes.size() != 2) { + MS_LOG(WARNING) << "unexpected lstm shape size: " << shapes.size(); + } else { + auto tensor_elem_cnt = shapes[0] * shapes[1]; + if (tensor_elem_cnt / 4 * 4 == tensor_elem_cnt) { + *channel_cnt = 4; + } + } + } else { + MS_LOG(WARNING) << "unexpected index of lstm: " << index; + } + } +} + } // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index 8c12f78f66f..0324b0bd464 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -120,6 +120,9 @@ int CalChannels(const ShapeVector &dims, int channel_cnt, bool *channel_at_first void CalQuantAssitInfo(const PrimitivePtr &primitive, const ShapeVector &shapes, int index, bool *channel_at_first, int *channel_cnt); +void CalQuantAssitInfo(const schema::PrimitiveT &primitive, const std::vector &shapes, int index, + bool *channel_at_first, int *channel_cnt); + template T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { MS_ASSERT(quantParam != nullptr); diff --git a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc index a4dacbe378d..210c686b35f 100644 --- a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc @@ -69,17 +69,19 @@ int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t quant_param.min = FLT_MAX; quant_param.max = FLT_MIN; for (int i = 0; i < filterMinPtr->ElementsNum(); ++i) { - quant_param.min = (*(minBuf) < quant_param.min) ? (*minBuf) : quant_param.min; - quant_param.max = (*(maxBuf) > quant_param.max) ? (*maxBuf) : quant_param.max; + schema::QuantParamT tmp_quant_param; + tmp_quant_param.min = *minBuf; + tmp_quant_param.max = *maxBuf; + auto ret = + lite::quant::CalQuantizationParams(&tmp_quant_param, tmp_quant_param.min, tmp_quant_param.max, true, numbits); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Can't calculate quant parameters"; + return ret; + } + quants.emplace_back(tmp_quant_param); minBuf++; maxBuf++; } - auto ret = lite::quant::CalQuantizationParams(&quant_param, quant_param.min, quant_param.max, true, numbits); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Can't calculate quant parameters"; - return ret; - } - quants.emplace_back(quant_param); quant_param_holder->set_input_quant_param(1, quants); } return lite::RET_OK;