====weight quant======

This commit is contained in:
kai00 2020-08-31 20:33:56 +08:00
parent d27178bf2b
commit 0f2c78253e
18 changed files with 479 additions and 185 deletions

View File

@ -103,8 +103,9 @@ int ModelImpl::BuildOps() {
auto cNode = meta_graph_->nodes()->GetAs<schema::CNode>(i);
auto name = cNode->name()->str();
auto srcPrim = cNode->primitive();
this->ops_[name] = PrimitiveC::UnPackFromSchemaPrimitive(const_cast<schema::Primitive *>(srcPrim));
auto prim = PrimitiveC::UnPackFromSchemaPrimitive(const_cast<schema::Primitive *>(srcPrim));
prim->SetQuantType(cNode->quantType());
this->ops_[name] = prim;
}
return 0;
}

View File

@ -688,6 +688,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(const schema::Primitive *primi
}
return nullptr;
}
void PrimitiveC::SetQuantType(schema::QuantType quant_type) {
this->quant_type_ = quant_type;
}
schema::QuantType PrimitiveC::GetQuantType() const { return quant_type_;}
#endif
int PrimitiveC::Type() const {

View File

@ -145,6 +145,9 @@ class PrimitiveC {
int Type() const;
void SetQuantType(schema::QuantType quant_type);
schema::QuantType GetQuantType() const;
protected:
template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>>
static PrimitiveC *NewPrimitiveC(const schema::Primitive *primitive) {
@ -194,6 +197,7 @@ class PrimitiveC {
const schema::Primitive *primitive_ = nullptr;
char *primitive_buf_ = nullptr;
bool infer_flag_ = true;
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
};
#endif
} // namespace lite

View File

@ -331,4 +331,46 @@ int ConvolutionBaseCPUKernel::SetQuantParam() {
return RET_OK;
}
int ConvolutionBaseCPUKernel::RestoreFilter(lite::tensor::Tensor *input_tensor) {
MS_ASSERT(input_tensor != nullptr);
if (input_tensor->GetQuantParams().empty()) {
MS_LOG(ERROR) << "no quant param";
return RET_ERROR;
}
const auto* quant_data = static_cast<const uint8_t*>(input_tensor->Data());
auto* dequant_data = static_cast<float *>(malloc(input_tensor->DataSize() * sizeof(float)));
if (dequant_data == nullptr) {
MS_LOG(ERROR) << "malloc faile";
return RET_ERROR;
}
if (input_tensor->GetQuantParams().size() != kPerTensor) {
size_t channels = static_cast<size_t>(input_tensor->Batch());
if (input_tensor->GetQuantParams().size() != channels) {
MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels;
return RET_ERROR;
}
size_t per_channel_size = input_tensor->DataSize() / channels;
auto quant_param = input_tensor->GetQuantParams();
for (size_t i = 0; i < channels; i++) {
auto param = quant_param.at(i);
auto scale = param.scale;
auto zero_point = param.zeroPoint;
for (size_t j = 0; j < per_channel_size; j++) {
dequant_data[per_channel_size * i + j] = static_cast<float>(
(quant_data[per_channel_size * i + j] - zero_point) * scale);
}
}
} else {
auto quant_param = input_tensor->GetQuantParams();
auto param = quant_param.front();
auto scale = param.scale;
auto zero_point = param.zeroPoint;
for (int64_t j = 0; j < input_tensor->DataSize(); j++) {
dequant_data[j] = static_cast<float>((quant_data[j] - zero_point) * scale);
}
}
input_tensor->SetData(dequant_data);
return RET_OK;
}
} // namespace mindspore::kernel

View File

@ -60,6 +60,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
int SetQuantMultiplier();
int CheckResizeValid();
void FreeQuantParam();
static int RestoreFilter(lite::tensor::Tensor *input_tensor);
protected:
int tile_num_;

View File

@ -239,6 +239,12 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
}
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->Data();
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex));
}
kernel::LiteKernel *kernel;
if (kernel_h == 1 && kernel_w == 1) {
kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
@ -263,6 +269,12 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
return nullptr;
}
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return kernel;
}

