fix matmul quantizationt

This commit is contained in:
xutianchun 2021-02-01 19:33:39 +08:00
parent 6b39c89da7
commit 72daa10df6
6 changed files with 199 additions and 97 deletions

View File

@ -18,9 +18,10 @@
#include <memory> #include <memory>
#include "src/dequant.h" #include "src/dequant.h"
#include "src/huffman_decode.h" #include "src/huffman_decode.h"
#include "src/ops/matmul.h"
namespace mindspore::lite { namespace mindspore::lite {
float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { float *DequantUtil::DequantWeight(lite::Tensor *input_tensor, bool channel_first) {
MS_ASSERT(input_tensor != nullptr); MS_ASSERT(input_tensor != nullptr);
if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) { if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) {
MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type(); MS_LOG(ERROR) << "Conv weight input type error." << input_tensor->data_type();
@ -31,9 +32,9 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
return nullptr; return nullptr;
} }
if (input_tensor->data_type() == kNumberTypeInt16) { if (input_tensor->data_type() == kNumberTypeInt16) {
return DequantData<int16_t>(input_tensor); return DequantData<int16_t>(input_tensor, channel_first);
} else { } else {
return DequantData<int8_t>(input_tensor); return DequantData<int8_t>(input_tensor, channel_first);
} }
} }
@ -65,19 +66,35 @@ int DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_in
return RET_OK; return RET_OK;
} }
std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors, std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const mindspore::lite::PrimitiveC *primitive,
const std::vector<Tensor *> &in_tensors,
TypeId data_type, bool need_restore) { TypeId data_type, bool need_restore) {
std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data; std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data;
if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) { if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) {
auto input_i = 0;
for (auto weight_tensor : in_tensors) { for (auto weight_tensor : in_tensors) {
MS_ASSERT(weight_tensor != nullptr); MS_ASSERT(weight_tensor != nullptr);
input_i++;
auto channel_first = true;
if ((schema::PrimitiveType)primitive->Type() == schema::PrimitiveType_MatMul &&
weight_tensor->shape().size() == 2) {
auto param = reinterpret_cast<mindspore::lite::MatMul *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
if (input_i == 1) {
channel_first = !param->GetTransposeA();
} else if (input_i == 2) {
channel_first = param->GetTransposeB();
} else {
MS_LOG(WARNING) << "unexpected input_i";
}
}
auto *restore_data = weight_tensor->data_c(); auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type(); auto restore_type = weight_tensor->data_type();
bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited &&
restore_data != nullptr && restore_data != nullptr &&
(restore_type == kNumberTypeInt8 || restore_type == kNumberTypeInt16); (restore_type == kNumberTypeInt8 || restore_type == kNumberTypeInt16);
if (dequant_flag) { if (dequant_flag) {
auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor); auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor, channel_first);
if (dequant_weight == nullptr) { if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr."; MS_LOG(ERROR) << "dequant data is nullptr.";
return tensor_origin_data; return tensor_origin_data;

View File

@ -29,17 +29,18 @@
namespace mindspore::lite { namespace mindspore::lite {
class DequantUtil { class DequantUtil {
public: public:
static float *DequantWeight(lite::Tensor *input_tensor); static float *DequantWeight(lite::Tensor *input_tensor, bool);
static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); static int UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);
static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors, static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const mindspore::lite::PrimitiveC *primitive,
const std::vector<Tensor *> &in_tensors,
TypeId data_type, bool need_restore = true); TypeId data_type, bool need_restore = true);
static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map); static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map);
template <typename ST, typename DT = float> template <typename ST, typename DT = float>
static DT *DequantData(lite::Tensor *input_tensor) { static DT *DequantData(lite::Tensor *input_tensor, bool channel_first = true) {
const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData()); const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData());
if (quant_datas == nullptr) { if (quant_datas == nullptr) {
MS_LOG(ERROR) << "Get quant tensor failed."; MS_LOG(ERROR) << "Get quant tensor failed.";
@ -65,6 +66,13 @@ class DequantUtil {
} }
} else if (input_tensor->quant_params().size() != kPerTensor) { } else if (input_tensor->quant_params().size() != kPerTensor) {
auto channels = static_cast<size_t>(input_tensor->Batch()); auto channels = static_cast<size_t>(input_tensor->Batch());
if (!channel_first) {
if (input_tensor->shape().size() != 2) {
MS_LOG(ERROR) << "unexpected shape size: " << input_tensor->shape().size();
return nullptr;
}
channels = input_tensor->shape()[1];
}
if (input_tensor->quant_params().size() != channels) { if (input_tensor->quant_params().size() != channels) {
MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->quant_params().size() << channels; MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->quant_params().size() << channels;
free(dequant_datas); free(dequant_datas);
@ -83,8 +91,12 @@ class DequantUtil {
var_corr = 1; var_corr = 1;
} }
for (size_t j = 0; j < per_channel_size; j++) { for (size_t j = 0; j < per_channel_size; j++) {
auto dequant_data = (quant_datas[per_channel_size * i + j] - zero_point) * scale; auto index = per_channel_size * i + j;
dequant_datas[per_channel_size * i + j] = static_cast<DT>(dequant_data * var_corr + mean_corr); if (!channel_first) {
index = channels * j + i;
}
auto dequant_data = (quant_datas[index] - zero_point) * scale;
dequant_datas[index] = static_cast<DT>(dequant_data * var_corr + mean_corr);
} }
} }
} else { } else {

View File

@ -223,7 +223,8 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
if (mindspore::lite::IsSupportFloat16() && if (mindspore::lite::IsSupportFloat16() &&
((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) { ((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) {
kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type}; kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type};
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type, need_restore); auto tensor_origin_data_map =
DequantUtil::DequantTensor(primitive, in_tensors, fp16_cpu_desc.data_type, need_restore);
auto *kernel = auto *kernel =
KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc); KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc);
DequantUtil::RestoreTensorData(tensor_origin_data_map); DequantUtil::RestoreTensorData(tensor_origin_data_map);
@ -237,7 +238,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op."; MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
desc.data_type = kNumberTypeFloat32; desc.data_type = kNumberTypeFloat32;
} }
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type, need_restore); auto tensor_origin_data_map = DequantUtil::DequantTensor(primitive, in_tensors, desc.data_type, need_restore);
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc); auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
DequantUtil::RestoreTensorData(tensor_origin_data_map); DequantUtil::RestoreTensorData(tensor_origin_data_map);
if (kernel != nullptr) { if (kernel != nullptr) {

View File

@ -358,15 +358,15 @@ static bool SearchUpperBound(const std::vector<float> &data, const size_t &index
return true; return true;
} }
static float CalPercentile(const std::vector<float> &datas, const int &outlier_percent) { static float CalPercentile(const std::vector<float> &data, const int &outlier_percent) {
const int size = datas.size(); const int size = data.size();
float val = outlier_percent / 100.0 * size; float val = outlier_percent / 100.0 * size;
int index = std::ceil(val); int index = std::ceil(val);
float result; float result;
if (index - val > 0) { if (index - val > 0) {
result = datas.at(index - 1); result = data.at(index - 1);
} else { } else {
result = (datas.at(index - 1) + datas.at(index)) / 2; result = (data.at(index - 1) + data.at(index)) / 2;
} }
return result; return result;
} }
@ -522,11 +522,78 @@ std::vector<std::vector<int>> DataToVectors(const string &str) {
return result; return result;
} }
STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) { void ParseInputShape(PostQuantConfig *post_quant_config, std::string raw_shape) {
if (post_quant_config == nullptr) { MS_ASSERT(post_quant_config != nullptr);
MS_LOG(ERROR) << "post_quant_config is null."; auto ind = raw_shape.find('/');
return RET_PARAM_INVALID; while (ind != std::string::npos) {
auto shape = raw_shape.substr(0, ind);
Trim(&shape);
post_quant_config->input_shapes.push_back(DataToVectors(shape));
raw_shape = raw_shape.substr(ind + 1);
Trim(&raw_shape);
ind = raw_shape.find('/');
} }
if (!raw_shape.empty()) {
post_quant_config->input_shapes.push_back(DataToVectors(raw_shape));
}
}
void ParseImagePath(PostQuantConfig *post_quant_config, std::string raw_image_paths) {
MS_ASSERT(post_quant_config != nullptr);
auto ind = raw_image_paths.find(',');
while (ind != std::string::npos) {
auto image_path = raw_image_paths.substr(0, ind);
Trim(&image_path);
post_quant_config->image_paths.push_back(image_path);
raw_image_paths = raw_image_paths.substr(ind + 1);
Trim(&raw_image_paths);
ind = raw_image_paths.find(',');
}
post_quant_config->image_paths.push_back(raw_image_paths);
}
void ParseBatchCount(PostQuantConfig *post_quant_config, std::string value) {
MS_ASSERT(post_quant_config != nullptr);
post_quant_config->batch_count = std::stoul(value);
}
void ParseThreadNum(PostQuantConfig *post_quant_config, std::string value) {
MS_ASSERT(post_quant_config != nullptr);
post_quant_config->thread_num = std::stoul(value);
}
void ParseMethodX(PostQuantConfig *post_quant_config, const std::string &value) {
MS_ASSERT(post_quant_config != nullptr);
if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) {
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
} else {
post_quant_config->method_x = value;
}
}
void ParseMixed(PostQuantConfig *post_quant_config, std::string value) {
MS_ASSERT(post_quant_config != nullptr);
std::for_each(value.begin(), value.end(), ::tolower);
if (value == "true") {
post_quant_config->mixed = true;
}
}
void ParseMeanErrorThreshold(PostQuantConfig *post_quant_config, std::string value) {
MS_ASSERT(post_quant_config != nullptr);
post_quant_config->mean_error_threshold = std::stof(value);
}
void ParseBiasCorrection(PostQuantConfig *post_quant_config, std::string value) {
MS_ASSERT(post_quant_config != nullptr);
std::for_each(value.begin(), value.end(), ::tolower);
if (value == "true") {
post_quant_config->bias_correction = true;
}
}
STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) {
MS_ASSERT(post_quant_config != nullptr);
if (config_file.empty() || config_file.length() > PATH_MAX) { if (config_file.empty() || config_file.length() > PATH_MAX) {
MS_LOG(ERROR) << "invalid config path!"; MS_LOG(ERROR) << "invalid config path!";
@ -552,6 +619,26 @@ STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_conf
MS_LOG(ERROR) << "config file open failed: " << config_file; MS_LOG(ERROR) << "config file open failed: " << config_file;
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
std::string INPUT_SHAPES = "input_shapes";
std::string IMAGE_PATH = "image_path";
std::string BATCH_COUNT = "batch_count";
std::string THREAD_NUM = "thread_num";
std::string METHOD_X = "method_x";
std::string MIXED = "mixed";
std::string MEAN_ERROR_THRESHOLD = "mean_error_threshold";
std::string BIAS_CORRECTION = "bias_correction";
std::map<std::string, std::function<void(PostQuantConfig *, std::string)>> value_parser;
value_parser[INPUT_SHAPES] = ParseInputShape;
value_parser[IMAGE_PATH] = ParseImagePath;
value_parser[BATCH_COUNT] = ParseBatchCount;
value_parser[THREAD_NUM] = ParseThreadNum;
value_parser[METHOD_X] = ParseMethodX;
value_parser[MIXED] = ParseMixed;
value_parser[MEAN_ERROR_THRESHOLD] = ParseMeanErrorThreshold;
value_parser[BIAS_CORRECTION] = ParseBiasCorrection;
std::string line; std::string line;
while (std::getline(fs, line)) { while (std::getline(fs, line)) {
Trim(&line); Trim(&line);
@ -567,54 +654,9 @@ STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_conf
auto value = line.substr(index + 1); auto value = line.substr(index + 1);
Trim(&key); Trim(&key);
Trim(&value); Trim(&value);
if (key == "image_path") { auto it = value_parser.find(key);
auto &raw_image_paths = value; if (it != value_parser.end()) {
auto ind = raw_image_paths.find(','); it->second(post_quant_config, value);
while (ind != std::string::npos) {
auto image_path = raw_image_paths.substr(0, ind);
Trim(&image_path);
post_quant_config->image_paths.push_back(image_path);
raw_image_paths = raw_image_paths.substr(ind + 1);
Trim(&raw_image_paths);
ind = raw_image_paths.find(',');
}
post_quant_config->image_paths.push_back(raw_image_paths);
} else if (key == "batch_count") {
post_quant_config->batch_count = std::stoul(value);
} else if (key == "thread_num") {
post_quant_config->thread_num = std::stoul(value);
} else if (key == "method_x") {
if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) {
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
} else {
post_quant_config->method_x = value;
}
} else if (key == "bias_correction") {
std::for_each(value.begin(), value.end(), ::tolower);
if (value == "true") {
post_quant_config->bias_correction = true;
}
} else if (key == "mixed") {
std::for_each(value.begin(), value.end(), ::tolower);
if (value == "true") {
post_quant_config->mixed = true;
}
} else if (key == "mean_error_threshold") {
post_quant_config->mean_error_threshold = std::stof(value);
} else if (key == "input_shapes") {
auto &raw_shape = value;
auto ind = raw_shape.find('/');
while (ind != std::string::npos) {
auto shape = raw_shape.substr(0, ind);
Trim(&shape);
post_quant_config->input_shapes.push_back(DataToVectors(shape));
raw_shape = raw_shape.substr(ind + 1);
Trim(&raw_shape);
ind = raw_shape.find('/');
}
if (!raw_shape.empty()) {
post_quant_config->input_shapes.push_back(DataToVectors(raw_shape));
}
} else { } else {
MS_LOG(WARNING) << "unsupported parameter: " << key; MS_LOG(WARNING) << "unsupported parameter: " << key;
} }
@ -881,4 +923,24 @@ STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int
return RET_OK; return RET_OK;
} }
void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, float *raw_datas,
bool channel_at_first, float *desired_max, float *desired_min) {
float min = FLT_MAX;
float max = -FLT_MAX;
// find min and max
for (int j = 0; j < one_filter_size; j++) {
auto index = j + i * one_filter_size;
if (!channel_at_first) {
index = j * channels + i;
}
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
}
min = std::min(min, raw_datas[index]);
max = std::max(max, raw_datas[index]);
}
*desired_max = max;
*desired_min = min;
}
} // namespace mindspore::lite::quant } // namespace mindspore::lite::quant

View File

@ -107,6 +107,9 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc
STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int new_size); STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int new_size);
void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, float *raw_datas,
bool channel_at_first, float *desired_max, float *desired_min);
template <typename T> template <typename T>
T QuantizeData(const float originData, const schema::QuantParamT *quantParam) { T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
MS_ASSERT(quantParam != nullptr); MS_ASSERT(quantParam != nullptr);
@ -163,11 +166,19 @@ template <typename T>
STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant_type, STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant_type,
std::vector<schema::QuantParamT> *quant_params, const int &quant_max, const int &quant_min, std::vector<schema::QuantParamT> *quant_params, const int &quant_max, const int &quant_min,
const size_t &bit_num, const bool &k_means, std::vector<T> *quant_datas, const size_t &bit_num, const bool &k_means, std::vector<T> *quant_datas,
std::vector<float> *dequant_datas) { std::vector<float> *dequant_datas, bool channel_at_first = true) {
auto dims = weight->tensor_shape(); auto dims = weight->tensor_shape();
size_t elem_count = weight->tensor_shape_size(); size_t elem_count = weight->tensor_shape_size();
auto *raw_datas = static_cast<float *>(weight->tensor_addr()); auto *raw_datas = static_cast<float *>(weight->tensor_addr());
auto channels = dims[0]; auto channels = dims[0];
if (!channel_at_first) {
if (dims.size() != 2) {
MS_LOG(ERROR) << "unexpected dims size: " << dims.size();
channel_at_first = true;
} else {
channels = dims[1];
}
}
if (channels == 0) { if (channels == 0) {
MS_LOG(ERROR) << "channels is zero"; MS_LOG(ERROR) << "channels is zero";
return RET_ERROR; return RET_ERROR;
@ -181,16 +192,7 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant
for (int i = 0; i < channels; i++) { for (int i = 0; i < channels; i++) {
float min = FLT_MAX; float min = FLT_MAX;
float max = -FLT_MAX; float max = -FLT_MAX;
// find min and max GetMaxMinPerchannel(channels, one_filter_size, i, elem_count, raw_datas, channel_at_first, &max, &min);
for (size_t j = 0; j < one_filter_size; j++) {
auto index = j + i * one_filter_size;
if (index >= elem_count) {
MS_LOG(ERROR) << "over flow!";
return RET_ERROR;
}
min = std::min(min, raw_datas[index]);
max = std::max(max, raw_datas[index]);
}
schema::QuantParamT quant_param; schema::QuantParamT quant_param;
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num); STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num);
if (status != RET_OK) { if (status != RET_OK) {
@ -202,10 +204,10 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant
double average_raw = 0; double average_raw = 0;
for (uint32_t j = 0; j < one_filter_size; j++) { for (uint32_t j = 0; j < one_filter_size; j++) {
auto index = j + i * one_filter_size; auto index = j + i * one_filter_size;
if (index >= elem_count) { if (!channel_at_first) {
MS_LOG(ERROR) << "over flow!"; index = j * channels + i;
return RET_ERROR;
} }
MS_ASSERT(index < elem_count);
float raw_data = raw_datas[index]; float raw_data = raw_datas[index];
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min); auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
(*quant_datas)[index] = quant_data; (*quant_datas)[index] = quant_data;
@ -226,10 +228,10 @@ STATUS DoPerChannelQuant(const ParamValueLitePtr &weight, const QuantType &quant
double variance_raw = 0; double variance_raw = 0;
for (uint32_t j = 0; j < one_filter_size; j++) { for (uint32_t j = 0; j < one_filter_size; j++) {
auto index = j + i * one_filter_size; auto index = j + i * one_filter_size;
if (index >= elem_count) { if (!channel_at_first) {
MS_LOG(ERROR) << "over flow!"; index = j * channels + i;
return RET_ERROR;
} }
MS_ASSERT(index < elem_count);
variance_dequant += std::pow((*dequant_datas)[index] - average_dequant, 2); variance_dequant += std::pow((*dequant_datas)[index] - average_dequant, 2);
variance_raw += std::pow(raw_datas[index] - average_raw, 2); variance_raw += std::pow(raw_datas[index] - average_raw, 2);
} }
@ -339,20 +341,26 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit
std::vector<schema::QuantParamT> quant_params; std::vector<schema::QuantParamT> quant_params;
size_t elem_count = weight->tensor_shape_size(); size_t elem_count = weight->tensor_shape_size();
auto *raw_datas = static_cast<float *>(weight->tensor_addr()); auto *raw_data = static_cast<float *>(weight->tensor_addr());
if (raw_datas == nullptr) { if (raw_data == nullptr) {
MS_LOG(ERROR) << "rawDatas is nullptr"; MS_LOG(ERROR) << "rawDatas is nullptr";
return RET_ERROR; return RET_ERROR;
} }
std::vector<T> quant_datas(elem_count); std::vector<T> quant_data(elem_count);
std::vector<float> dequant_datas(elem_count); std::vector<float> dequant_datas(elem_count);
int ret = RET_OK; int ret = RET_OK;
if (per_channel) { if (per_channel) {
// notice: assume Con2D\DepthwiseConv2D's weight format are same: KHWC bool channel_at_first = true;
auto op_type = (schema::PrimitiveType)primitive_c->Type();
if (op_type == schema::PrimitiveType_MatMul && weight->tensor_shape().size() == 2) {
auto matmul_op = primitive_c->primitiveT()->value.AsMatMul();
MS_ASSERT(matmul_op != nullptr);
channel_at_first = !(index == 1 && !matmul_op->transposeB);
}
// channel at first // channel at first
ret = DoPerChannelQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_datas, ret = DoPerChannelQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_data,
&dequant_datas); &dequant_datas, channel_at_first);
if (ret == RET_CONTINUE) { if (ret == RET_CONTINUE) {
return ret; return ret;
} else if (ret != RET_OK) { } else if (ret != RET_OK) {
@ -360,7 +368,7 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit
return ret; return ret;
} }
} else { } else {
ret = DoPerLayerQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_datas); ret = DoPerLayerQuant<T>(weight, quant_type, &quant_params, quant_max, quant_min, bit_num, k_means, &quant_data);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Do per layer quant failed."; MS_LOG(ERROR) << "Do per layer quant failed.";
return ret; return ret;
@ -376,7 +384,7 @@ STATUS QuantFilter(const ParamValueLitePtr &weight, const std::shared_ptr<Primit
} }
#else #else
// do bit pack // do bit pack
ret = DoBitPack(weight, bit_num, quant_datas); ret = DoBitPack(weight, bit_num, quant_data);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Do bit pack failed."; MS_LOG(ERROR) << "Do bit pack failed.";
return ret; return ret;

View File

@ -127,6 +127,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
auto already_quant = false; auto already_quant = false;
ParamValueLitePtr param_value = nullptr; ParamValueLitePtr param_value = nullptr;
ParameterPtr param_node = nullptr; ParameterPtr param_node = nullptr;
int index = 0;
for (size_t i = 1; i < cnode->size(); i++) { for (size_t i = 1; i < cnode->size(); i++) {
auto inputNode = cnode->input(i); auto inputNode = cnode->input(i);
if (inputNode->isa<Parameter>()) { if (inputNode->isa<Parameter>()) {
@ -146,6 +147,7 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
param_value = nullptr; param_value = nullptr;
continue; continue;
} else { } else {
index = i;
break; break;
} }
} }
@ -169,11 +171,11 @@ STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
auto status = RET_ERROR; auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) { if (type_id_ == kNumberTypeInt8) {
status = status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); true, index - 1);
} else if (type_id_ == kNumberTypeInt16) { } else if (type_id_ == kNumberTypeInt16) {
status = status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true); true, index - 1);
} }
if (status == RET_CONTINUE) { if (status == RET_CONTINUE) {
return RET_OK; return RET_OK;