forked from mindspore-Ecosystem/mindspore
====weight quant======
This commit is contained in:
parent
d27178bf2b
commit
0f2c78253e
|
@ -103,8 +103,9 @@ int ModelImpl::BuildOps() {
|
|||
auto cNode = meta_graph_->nodes()->GetAs<schema::CNode>(i);
|
||||
auto name = cNode->name()->str();
|
||||
auto srcPrim = cNode->primitive();
|
||||
|
||||
this->ops_[name] = PrimitiveC::UnPackFromSchemaPrimitive(const_cast<schema::Primitive *>(srcPrim));
|
||||
auto prim = PrimitiveC::UnPackFromSchemaPrimitive(const_cast<schema::Primitive *>(srcPrim));
|
||||
prim->SetQuantType(cNode->quantType());
|
||||
this->ops_[name] = prim;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -688,6 +688,10 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(const schema::Primitive *primi
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
void PrimitiveC::SetQuantType(schema::QuantType quant_type) {
|
||||
this->quant_type_ = quant_type;
|
||||
}
|
||||
schema::QuantType PrimitiveC::GetQuantType() const { return quant_type_;}
|
||||
#endif
|
||||
|
||||
int PrimitiveC::Type() const {
|
||||
|
|
|
@ -145,6 +145,9 @@ class PrimitiveC {
|
|||
|
||||
int Type() const;
|
||||
|
||||
void SetQuantType(schema::QuantType quant_type);
|
||||
schema::QuantType GetQuantType() const;
|
||||
|
||||
protected:
|
||||
template <typename T, typename = std::enable_if<std::is_base_of<PrimitiveC, T>::value>>
|
||||
static PrimitiveC *NewPrimitiveC(const schema::Primitive *primitive) {
|
||||
|
@ -194,6 +197,7 @@ class PrimitiveC {
|
|||
const schema::Primitive *primitive_ = nullptr;
|
||||
char *primitive_buf_ = nullptr;
|
||||
bool infer_flag_ = true;
|
||||
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
|
||||
};
|
||||
#endif
|
||||
} // namespace lite
|
||||
|
|
|
@ -331,4 +331,46 @@ int ConvolutionBaseCPUKernel::SetQuantParam() {
|
|||
|
||||
return RET_OK;
|
||||
}
|
||||
int ConvolutionBaseCPUKernel::RestoreFilter(lite::tensor::Tensor *input_tensor) {
|
||||
MS_ASSERT(input_tensor != nullptr);
|
||||
if (input_tensor->GetQuantParams().empty()) {
|
||||
MS_LOG(ERROR) << "no quant param";
|
||||
return RET_ERROR;
|
||||
}
|
||||
const auto* quant_data = static_cast<const uint8_t*>(input_tensor->Data());
|
||||
auto* dequant_data = static_cast<float *>(malloc(input_tensor->DataSize() * sizeof(float)));
|
||||
if (dequant_data == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc faile";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (input_tensor->GetQuantParams().size() != kPerTensor) {
|
||||
size_t channels = static_cast<size_t>(input_tensor->Batch());
|
||||
if (input_tensor->GetQuantParams().size() != channels) {
|
||||
MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels;
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t per_channel_size = input_tensor->DataSize() / channels;
|
||||
auto quant_param = input_tensor->GetQuantParams();
|
||||
for (size_t i = 0; i < channels; i++) {
|
||||
auto param = quant_param.at(i);
|
||||
auto scale = param.scale;
|
||||
auto zero_point = param.zeroPoint;
|
||||
for (size_t j = 0; j < per_channel_size; j++) {
|
||||
dequant_data[per_channel_size * i + j] = static_cast<float>(
|
||||
(quant_data[per_channel_size * i + j] - zero_point) * scale);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto quant_param = input_tensor->GetQuantParams();
|
||||
auto param = quant_param.front();
|
||||
auto scale = param.scale;
|
||||
auto zero_point = param.zeroPoint;
|
||||
for (int64_t j = 0; j < input_tensor->DataSize(); j++) {
|
||||
dequant_data[j] = static_cast<float>((quant_data[j] - zero_point) * scale);
|
||||
}
|
||||
}
|
||||
input_tensor->SetData(dequant_data);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -60,6 +60,7 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
|
|||
int SetQuantMultiplier();
|
||||
int CheckResizeValid();
|
||||
void FreeQuantParam();
|
||||
static int RestoreFilter(lite::tensor::Tensor *input_tensor);
|
||||
|
||||
protected:
|
||||
int tile_num_;
|
||||
|
|
|
@ -239,6 +239,12 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
|
|||
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
|
||||
}
|
||||
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->Data();
|
||||
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex));
|
||||
}
|
||||
|
||||
kernel::LiteKernel *kernel;
|
||||
if (kernel_h == 1 && kernel_w == 1) {
|
||||
kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||
|
@ -263,6 +269,12 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
|
|||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -131,6 +131,13 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::tensor::T
|
|||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
|
||||
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->Data();
|
||||
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
ConvolutionBaseCPUKernel::RestoreFilter(inputs.at(kWeightIndex));
|
||||
}
|
||||
|
||||
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
||||
kernel::LiteKernel *kernel;
|
||||
if (conv_param->input_channel_ < 32) {
|
||||
|
@ -149,6 +156,12 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::tensor::T
|
|||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -64,7 +64,8 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
|
|||
MS_ASSERT(dst_node != nullptr);
|
||||
// add quant param
|
||||
dst_node->quantType = primitive->GetQuantType();
|
||||
if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining) {
|
||||
if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining
|
||||
|| dst_node->quantType == schema::QuantType_WeightQuant) {
|
||||
MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam";
|
||||
// activation
|
||||
auto input_quant_params = primitive->GetInputQuantParams();
|
||||
|
@ -103,7 +104,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &me
|
|||
}
|
||||
} else {
|
||||
for (auto output_quant_param : output_quant_params[0]) {
|
||||
if (tensor_output->quantParams.empty()) {
|
||||
if (tensor_output->quantParams.empty() && dst_node->quantType != schema::QuantType_WeightQuant) {
|
||||
std::unique_ptr<schema::QuantParamT> output_quant_param_ptr =
|
||||
std::make_unique<schema::QuantParamT>(output_quant_param);
|
||||
MS_LOG(DEBUG) << "[output]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "tools/optimizer/fusion/constant_folding_fusion.h"
|
||||
#include "tools/converter/quantizer/post_training_quantizer.h"
|
||||
#include "tools/converter/quantizer/quant_cast.h"
|
||||
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||
|
||||
using std::string;
|
||||
namespace mindspore {
|
||||
|
@ -57,11 +58,20 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
FuncGraphPtr new_graph = optimizer->Optimize(old_graph);
|
||||
|
||||
// quant
|
||||
if (config != nullptr && config->quantType == schema::QuantType_PostTraining) {
|
||||
this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8);
|
||||
if (mQuantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
|
||||
return nullptr;
|
||||
if (config != nullptr) {
|
||||
if (config->quantType == schema::QuantType_PostTraining) {
|
||||
this->mQuantizer = std::make_unique<quant::PostTrainingQuantizer>(new_graph, config->configFile, 8);
|
||||
if (mQuantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
|
||||
return nullptr;
|
||||
}
|
||||
} else if (config->quantType == schema::QuantType_WeightQuant) {
|
||||
this->mQuantizer = std::make_unique<quant::WeightQuantizer>(new_graph, config->quantSize,
|
||||
config->convWeightQuantChannelThreshold, config->bitNum);
|
||||
if (mQuantizer == nullptr) {
|
||||
MS_LOG(ERROR) << "New PostTrainingQuantizer failed";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (mQuantizer != nullptr) {
|
||||
|
@ -71,12 +81,14 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
MS_LOG(ERROR) << "Quant failed " << status;
|
||||
return nullptr;
|
||||
}
|
||||
quant::QuantCast quant_cast;
|
||||
quant_cast.SetInputDataDType(kNumberTypeFloat32);
|
||||
status = quant_cast.Run(new_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "add QuantCast error";
|
||||
return nullptr;
|
||||
if (config->quantType == schema::QuantType_PostTraining) {
|
||||
quant::QuantCast quant_cast;
|
||||
quant_cast.SetInputDataDType(kNumberTypeFloat32);
|
||||
status = quant_cast.Run(new_graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "add QuantCast error";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -36,6 +36,8 @@ Flags::Flags() {
|
|||
AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128");
|
||||
AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5");
|
||||
AddFlag(&Flags::quantSize, "quantSize", "Weight quantization size threshold", "0");
|
||||
AddFlag(&Flags::convWeightQuantChannelThreshold, "convWeightQuantChannelThreshold",
|
||||
"convWeightQuantChannelThreshold", "16");
|
||||
AddFlag(&Flags::configFile, "config_file", "Configuration for post-training.", "");
|
||||
AddFlag(&Flags::formatTrans, "formatTrans", "whether transform format. true | false", "true");
|
||||
}
|
||||
|
|
|
@ -191,6 +191,7 @@ STATUS WeightFormatHardCodePass::HardCodeTFLITE(const std::unique_ptr<CNodeT> &n
|
|||
switch (this->quantType) {
|
||||
case QuantType_AwareTraining:
|
||||
case QuantType_PostTraining:
|
||||
case QuantType_WeightQuant:
|
||||
case QuantType_QUANT_NONE: {
|
||||
if (opType == schema::PrimitiveType_Conv2D) {
|
||||
weightTensor->format = schema::Format_KHWC;
|
||||
|
|
|
@ -31,7 +31,7 @@ void WeightFormatTransformPass::SetDstFormat(Format format) { this->dstFormat =
|
|||
|
||||
STATUS WeightFormatTransformPass::Run(MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
if (this->quantType == QuantType_AwareTraining) {
|
||||
if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_WeightQuant) {
|
||||
auto status = QuantDataFormatTrans(graph);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status;
|
||||
|
|
|
@ -11,6 +11,7 @@ add_library(quantizer_mid OBJECT
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc
|
||||
)
|
||||
|
||||
if(ENABLE_ASAN)
|
||||
|
|
|
@ -530,7 +530,8 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto status =
|
||||
QuantFilter(paramValue, primitive_c, QuantType_PostTraining, quant_max, quant_min, bit_num, perchanel, depthwise);
|
||||
QuantFilter<int8_t>(paramValue, primitive_c, QuantType_PostTraining, quant_max,
|
||||
quant_min, bit_num, perchanel, depthwise);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
||||
return status;
|
||||
|
|
|
@ -279,171 +279,6 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
|
||||
int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) {
|
||||
auto dims = weight->tensor_shape();
|
||||
if (per_channel) {
|
||||
if (dims.size() != 4) {
|
||||
MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer.";
|
||||
per_channel = false;
|
||||
} else {
|
||||
uint32_t channels = dims[0];
|
||||
if (channels == 0) {
|
||||
MS_LOG(ERROR) << "channels is 0";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
vector<schema::QuantParamT> quant_params;
|
||||
size_t elem_count = weight->tensor_shape_size();
|
||||
auto *raw_datas = static_cast<float *>(weight->tensor_addr());
|
||||
if (raw_datas == nullptr) {
|
||||
MS_LOG(ERROR) << "rawDatas is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
vector<int8_t> quant_datas(elem_count);
|
||||
|
||||
if (per_channel) {
|
||||
// notice:
|
||||
// at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D
|
||||
// if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK
|
||||
if (depth_wise) {
|
||||
// channel at last
|
||||
auto channels = dims[3];
|
||||
if (channels == 0) {
|
||||
MS_LOG(ERROR) << "channels is zero";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t one_filter_size = elem_count / channels;
|
||||
|
||||
for (int i = 0; i < channels; i++) {
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MAX;
|
||||
// find min and max
|
||||
for (size_t j = 0; j < one_filter_size; j++) {
|
||||
auto index = i + j * channels;
|
||||
if (index >= elem_count) {
|
||||
MS_LOG(ERROR) << "over flow!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
min = std::min(min, raw_datas[index]);
|
||||
max = std::max(max, raw_datas[index]);
|
||||
}
|
||||
schema::QuantParamT quant_param;
|
||||
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
|
||||
return status;
|
||||
}
|
||||
quant_params.emplace_back(quant_param);
|
||||
// do quantization
|
||||
for (uint32_t j = 0; j < one_filter_size; j++) {
|
||||
auto index = i + j * channels;
|
||||
if (index >= elem_count) {
|
||||
MS_LOG(ERROR) << "over flow!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
float raw_data = raw_datas[index];
|
||||
auto quant_data = QuantizeData<int8_t>(raw_data, quant_param, quant_max, quant_min);
|
||||
quant_datas[index] = quant_data;
|
||||
}
|
||||
}
|
||||
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(),
|
||||
elem_count * sizeof(int8_t));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight->set_tensor_size(elem_count * sizeof(int8_t));
|
||||
} else {
|
||||
// channel at first
|
||||
auto channels = dims[0];
|
||||
if (channels == 0) {
|
||||
MS_LOG(ERROR) << "channels is zero";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t one_filter_size = elem_count / channels;
|
||||
|
||||
for (int i = 0; i < channels; i++) {
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MAX;
|
||||
// find min and max
|
||||
for (size_t j = 0; j < one_filter_size; j++) {
|
||||
auto index = j + i * one_filter_size;
|
||||
if (index >= elem_count) {
|
||||
MS_LOG(ERROR) << "over flow!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
min = std::min(min, raw_datas[index]);
|
||||
max = std::max(max, raw_datas[index]);
|
||||
}
|
||||
schema::QuantParamT quant_param;
|
||||
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
|
||||
return status;
|
||||
}
|
||||
quant_params.emplace_back(quant_param);
|
||||
// do quantization
|
||||
for (uint32_t j = 0; j < one_filter_size; j++) {
|
||||
auto index = j + i * one_filter_size;
|
||||
if (index >= elem_count) {
|
||||
MS_LOG(ERROR) << "over flow!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
float raw_data = raw_datas[index];
|
||||
auto quant_data = QuantizeData<int8_t>(raw_data, quant_param, quant_max, quant_min);
|
||||
quant_datas[index] = quant_data;
|
||||
}
|
||||
}
|
||||
auto ret =
|
||||
memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight->set_tensor_size(elem_count * sizeof(int8_t));
|
||||
}
|
||||
|
||||
} else {
|
||||
// per layer
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MIN;
|
||||
for (uint32_t i = 0; i < elem_count; i++) {
|
||||
// find max min
|
||||
min = std::min(min, raw_datas[i]);
|
||||
max = std::max(max, raw_datas[i]);
|
||||
}
|
||||
|
||||
schema::QuantParamT quant_param;
|
||||
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
|
||||
return status;
|
||||
}
|
||||
quant_params.emplace_back(quant_param);
|
||||
// update data and datatype
|
||||
for (uint32_t i = 0; i < elem_count; i++) {
|
||||
float raw_data = raw_datas[i];
|
||||
auto quant_data = QuantizeData<int8_t>(raw_data, quant_param, quant_max, quant_min);
|
||||
quant_datas[i] = quant_data;
|
||||
}
|
||||
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight->set_tensor_size(elem_count * sizeof(int8_t));
|
||||
}
|
||||
if (quant_params.empty()) {
|
||||
MS_LOG(ERROR) << "quant_params empty";
|
||||
return RET_ERROR;
|
||||
}
|
||||
primitive_c->AddInputQuantParam(quant_params);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) {
|
||||
auto *rawDatas = reinterpret_cast<uint8_t *>(weight);
|
||||
vector<uint8_t> qDatas(rawDatas, rawDatas + shapeSize);
|
||||
|
|
|
@ -21,6 +21,8 @@
|
|||
#include <string>
|
||||
#include <cmath>
|
||||
#include <array>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "tools/converter/quantizer/quantizer.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -117,10 +119,171 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan
|
|||
return static_cast<T>(quant_data);
|
||||
}();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
|
||||
int quant_max, int quant_min, size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false,
|
||||
bool depth_wise = false);
|
||||
int quant_max, int quant_min, size_t bitNum, bool per_channel, bool depth_wise) {
|
||||
auto dims = weight->tensor_shape();
|
||||
if (per_channel) {
|
||||
if (dims.size() != 4) {
|
||||
MS_LOG(ERROR) << "weight dims size error: " << dims.size() << " Back to per layer.";
|
||||
per_channel = false;
|
||||
} else {
|
||||
uint32_t channels = dims[0];
|
||||
if (channels == 0) {
|
||||
MS_LOG(ERROR) << "channels is 0";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<schema::QuantParamT> quant_params;
|
||||
size_t elem_count = weight->tensor_shape_size();
|
||||
auto *raw_datas = static_cast<float *>(weight->tensor_addr());
|
||||
if (raw_datas == nullptr) {
|
||||
MS_LOG(ERROR) << "rawDatas is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<T> quant_datas(elem_count);
|
||||
|
||||
if (per_channel) {
|
||||
// notice:
|
||||
// at now for tflite model, Conv2D's weight format is KHWC, so is DepthwiseConv2D
|
||||
// if TransWeightFormat is done before PostTraingingQuantization, the DepthwiseCon2D's weight is CHWK
|
||||
if (depth_wise) {
|
||||
// channel at last
|
||||
auto channels = dims[3];
|
||||
if (channels == 0) {
|
||||
MS_LOG(ERROR) << "channels is zero";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t one_filter_size = elem_count / channels;
|
||||
|
||||
for (int i = 0; i < channels; i++) {
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MAX;
|
||||
// find min and max
|
||||
for (size_t j = 0; j < one_filter_size; j++) {
|
||||
auto index = i + j * channels;
|
||||
if (index >= elem_count) {
|
||||
MS_LOG(ERROR) << "over flow!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
min = std::min(min, raw_datas[index]);
|
||||
max = std::max(max, raw_datas[index]);
|
||||
}
|
||||
schema::QuantParamT quant_param;
|
||||
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
|
||||
return status;
|
||||
}
|
||||
quant_params.emplace_back(quant_param);
|
||||
// do quantization
|
||||
for (uint32_t j = 0; j < one_filter_size; j++) {
|
||||
auto index = i + j * channels;
|
||||
if (index >= elem_count) {
|
||||
MS_LOG(ERROR) << "over flow!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
float raw_data = raw_datas[index];
|
||||
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
|
||||
quant_datas[index] = quant_data;
|
||||
}
|
||||
}
|
||||
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(),
|
||||
elem_count * sizeof(T));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight->set_tensor_size(elem_count * sizeof(T));
|
||||
} else {
|
||||
// channel at first
|
||||
auto channels = dims[0];
|
||||
if (channels == 0) {
|
||||
MS_LOG(ERROR) << "channels is zero";
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t one_filter_size = elem_count / channels;
|
||||
|
||||
for (int i = 0; i < channels; i++) {
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MAX;
|
||||
// find min and max
|
||||
for (size_t j = 0; j < one_filter_size; j++) {
|
||||
auto index = j + i * one_filter_size;
|
||||
if (index >= elem_count) {
|
||||
MS_LOG(ERROR) << "over flow!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
min = std::min(min, raw_datas[index]);
|
||||
max = std::max(max, raw_datas[index]);
|
||||
}
|
||||
schema::QuantParamT quant_param;
|
||||
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
|
||||
return status;
|
||||
}
|
||||
quant_params.emplace_back(quant_param);
|
||||
// do quantization
|
||||
for (uint32_t j = 0; j < one_filter_size; j++) {
|
||||
auto index = j + i * one_filter_size;
|
||||
if (index >= elem_count) {
|
||||
MS_LOG(ERROR) << "over flow!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
float raw_data = raw_datas[index];
|
||||
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
|
||||
quant_datas[index] = quant_data;
|
||||
}
|
||||
}
|
||||
auto ret =
|
||||
memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight->set_tensor_size(elem_count * sizeof(T));
|
||||
}
|
||||
|
||||
} else {
|
||||
// per layer
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MIN;
|
||||
for (uint32_t i = 0; i < elem_count; i++) {
|
||||
// find max min
|
||||
min = std::min(min, raw_datas[i]);
|
||||
max = std::max(max, raw_datas[i]);
|
||||
}
|
||||
|
||||
schema::QuantParamT quant_param;
|
||||
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
|
||||
return status;
|
||||
}
|
||||
quant_params.emplace_back(quant_param);
|
||||
// update data and datatype
|
||||
for (uint32_t i = 0; i < elem_count; i++) {
|
||||
float raw_data = raw_datas[i];
|
||||
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
|
||||
quant_datas[i] = quant_data;
|
||||
}
|
||||
auto ret = memcpy_s(raw_datas, weight->tensor_size(), quant_datas.data(), elem_count * sizeof(int8_t));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "memcpy error: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
weight->set_tensor_size(elem_count * sizeof(T));
|
||||
}
|
||||
if (quant_params.empty()) {
|
||||
MS_LOG(ERROR) << "quant_params empty";
|
||||
return RET_ERROR;
|
||||
}
|
||||
primitive_c->AddInputQuantParam(quant_params);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION);
|
||||
} // namespace quant
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/quantizer/weight_quantizer.h"
|
||||
#include <list>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "src/common/common.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace quant {
|
||||
|
||||
WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize,
|
||||
const std::string &convWeightChannelThreshold, const std::string &bitNum)
|
||||
: Quantizer(graph) {
|
||||
auto quantSize = static_cast<size_t>(std::stoull(weightSize));
|
||||
this->bitNum = static_cast<size_t>(std::stoull(bitNum));
|
||||
auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold));
|
||||
mStrategy.reset(new QuantStrategy(quantSize, convQuantWeightChannelThreshold));
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
|
||||
for (auto &cnode : nodes) {
|
||||
if (!mStrategy->CanConvOpQuantized(cnode)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto inputNode = cnode->input(2);
|
||||
if (!inputNode->isa<Parameter>()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto paramNode = inputNode->cast<ParameterPtr>();
|
||||
if (!paramNode->has_default()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
std::vector<schema::QuantParamT> quant_params;
|
||||
primitive_c->AddInputQuantParam(quant_params);
|
||||
|
||||
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
||||
bool depthwise = op_type == schema::PrimitiveType_DepthwiseConv2D ? true : false;
|
||||
|
||||
ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(paramNode->default_param());
|
||||
auto status = QuantFilter<uint8_t>(param_value, primitive_c, QuantType_WeightQuant, 255, 0,
|
||||
bitNum, true, depthwise);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
param_value->set_tensor_type(kNumberTypeUInt8);
|
||||
primitive_c->SetQuantType(schema::QuantType_WeightQuant);
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
|
||||
for (auto &node : nodes) {
|
||||
if (!mStrategy->CanMulOpQuantized(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ParamValueLitePtr param_value = nullptr;
|
||||
for (size_t i = 1; i < node->size(); i++) {
|
||||
auto inputNode = node->input(i);
|
||||
if (inputNode->isa<Parameter>() == true) {
|
||||
auto paramNode = inputNode->cast<ParameterPtr>();
|
||||
if ((paramNode != nullptr) && (paramNode->has_default() == true)) {
|
||||
param_value = std::static_pointer_cast<ParamValueLite>(paramNode->default_param());
|
||||
if ((param_value == nullptr) || (param_value->tensor_size() == 0)
|
||||
|| (param_value->tensor_shape().size() != 4)
|
||||
|| (param_value->tensor_addr() == nullptr)
|
||||
|| (param_value->tensor_type() != mindspore::kNumberTypeFloat32)) {
|
||||
param_value = nullptr;
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "No valid input param node !";
|
||||
return RET_ERROR;;
|
||||
}
|
||||
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto status = QuantFilter<uint8_t>(param_value, primitive_c, QuantType_WeightQuant, 255, 0, bitNum, true, false);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
param_value->set_tensor_type(kNumberTypeUInt8);
|
||||
primitive_c->SetQuantType(schema::QuantType_WeightQuant);
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoQuantize(FuncGraphPtr funcGraph) {
|
||||
auto ret = RET_OK;
|
||||
auto cnodes = funcGraph->GetOrderedCnodes();
|
||||
ret = DoConvQuantize(cnodes);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoConvQuantize failed :" << ret;
|
||||
return ret;
|
||||
}
|
||||
ret = DoMulQuantize(cnodes);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoMulQuantize failed :" << ret;
|
||||
return ret;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
} // namespace quant
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef WEIGHT_QUANTIZER_H
|
||||
#define WEIGHT_QUANTIZER_H
|
||||
|
||||
#include <memory>
|
||||
#include <list>
|
||||
#include <string>
|
||||
#include "tools/converter/quantizer/quantizer.h"
|
||||
#include "tools/converter/quantizer/quantize_util.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
#include "include/model.h"
|
||||
#include "base/base.h"
|
||||
#include "abstract/dshape.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace quant {
|
||||
class WeightQuantizer : public Quantizer {
|
||||
public:
|
||||
WeightQuantizer(FuncGraphPtr graph, const std::string& weightSize,
|
||||
const std::string& covWeightChannelThreshold, const std::string& bitNum);
|
||||
|
||||
~WeightQuantizer() = default;
|
||||
|
||||
STATUS DoQuantize(FuncGraphPtr funcGraph) override;
|
||||
STATUS DoConvQuantize(const std::list<CNodePtr> &nodes);
|
||||
STATUS DoMulQuantize(const std::list<CNodePtr> &nodes);
|
||||
|
||||
private:
|
||||
std::unique_ptr<QuantStrategy> mStrategy;
|
||||
size_t bitNum;
|
||||
};
|
||||
} // namespace quant
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue