forked from mindspore-Ecosystem/mindspore
fix matmul quantizationt
This commit is contained in:
parent
6b39c89da7
commit
72daa10df6
|
@ -18,9 +18,10 @@
|
|||
#include <memory>
|
||||
#include "src/dequant.h"
|
||||
#include "src/huffman_decode.h"
|
||||
#include "src/ops/matmul.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
|
||||
float *DequantUtil::DequantWeight(lite::Tensor *input_tensor, bool channel_first) {
|
||||
MS_ASSERT(input_tensor != nullptr);
|
||||
if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) {
|
||||
MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type();
|
||||
|
@ -31,9 +32,9 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
|
|||
return nullptr;
|
||||
}
|
||||
if (input_tensor->data_type() == kNumberTypeInt16) {
|
||||
return DequantData<int16_t>(input_tensor);
|
||||
return DequantData<int16_t>(input_tensor, channel_first);
|
||||
} else {
|
||||
return DequantData<int8_t>(input_tensor);
|
||||
return DequantData<int8_t>(input_tensor, channel_first);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -65,19 +66,35 @@ int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_in
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors,
|
||||
std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const mindspore::lite::PrimitiveC *primitive,
|
||||
const std::vector<Tensor *> &in_tensors,
|
||||
TypeId data_type, bool need_restore) {
|
||||
std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data;
|
||||
if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) {
|
||||
auto input_i = 0;
|
||||
for (auto weight_tensor : in_tensors) {
|
||||
MS_ASSERT(weight_tensor != nullptr);
|
||||
input_i++;
|
||||
auto channel_first = true;
|
||||
if ((schema::PrimitiveType)primitive->Type() == schema::PrimitiveType_MatMul &&
|
||||
weight_tensor->shape().size() == 2) {
|
||||
auto param = reinterpret_cast<mindspore::lite::MatMul *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
if (input_i == 1) {
|
||||
channel_first = !param->GetTransposeA();
|
||||
} else if (input_i == 2) {
|
||||
channel_first = param->GetTransposeB();
|
||||
} else {
|
||||
MS_LOG(WARNING) << "unexpected input_i";
|
||||
}
|
||||
}
|
||||
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
auto restore_type = weight_tensor->data_type();
|
||||
bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited &&
|
||||
restore_data != nullptr &&
|
||||
(restore_type == kNumberTypeInt8 || restore_type == kNumberTypeInt16);
|
||||
if (dequant_flag) {
|
||||
auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor);
|
||||
auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor, channel_first);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
return tensor_origin_data;
|
||||
|
|
|
@ -29,17 +29,18 @@
|
|||
namespace mindspore::lite {
|
||||
class DequantUtil {
|
||||
public:
|
||||
static float *DequantWeight(lite::Tensor *input_tensor);
|
||||
static float *DequantWeight(lite::Tensor *input_tensor, bool);
|
||||
|
||||
static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);
|
||||
|
||||
static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors,
|
||||
static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const mindspore::lite::PrimitiveC *primitive,
|
||||
const std::vector<Tensor *> &in_tensors,
|
||||
TypeId data_type, bool need_restore = true);
|
||||
|
||||
static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map);
|
||||
|
||||
template <typename ST, typename DT = float>
|
||||
static DT *DequantData(lite::Tensor *input_tensor) {
|
||||
static DT *DequantData(lite::Tensor *input_tensor, bool channel_first = true) {
|
||||
const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData());
|
||||
if (quant_datas == nullptr) {
|
||||
MS_LOG(ERROR) << "Get quant tensor failed.";
|
||||
|
@ -65,6 +66,13 @@ class DequantUtil {
|
|||
}
|
||||
} else if (input_tensor->quant_params().size() != kPerTensor) {
|
||||
auto channels = static_cast<size_t>(input_tensor->Batch());
|
||||
if (!channel_first) {
|
||||
if (input_tensor->shape().size() != 2) {
|
||||
MS_LOG(ERROR) << "unexpected shape size: " << input_tensor->shape().size();
|
||||
return nullptr;
|
||||
}
|
||||
channels = input_tensor->shape()[1];
|
||||
}
|
||||
if (input_tensor->quant_params().size() != channels) {
|
||||
MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->quant_params().size() << channels;
|
||||
free(dequant_datas);
|
||||
|
@ -83,8 +91,12 @@ class DequantUtil {
|
|||
var_corr = 1;
|
||||
}
|
||||
for (size_t j = 0; j < per_channel_size; j++) {
|
||||
auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale;
|
||||
dequant_datas[per_channel_size * i + j] = static_cast<DT>(dequant_data * var_corr + mean_corr);
|
||||
auto index = per_channel_size * i + j;
|
||||
if (!channel_first) {
|
||||
index = channels * j + i;
|
||||
}
|
||||
auto dequant_data = (quant_datas[index] - zero_point) * scale;
|
||||
dequant_datas[index] = static_cast<DT>(dequant_data * var_corr + mean_corr);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -223,7 +223,8 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
|||
if (mindspore::lite::IsSupportFloat16() &&
|
||||
((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) {
|
||||
kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type};
|
||||
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type, need_restore);
|
||||
auto tensor_origin_data_map =
|
||||
DequantUtil::DequantTensor(primitive, in_tensors, fp16_cpu_desc.data_type, need_restore);
|
||||
auto *kernel =
|
||||
KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc);
|
||||
DequantUtil::RestoreTensorData(tensor_origin_data_map);
|
||||
|
@ -237,7 +238,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
|||
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
|
||||
desc.data_type = kNumberTypeFloat32;
|
||||
}
|
||||
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type, need_restore);
|
||||
auto tensor_origin_data_map = DequantUtil::DequantTensor(primitive, in_tensors, desc.data_type, need_restore);
|
||||
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
|
||||
DequantUtil::RestoreTensorData(tensor_origin_data_map);
|
||||
if (kernel != nullptr) {
|
||||
|
|
|
@ -358,15 +358,15 @@ static bool SearchUpperBound(const std::vector<float> &data, const size_t &index
|
|||
return true;
|
||||
}
|
||||
|
||||
static float CalPercentile(const std::vector<float> &datas, const int &outlier_percent) {
|
||||
const int size = datas.size();
|
||||
static float CalPercentile(const std::vector<float> &data, const int &outlier_percent) {
|
||||
const int size = data.size();
|
||||
float val = outlier_percent / 100.0 * size;
|
||||
int index = std::ceil(val);
|
||||
float result;
|
||||
if (index - val > 0) {
|
||||
result = datas.at(index - 1);
|
||||
result = data.at(index - 1);
|
||||
} else {
|
||||
result = (datas.at(index - 1) + datas.at(index)) / 2;
|
||||
result = (data.at(index - 1) + data.at(index)) / 2;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -522,11 +522,78 @@ std::vector<std::vector<int>> DataToVectors(const string &str) {
|
|||
return result;
|
||||
}
|
||||
|
||||
STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) {
|
||||
if (post_quant_config == nullptr) {
|
||||
MS_LOG(ERROR) << "post_quant_config is null.";
|
||||
return RET_PARAM_INVALID;
|
||||
void ParseInputShape(PostQuantConfig *post_quant_config, std::string raw_shape) {
|
||||
MS_ASSERT(post_quant_config != nullptr);
|
||||
auto ind = raw_shape.find('/');
|
||||
while (ind != std::string::npos) {
|
||||
auto shape = raw_shape.substr(0, ind);
|
||||
Trim(&shape);
|
||||
post_quant_config->input_shapes.push_back(DataToVectors(shape));
|
||||
raw_shape = raw_shape.substr(ind + 1);
|
||||
Trim(&raw_shape);
|
||||
ind = raw_shape.find('/');
|
||||
}
|
||||
if (!raw_shape.empty()) {
|
||||
post_quant_config->input_shapes.push_back(DataToVectors(raw_shape));
|
||||
}
|
||||
}
|
||||
|
||||
void ParseImagePath(PostQuantConfig *post_quant_config, std::string raw_image_paths) {
|
||||
MS_ASSERT(post_quant_config != nullptr);
|
||||
auto ind = raw_image_paths.find(',');
|
||||
while (ind != std::string::npos) {
|
||||
auto image_path = raw_image_paths.substr(0, ind);
|
||||
Trim(&image_path);
|
||||
post_quant_config->image_paths.push_back(image_path);
|
||||
raw_image_paths = raw_image_paths.substr(ind + 1);
|
||||
Trim(&raw_image_paths);
|
||||
ind = raw_image_paths.find(',');
|
||||
}
|
||||
post_quant_config->image_paths.push_back(raw_image_paths);
|
||||
}
|
||||
|
||||
void ParseBatchCount(PostQuantConfig *post_quant_config, std::string value) {
|
||||
MS_ASSERT(post_quant_config != nullptr);
|
||||
post_quant_config->batch_count = std::stoul(value);
|
||||
}
|
||||
|
||||
void ParseThreadNum(PostQuantConfig *post_quant_config, std::string value) {
|
||||
MS_ASSERT(post_quant_config != nullptr);
|
||||
post_quant_config->thread_num = std::stoul(value);
|
||||
}
|
||||
|
||||
void ParseMethodX(PostQuantConfig *post_quant_config, const std::string &value) {
|
||||
MS_ASSERT(post_quant_config != nullptr);
|
||||
if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) {
|
||||
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
|
||||
} else {
|
||||
post_quant_config->method_x = value;
|
||||
}
|
||||
}
|
||||
|
||||
void ParseMixed(PostQuantConfig *post_quant_config, std::string value) {
|
||||
MS_ASSERT(post_quant_config != nullptr);
|
||||
std::for_each(value.begin(), value.end(), ::tolower);
|
||||
if (value == "true") {
|
||||
post_quant_config->mixed = true;
|
||||
}
|
||||
}
|
||||
|
||||
void ParseMeanErrorThreshold(PostQuantConfig *post_quant_config, std::string value) {
|
||||
MS_ASSERT(post_quant_config != nullptr);
|
||||
post_quant_config->mean_error_threshold = std::stof(value);
|
||||
}
|
||||
|
||||
void ParseBiasCorrection(PostQuantConfig *post_quant_config, std::string value) {
|
||||
MS_ASSERT(post_quant_config != nullptr);
|
||||
std::for_each(value.begin(), value.end(), ::tolower);
|
||||
if (value == "true") {
|
||||
post_quant_config->bias_correction = true;
|
||||
}
|
||||
}
|
||||
|
||||
STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) {
|
||||
MS_ASSERT(post_quant_config != nullptr);
|
||||
|
||||
if (config_file.empty() || config_file.length() > PATH_MAX) {
|
||||
MS_LOG(ERROR) << "invalid config path!";
|
||||
|
@ -552,6 +619,26 @@ STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_conf
|
|||
MS_LOG(ERROR) << "config file open failed: " << config_file;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
|
||||
std::string INPUT_SHAPES = "input_shapes";
|
||||
std::string IMAGE_PATH = "image_path";
|
||||
std::string BATCH_COUNT = "batch_count";
|
||||
std::string THREAD_NUM = "thread_num";
|
||||
std::string METHOD_X = "method_x";
|
||||
std::string MIXED = "mixed";
|
||||
std::string MEAN_ERROR_THRESHOLD = "mean_error_threshold";
|
||||
std::string BIAS_CORRECTION = "bias_correction";
|
||||
|
||||
std::map<std::string, std::function<void(PostQuantConfig *, std::string)>> value_parser;
|
||||
value_parser[INPUT_SHAPES] = ParseInputShape;
|
||||
value_parser[IMAGE_PATH] = ParseImagePath;
|
||||
value_parser[BATCH_COUNT] = ParseBatchCount;
|
||||
value_parser[THREAD_NUM] = ParseThreadNum;
|
||||
value_parser[METHOD_X] = ParseMethodX;
|
||||
value_parser[MIXED] = ParseMixed;
|
||||
value_parser[MEAN_ERROR_THRESHOLD] = ParseMeanErrorThreshold;
|
||||
value_parser[BIAS_CORRECTION] = ParseBiasCorrection;
|
||||
|
||||
std::string line;
|
||||
while (std::getline(fs, line)) {
|
||||
Trim(&line);
|
||||
|
@ -567,54 +654,9 @@ STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_conf
|
|||
auto value = line.substr(index + 1);
|
||||
Trim(&key);
|
||||
Trim(&value);
|
||||
if (key == "image_path") {
|
||||
auto &raw_image_paths = value;
|
||||
auto ind = raw_image_paths.find(',');
|
||||
while (ind != std::string::npos) {
|
||||
auto image_path = raw_image_paths.substr(0, ind);
|
||||
Trim(&image_path);
|
||||
post_quant_config->image_paths.push_back(image_path);
|
||||
raw_image_paths = raw_image_paths.substr(ind + 1);
|
||||
Trim(&raw_image_paths);
|
||||
ind = raw_image_paths.find(',');
|
||||
}
|
||||
post_quant_config->image_paths.push_back(raw_image_paths);
|
||||
} else if (key == "batch_count") {
|
||||
post_quant_config->batch_count = std::stoul(value);
|
||||
} else if (key == "thread_num") {
|
||||
post_quant_config->thread_num = std::stoul(value);
|
||||
} else if (key == "method_x") {
|
||||
if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) {
|
||||
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
|
||||
} else {
|
||||
post_quant_config->method_x = value;
|
||||
}
|
||||
} else if (key == "bias_correction") {
|
||||
std::for_each(value.begin(), value.end(), ::tolower);
|
||||
if (value == "true") {
|
||||
post_quant_config->bias_correction = true;
|
||||
}
|
||||
} else if (key == "mixed") {
|
||||
std::for_each(value.begin(), value.end(), ::tolower);
|
||||
if (value == "true") {
|
||||
post_quant_config->mixed = true;
|
||||
}
|
||||
} else if (key == "mean_error_threshold") {
|
||||
post_quant_config->mean_error_threshold = std::stof(value);
|
||||
} else if (key == "input_shapes") {
|
||||
auto &raw_shape = value;
|
||||
auto ind = raw_shape.find('/');
|
||||
while (ind != std::string::npos) {
|
||||
auto shape = raw_shape.substr(0, ind);
|
||||
Trim(&shape);
|
||||
post_quant_config->input_shapes.push_back(DataToVectors(shape));
|
||||
raw_shape = raw_shape.substr(ind + 1);
|
||||
Trim(&raw_shape);
|
||||
ind = raw_shape.find('/');
|
||||
}
|
||||
if (!raw_shape.empty()) {
|
||||
post_quant_config->input_shapes.push_back(DataToVectors(raw_shape));
|
||||
}
|
||||
auto it = value_parser.find(key);
|
||||
if (it != value_parser.end()) {
|
||||
it->second(post_quant_config, value);
|
||||
} else {
|
||||
MS_LOG(WARNING) << "unsupported parameter: " << key;
|
||||
}
|
||||
|
@ -881,4 +923,24 @@ STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, float *raw_datas,
|
||||
bool channel_at_first, float *desired_max, float *desired_min) {
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MAX;
|
||||
// find min and max
|
||||
for (int j = 0; j < one_filter_size; j++) {
|
||||
auto index = j + i * one_filter_size;
|
||||
if (!channel_at_first) {
|
||||
index = j * channels + i;
|
||||
}
|
||||
if (index >= elem_count) {
|
||||
MS_LOG(ERROR) << "over flow!";
|
||||
}
|
||||
min = std::min(min, raw_datas[index]);
|
||||
max = std::max(max, raw_datas[index]);
|
||||
}
|
||||
*desired_max = max;
|
||||
*desired_min = min;
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -107,6 +107,9 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc
|
|||
|
||||
STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int new_size);
|
||||
|
||||
void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, float *raw_datas,
|
||||
bool channel_at_first, float *desired_max, float *desired_min);
|
||||
|
||||
template <typename T>
|
||||
T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
|
||||
MS_ASSERT(quantParam != nullptr);
|
||||
|
@ -163,11 +166,19 @@ template <typename T>
|
|||
STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant_type,
|
||||
std::vector<schema::QuantParamT> *quant_params, const int &quant_max, const int &quant_min,
|
||||
const size_t &bit_num, const bool &k_means, std::vector<T> *quant_datas,
|
||||
std::vector<float> *dequant_datas) {
|
||||
std::vector<float> *dequant_datas, bool channel_at_first = true) {
|
||||
auto dims = weight->tensor_shape();
|
||||
size_t elem_count = weight->tensor_shape_size();
|
||||
auto *raw_datas = static_cast<float *>(weight->tensor_addr());
|
||||
auto channels = dims[0];
|
||||
if (!channel_at_first) {
|
||||
if (dims.size() != 2) {
|
||||
MS_LOG(ERROR) << "unexpected dims size: " << dims.size();
|
||||
channel_at_first = true;
|
||||
} else {
|
||||
channels = dims[1];
|
||||
}
|
||||
}
|
||||
if (channels == 0) {
|
||||
MS_LOG(ERROR) << "channels is zero";
|
||||
return RET_ERROR;
|
||||
|
@ -181,16 +192,7 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant
|
|||
for (int i = 0; i < channels; i++) {
|
||||
float min = FLT_MAX;
|
||||
float max = -FLT_MAX;
|
||||
// find min and max
|
||||
for (size_t j = 0; j < one_filter_size; j++) {
|
||||
auto index = j + i * one_filter_size;
|
||||
if (index >= elem_count) {
|
||||
MS_LOG(ERROR) << "over flow!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
min = std::min(min, raw_datas[index]);
|
||||
max = std::max(max, raw_datas[index]);
|
||||
}
|
||||
GetMaxMinPerchannel(channels, one_filter_size, i, elem_count, raw_datas, channel_at_first, &max, &min);
|
||||
schema::QuantParamT quant_param;
|
||||
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num);
|
||||
if (status != RET_OK) {
|
||||
|
@ -202,10 +204,10 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant
|
|||
double average_raw = 0;
|
||||
for (uint32_t j = 0; j < one_filter_size; j++) {
|
||||
auto index = j + i * one_filter_size;
|
||||
if (index >= elem_count) {
|
||||
MS_LOG(ERROR) << "over flow!";
|
||||
return RET_ERROR;
|
||||
if (!channel_at_first) {
|
||||
index = j * channels + i;
|
||||
}
|
||||
MS_ASSERT(index < elem_count);
|
||||
float raw_data = raw_datas[index];
|
||||
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
|
||||
(*quant_datas)[index] = quant_data;
|
||||
|
@ -226,10 +228,10 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant
|
|||
double variance_raw = 0;
|
||||
for (uint32_t j = 0; j < one_filter_size; j++) {
|
||||
auto index = j + i * one_filter_size;
|
||||
if (index >= elem_count) {
|
||||
MS_LOG(ERROR) << "over flow!";
|
||||
return RET_ERROR;
|
||||
if (!channel_at_first) {
|
||||
index = j * channels + i;
|
||||
}
|
||||
MS_ASSERT(index < elem_count);
|
||||
variance_dequant += std::pow((*dequant_datas)[index] - average_dequant, 2);
|
||||
variance_raw += std::pow(raw_datas[index] - average_raw, 2);
|
||||
}
|
||||
|
@ -339,20 +341,26 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit
|
|||
|
||||
std::vector<schema::QuantParamT> quant_params;
|
||||
size_t elem_count = weight->tensor_shape_size();
|
||||
auto *raw_datas = static_cast<float *>(weight->tensor_addr());
|
||||
if (raw_datas == nullptr) {
|
||||
auto *raw_data = static_cast<float *>(weight->tensor_addr());
|
||||
if (raw_data == nullptr) {
|
||||
MS_LOG(ERROR) << "rawDatas is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
std::vector<T> quant_datas(elem_count);
|
||||
std::vector<T> quant_data(elem_count);
|
||||
std::vector<float> dequant_datas(elem_count);
|
||||
int ret = RET_OK;
|
||||
if (per_channel) {
|
||||
// notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC
|
||||
bool channel_at_first = true;
|
||||
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
||||
if (op_type == schema::PrimitiveType_MatMul && weight->tensor_shape().size() == 2) {
|
||||
auto matmul_op = primitive_c->primitiveT()->value.AsMatMul();
|
||||
MS_ASSERT(matmul_op != nullptr);
|
||||
channel_at_first = !(index == 1 && !matmul_op->transposeB);
|
||||
}
|
||||
// channel at first
|
||||
ret = DoPerChannelQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_datas,
|
||||
&dequant_datas);
|
||||
ret = DoPerChannelQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_data,
|
||||
&dequant_datas, channel_at_first);
|
||||
if (ret == RET_CONTINUE) {
|
||||
return ret;
|
||||
} else if (ret != RET_OK) {
|
||||
|
@ -360,7 +368,7 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit
|
|||
return ret;
|
||||
}
|
||||
} else {
|
||||
ret = DoPerLayerQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_datas);
|
||||
ret = DoPerLayerQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_data);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do per layer quant failed.";
|
||||
return ret;
|
||||
|
@ -376,7 +384,7 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit
|
|||
}
|
||||
#else
|
||||
// do bit pack
|
||||
ret = DoBitPack(weight, bit_num, quant_datas);
|
||||
ret = DoBitPack(weight, bit_num, quant_data);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Do bit pack failed.";
|
||||
return ret;
|
||||
|
|
|
@ -127,6 +127,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
|
|||
auto already_quant = false;
|
||||
ParamValueLitePtr param_value = nullptr;
|
||||
ParameterPtr param_node = nullptr;
|
||||
int index = 0;
|
||||
for (size_t i = 1; i < cnode->size(); i++) {
|
||||
auto inputNode = cnode->input(i);
|
||||
if (inputNode->isa<Parameter>()) {
|
||||
|
@ -146,6 +147,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
|
|||
param_value = nullptr;
|
||||
continue;
|
||||
} else {
|
||||
index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -169,11 +171,11 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
|
|||
|
||||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status =
|
||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
|
||||
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
true, index - 1);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status =
|
||||
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
|
||||
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
true, index - 1);
|
||||
}
|
||||
if (status == RET_CONTINUE) {
|
||||
return RET_OK;
|
||||
|
|
Loading…
Reference in New Issue