View File

@ -131,6 +131,13 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::tensor::T
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->Data();
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex));
}
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
kernel::LiteKernel *kernel;
if (conv_param->input_channel_ < 32) {
@ -149,6 +156,12 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::tensor::T
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
weight_tensor->FreeData();
weight_tensor->SetData(restore_data);
}
return kernel;
}

View File

@ -64,7 +64,8 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
MS_ASSERT(dst_node != nullptr);
// add quant param
dst_node->quantType = primitive->GetQuantType();
if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining) {
if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining
|| dst_node->quantType == schema::QuantType_WeightQuant) {
MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam";
// activation
auto input_quant_params = primitive->GetInputQuantParams();
@ -103,7 +104,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
}
} else {
for (auto output_quant_param : output_quant_params[0]) {
if (tensor_output->quantParams.empty()) {
if (tensor_output->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) {
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
std::make_unique<schema::QuantParamT>(output_quant_param);
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale

View File

@ -26,6 +26,7 @@
#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/weight_quantizer.h"
using std::string;
namespace mindspore {
@ -57,11 +58,20 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
FuncGraphPtr new_graph = optimizer->Optimize(old_graph);
// quant
if (config != nullptr && config->quantType == schema::QuantType_PostTraining) {
this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8);
if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
return nullptr;
if (config != nullptr) {
if (config->quantType == schema::QuantType_PostTraining) {
this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8);
if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
return nullptr;
}
} else if (config->quantType == schema::QuantType_WeightQuant) {
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->quantSize,
config->convWeightQuantChannelThreshold, config->bitNum);
if (mQuantizer == nullptr) {
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
return nullptr;
}
}
}
if (mQuantizer != nullptr) {
@ -71,12 +81,14 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
MS_LOG(ERROR) << "Quant failed " << status;
return nullptr;
}
quant::QuantCast quant_cast;
quant_cast.SetInputDataDType(kNumberTypeFloat32);
status = quant_cast.Run(new_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "add QuantCast error";
return nullptr;
if (config->quantType == schema::QuantType_PostTraining) {
quant::QuantCast quant_cast;
quant_cast.SetInputDataDType(kNumberTypeFloat32);
status = quant_cast.Run(new_graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "add QuantCast error";
return nullptr;
}
}
}

View File

@ -36,6 +36,8 @@ Flags::Flags() {
AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128");
AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5");
AddFlag(&Flags::quantSize, "quantSize", "Weight quantization size threshold", "0");
AddFlag(&Flags::convWeightQuantChannelThreshold, "convWeightQuantChannelThreshold",
"convWeightQuantChannelThreshold", "16");
AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", "");
AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true");
}

View File

@ -191,6 +191,7 @@ STATUS WeightFormatHardCodePass::HardCodeTFLITE(const std::unique_ptr<CNodeT> &n
switch (this->quantType) {
case QuantType_AwareTraining:
case QuantType_PostTraining:
case QuantType_WeightQuant:
case QuantType_QUANT_NONE: {
if (opType == schema::PrimitiveType_Conv2D) {
weightTensor->format = schema::Format_KHWC;

View File

@ -31,7 +31,7 @@ void WeightFormatTransformPass::SetDstFormat(Format format) { this->dstFormat =
STATUS WeightFormatTransformPass::Run(MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
if (this->quantType == QuantType_AwareTraining) {
if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_WeightQuant) {
auto status = QuantDataFormatTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status;

View File

@ -11,6 +11,7 @@ add_library(quantizer_mid OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc
${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc
${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc
)
if(ENABLE_ASAN)

View File

@ -530,7 +530,8 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P
return RET_ERROR;
}
auto status =
QuantFilter(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num, perchanel, depthwise);
QuantFilter<int8_t>(paramValue, primitive_c, QuantType_PostTraining, quant_max,
quant_min, bit_num, perchanel, depthwise);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status;

View File

@ -279,171 +279,6 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
return RET_OK;
}
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) {
auto dims = weight->tensor_shape();
if (per_channel) {
if (dims.size() != 4) {
MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer.";
per_channel = false;
} else {
uint32_t channels = dims[0];
if (channels == 0) {
MS_LOG(ERROR) << "channels is 0";
return RET_ERROR;
}
}
}
vector<schema::QuantParamT> quant_params;
size_t elem_count = weight->tensor_shape_size();
auto *raw_datas = static_cast<float *>(weight->tensor_addr());
if (raw_datas == nullptr) {
MS_LOG(ERROR) << "rawDatas is nullptr";
return RET_ERROR;
}
vector<int8_t> quant_datas(elem_count);
if (per_channel) {
// notice:
// at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D
// if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK
if (depth_wise) {
// channel at last
auto channels = dims[3];
if (channels == 0) {
MS_LOG(ERROR) << "channels is zero";
return RET_ERROR;
}
size_t one_filter_size = elem_count / channels;
for (int i = 0; i < channels; i++) {
float min = FLT_MAX;
float max = -FLT_MAX;
// find min and max
for (size_t j = 0; j < one_filter_size; j++) {
auto index = i + j * channels;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
min = std::min(min, raw_datas[index]);
max = std::max(max, raw_datas[index]);
}
schema::QuantParamT quant_param;
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
}
quant_params.emplace_back(quant_param);
// do quantization
for (uint32_t j = 0; j < one_filter_size; j++) {
auto index = i + j * channels;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
float raw_data = raw_datas[index];
auto quant_data = QuantizeData<int8_t>(raw_data, quant_param, quant_max, quant_min);
quant_datas[index] = quant_data;
}
}
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(),
elem_count * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
}
weight->set_tensor_size(elem_count * sizeof(int8_t));
} else {
// channel at first
auto channels = dims[0];
if (channels == 0) {
MS_LOG(ERROR) << "channels is zero";
return RET_ERROR;
}
size_t one_filter_size = elem_count / channels;
for (int i = 0; i < channels; i++) {
float min = FLT_MAX;
float max = -FLT_MAX;
// find min and max
for (size_t j = 0; j < one_filter_size; j++) {
auto index = j + i * one_filter_size;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
min = std::min(min, raw_datas[index]);
max = std::max(max, raw_datas[index]);
}
schema::QuantParamT quant_param;
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
}
quant_params.emplace_back(quant_param);
// do quantization
for (uint32_t j = 0; j < one_filter_size; j++) {
auto index = j + i * one_filter_size;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
float raw_data = raw_datas[index];
auto quant_data = QuantizeData<int8_t>(raw_data, quant_param, quant_max, quant_min);
quant_datas[index] = quant_data;
}
}
auto ret =
memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
}
weight->set_tensor_size(elem_count * sizeof(int8_t));
}
} else {
// per layer
float min = FLT_MAX;
float max = -FLT_MIN;
for (uint32_t i = 0; i < elem_count; i++) {
// find max min
min = std::min(min, raw_datas[i]);
max = std::max(max, raw_datas[i]);
}
schema::QuantParamT quant_param;
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
}
quant_params.emplace_back(quant_param);
// update data and datatype
for (uint32_t i = 0; i < elem_count; i++) {
float raw_data = raw_datas[i];
auto quant_data = QuantizeData<int8_t>(raw_data, quant_param, quant_max, quant_min);
quant_datas[i] = quant_data;
}
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
}
weight->set_tensor_size(elem_count * sizeof(int8_t));
}
if (quant_params.empty()) {
MS_LOG(ERROR) << "quant_params empty";
return RET_ERROR;
}
primitive_c->AddInputQuantParam(quant_params);
return RET_OK;
}
STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) {
auto *rawDatas = reinterpret_cast<uint8_t *>(weight);
vector<uint8_t> qDatas(rawDatas, rawDatas + shapeSize);

View File

@ -21,6 +21,8 @@
#include <string>
#include <cmath>
#include <array>
#include <vector>
#include <algorithm>
#include "tools/converter/quantizer/quantizer.h"
#include "src/ops/primitive_c.h"
#include "include/errorcode.h"
@ -117,10 +119,171 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan
return static_cast<T>(quant_data);
}();
}
template <typename T>
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
int quant_max, int quant_min, size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false,
bool depth_wise = false);
int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) {
auto dims = weight->tensor_shape();
if (per_channel) {
if (dims.size() != 4) {
MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer.";
per_channel = false;
} else {
uint32_t channels = dims[0];
if (channels == 0) {
MS_LOG(ERROR) << "channels is 0";
return RET_ERROR;
}
}
}
std::vector<schema::QuantParamT> quant_params;
size_t elem_count = weight->tensor_shape_size();
auto *raw_datas = static_cast<float *>(weight->tensor_addr());
if (raw_datas == nullptr) {
MS_LOG(ERROR) << "rawDatas is nullptr";
return RET_ERROR;
}
std::vector<T> quant_datas(elem_count);
if (per_channel) {
// notice:
// at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D
// if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK
if (depth_wise) {
// channel at last
auto channels = dims[3];
if (channels == 0) {
MS_LOG(ERROR) << "channels is zero";
return RET_ERROR;
}
size_t one_filter_size = elem_count / channels;
for (int i = 0; i < channels; i++) {
float min = FLT_MAX;
float max = -FLT_MAX;
// find min and max
for (size_t j = 0; j < one_filter_size; j++) {
auto index = i + j * channels;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
min = std::min(min, raw_datas[index]);
max = std::max(max, raw_datas[index]);
}
schema::QuantParamT quant_param;
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
}
quant_params.emplace_back(quant_param);
// do quantization
for (uint32_t j = 0; j < one_filter_size; j++) {
auto index = i + j * channels;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
float raw_data = raw_datas[index];
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
quant_datas[index] = quant_data;
}
}
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(),
elem_count * sizeof(T));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
}
weight->set_tensor_size(elem_count * sizeof(T));
} else {
// channel at first
auto channels = dims[0];
if (channels == 0) {
MS_LOG(ERROR) << "channels is zero";
return RET_ERROR;
}
size_t one_filter_size = elem_count / channels;
for (int i = 0; i < channels; i++) {
float min = FLT_MAX;
float max = -FLT_MAX;
// find min and max
for (size_t j = 0; j < one_filter_size; j++) {
auto index = j + i * one_filter_size;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
min = std::min(min, raw_datas[index]);
max = std::max(max, raw_datas[index]);
}
schema::QuantParamT quant_param;
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
}
quant_params.emplace_back(quant_param);
// do quantization
for (uint32_t j = 0; j < one_filter_size; j++) {
auto index = j + i * one_filter_size;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
float raw_data = raw_datas[index];
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
quant_datas[index] = quant_data;
}
}
auto ret =
memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
}
weight->set_tensor_size(elem_count * sizeof(T));
}
} else {
// per layer
float min = FLT_MAX;
float max = -FLT_MIN;
for (uint32_t i = 0; i < elem_count; i++) {
// find max min
min = std::min(min, raw_datas[i]);
max = std::max(max, raw_datas[i]);
}
schema::QuantParamT quant_param;
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
}
quant_params.emplace_back(quant_param);
// update data and datatype
for (uint32_t i = 0; i < elem_count; i++) {
float raw_data = raw_datas[i];
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
quant_datas[i] = quant_data;
}
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;
return RET_ERROR;
}
weight->set_tensor_size(elem_count * sizeof(T));
}
if (quant_params.empty()) {
MS_LOG(ERROR) << "quant_params empty";
return RET_ERROR;
}
primitive_c->AddInputQuantParam(quant_params);
return RET_OK;
}
STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION);
} // namespace quant

View File

@ -0,0 +1,148 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/quantizer/weight_quantizer.h"
#include <list>
#include <string>
#include <vector>
#include "src/common/common.h"
#include "ir/dtype/type_id.h"
using std::string;
using std::vector;
namespace mindspore {
namespace lite {
namespace quant {
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize,
const std::string &convWeightChannelThreshold, const std::string &bitNum)
: Quantizer(graph) {
auto quantSize = static_cast<size_t>(std::stoull(weightSize));
this->bitNum = static_cast<size_t>(std::stoull(bitNum));
auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold));
mStrategy.reset(new QuantStrategy(quantSize, convQuantWeightChannelThreshold));
}
STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
for (auto &cnode : nodes) {
if (!mStrategy->CanConvOpQuantized(cnode)) {
continue;
}
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return RET_ERROR;
}
auto inputNode = cnode->input(2);
if (!inputNode->isa<Parameter>()) {
return RET_ERROR;
}
auto paramNode = inputNode->cast<ParameterPtr>();
if (!paramNode->has_default()) {
return RET_ERROR;
}
std::vector<schema::QuantParamT> quant_params;
primitive_c->AddInputQuantParam(quant_params);
auto op_type = (schema::PrimitiveType)primitive_c->Type();
bool depthwise = op_type == schema::PrimitiveType_DepthwiseConv2D ? true : false;
ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(paramNode->default_param());
auto status = QuantFilter<uint8_t>(param_value, primitive_c, QuantType_WeightQuant, 255, 0,
bitNum, true, depthwise);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
param_value->set_tensor_type(kNumberTypeUInt8);
primitive_c->SetQuantType(schema::QuantType_WeightQuant);
}
return RET_OK;
}
STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
for (auto &node : nodes) {
if (!mStrategy->CanMulOpQuantized(node)) {
continue;
}
ParamValueLitePtr param_value = nullptr;
for (size_t i = 1; i < node->size(); i++) {
auto inputNode = node->input(i);
if (inputNode->isa<Parameter>() == true) {
auto paramNode = inputNode->cast<ParameterPtr>();
if ((paramNode != nullptr) && (paramNode->has_default() == true)) {
param_value = std::static_pointer_cast<ParamValueLite>(paramNode->default_param());
if ((param_value == nullptr) || (param_value->tensor_size() == 0)
|| (param_value->tensor_shape().size() != 4)
|| (param_value->tensor_addr() == nullptr)
|| (param_value->tensor_type() != mindspore::kNumberTypeFloat32)) {
param_value = nullptr;
continue;
} else {
break;
}
}
}
}
if (param_value == nullptr) {
MS_LOG(ERROR) << "No valid input param node !";
return RET_ERROR;;
}
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return RET_ERROR;
}
auto status = QuantFilter<uint8_t>(param_value, primitive_c, QuantType_WeightQuant, 255, 0, bitNum, true, false);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
param_value->set_tensor_type(kNumberTypeUInt8);
primitive_c->SetQuantType(schema::QuantType_WeightQuant);
}
return RET_OK;
}
STATUS WeightQuantizer::DoQuantize(FuncGraphPtr funcGraph) {
auto ret = RET_OK;
auto cnodes = funcGraph->GetOrderedCnodes();
ret = DoConvQuantize(cnodes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoConvQuantize failed :" << ret;
return ret;
}
ret = DoMulQuantize(cnodes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoMulQuantize failed :" << ret;
return ret;
}
return ret;
}
} // namespace quant
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,53 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef WEIGHT_QUANTIZER_H
#define WEIGHT_QUANTIZER_H
#include <memory>
#include <list>
#include <string>
#include "tools/converter/quantizer/quantizer.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "include/model.h"
#include "base/base.h"
#include "abstract/dshape.h"
namespace mindspore {
namespace lite {
namespace quant {
class WeightQuantizer : public Quantizer {
public:
WeightQuantizer(FuncGraphPtr graph, const std::string& weightSize,
const std::string& covWeightChannelThreshold, const std::string& bitNum);
~WeightQuantizer() = default;
STATUS DoQuantize(FuncGraphPtr funcGraph) override;
STATUS DoConvQuantize(const std::list<CNodePtr> &nodes);
STATUS DoMulQuantize(const std::list<CNodePtr> &nodes);
private:
std::unique_ptr<QuantStrategy> mStrategy;
size_t bitNum;
};
} // namespace quant
} // namespace lite
} // namespace mindspore
#endif