forked from mindspore-Ecosystem/mindspore
fix gather weight quant bug
This commit is contained in:
parent
f654167045
commit
9428ffe860
|
@ -51,7 +51,7 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors,
|
std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors,
|
||||||
TypeId data_type) {
|
TypeId data_type, bool need_restore) {
|
||||||
std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data;
|
std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data;
|
||||||
if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) {
|
if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) {
|
||||||
for (auto weight_tensor : in_tensors) {
|
for (auto weight_tensor : in_tensors) {
|
||||||
|
@ -59,16 +59,21 @@ std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const s
|
||||||
auto *restore_data = weight_tensor->data_c();
|
auto *restore_data = weight_tensor->data_c();
|
||||||
auto restore_type = weight_tensor->data_type();
|
auto restore_type = weight_tensor->data_type();
|
||||||
bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited &&
|
bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited &&
|
||||||
restore_data != nullptr;
|
restore_data != nullptr &&
|
||||||
|
(restore_type == kNumberTypeInt8 || restore_type == kNumberTypeInt16);
|
||||||
if (dequant_flag) {
|
if (dequant_flag) {
|
||||||
auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor);
|
auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor);
|
||||||
if (dequant_weight == nullptr) {
|
if (dequant_weight == nullptr) {
|
||||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||||
return tensor_origin_data;
|
return tensor_origin_data;
|
||||||
}
|
}
|
||||||
|
if (need_restore) {
|
||||||
|
tensor_origin_data[weight_tensor] = {restore_type, restore_data};
|
||||||
|
} else {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
}
|
||||||
weight_tensor->set_data(dequant_weight);
|
weight_tensor->set_data(dequant_weight);
|
||||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||||
tensor_origin_data[weight_tensor] = {restore_type, restore_data};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,7 +34,7 @@ class DequantUtil {
|
||||||
static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);
|
static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);
|
||||||
|
|
||||||
static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors,
|
static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors,
|
||||||
TypeId data_type);
|
TypeId data_type, bool need_restore = true);
|
||||||
|
|
||||||
static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map);
|
static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map);
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ class DequantUtil {
|
||||||
auto var_corr = param.var_corr;
|
auto var_corr = param.var_corr;
|
||||||
auto mean_corr = param.mean_corr;
|
auto mean_corr = param.mean_corr;
|
||||||
if (var_corr < 0 || var_corr > 10) {
|
if (var_corr < 0 || var_corr > 10) {
|
||||||
MS_LOG(WARNING) << "unexpeted var_corr: " << var_corr;
|
MS_LOG(WARNING) << "unexpected var_corr: " << var_corr;
|
||||||
var_corr = 1;
|
var_corr = 1;
|
||||||
}
|
}
|
||||||
for (size_t j = 0; j < per_channel_size; j++) {
|
for (size_t j = 0; j < per_channel_size; j++) {
|
||||||
|
|
|
@ -38,10 +38,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace lite {
|
namespace lite {
|
||||||
static std::vector<schema::PrimitiveType> packed_op = {
|
|
||||||
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DepthwiseConv2D,
|
|
||||||
schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_MatMul};
|
|
||||||
|
|
||||||
// this method will not check whether tensor_idx is a weight tensor index, caller should ensure this.
|
// this method will not check whether tensor_idx is a weight tensor index, caller should ensure this.
|
||||||
static bool WeightTensorNeedCopy(const lite::Model *model, const uint32_t tensor_idx) {
|
static bool WeightTensorNeedCopy(const lite::Model *model, const uint32_t tensor_idx) {
|
||||||
#ifdef SUPPORT_TRAIN
|
#ifdef SUPPORT_TRAIN
|
||||||
|
@ -92,8 +88,13 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
|
||||||
lite::Tensor *dst_tensor) {
|
lite::Tensor *dst_tensor) {
|
||||||
MS_ASSERT(src_tensor != nullptr);
|
MS_ASSERT(src_tensor != nullptr);
|
||||||
MS_ASSERT(dst_tensor != nullptr);
|
MS_ASSERT(dst_tensor != nullptr);
|
||||||
auto src_category = TensorCategory(src_tensor);
|
auto NeedUnPack = [&src_tensor, &dst_tensor]() -> bool {
|
||||||
auto data_type = src_tensor->dataType();
|
auto data_type = src_tensor->dataType();
|
||||||
|
int pack_size = src_tensor->data()->size();
|
||||||
|
int org_size = dst_tensor->Size();
|
||||||
|
return (pack_size != org_size) && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16);
|
||||||
|
};
|
||||||
|
auto src_category = TensorCategory(src_tensor);
|
||||||
if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) &&
|
if ((src_category == Tensor::Category::CONST_TENSOR || src_category == Tensor::Category::CONST_SCALAR) &&
|
||||||
src_tensor->data() != nullptr && src_tensor->data()->size() > 0) {
|
src_tensor->data() != nullptr && src_tensor->data()->size() > 0) {
|
||||||
if (src_tensor->dataType() == kObjectTypeTensorType) {
|
if (src_tensor->dataType() == kObjectTypeTensorType) {
|
||||||
|
@ -112,18 +113,20 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
|
||||||
MS_LOG(ERROR) << "Data from tensor is nullptr";
|
MS_LOG(ERROR) << "Data from tensor is nullptr";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
|
if (NeedUnPack()) {
|
||||||
|
DequantUtil::UnPackToInt(src_tensor, dst_data);
|
||||||
|
} else {
|
||||||
memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size());
|
memcpy(dst_data, src_tensor->data()->data(), dst_tensor->Size());
|
||||||
|
}
|
||||||
copyed_tensor_idxes_.emplace_back(tensor_index);
|
copyed_tensor_idxes_.emplace_back(tensor_index);
|
||||||
} else {
|
} else {
|
||||||
int pack_size = src_tensor->data()->size();
|
if (NeedUnPack()) {
|
||||||
int org_size = dst_tensor->Size();
|
auto dst_data = dst_tensor->MutableData();
|
||||||
if (pack_size != org_size && (data_type == kNumberTypeInt8 || data_type == kNumberTypeInt16)) {
|
if (dst_data == nullptr) {
|
||||||
auto ret = dst_tensor->MallocData();
|
MS_LOG(ERROR) << "Data from tensor is nullptr";
|
||||||
if (ret != RET_OK) {
|
return RET_NULL_PTR;
|
||||||
MS_LOG(ERROR) << "Malloc data for tensor failed ";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
}
|
||||||
DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData());
|
DequantUtil::UnPackToInt(src_tensor, dst_data);
|
||||||
copyed_tensor_idxes_.emplace_back(tensor_index);
|
copyed_tensor_idxes_.emplace_back(tensor_index);
|
||||||
} else {
|
} else {
|
||||||
dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data()));
|
dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data()));
|
||||||
|
@ -713,12 +716,12 @@ int LiteSession::InitGPURuntime() {
|
||||||
session::LiteSession *session::LiteSession::CreateSession(const lite::Context *context) {
|
session::LiteSession *session::LiteSession::CreateSession(const lite::Context *context) {
|
||||||
auto session = new (std::nothrow) lite::LiteSession();
|
auto session = new (std::nothrow) lite::LiteSession();
|
||||||
if (session == nullptr) {
|
if (session == nullptr) {
|
||||||
MS_LOG(ERROR) << "create sesssion failed";
|
MS_LOG(ERROR) << "create session failed";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto ret = session->Init(context);
|
auto ret = session->Init(context);
|
||||||
if (ret != mindspore::lite::RET_OK) {
|
if (ret != mindspore::lite::RET_OK) {
|
||||||
MS_LOG(ERROR) << "init sesssion failed";
|
MS_LOG(ERROR) << "init session failed";
|
||||||
delete session;
|
delete session;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -729,7 +732,7 @@ session::LiteSession *session::LiteSession::CreateSession(const char *model_buf,
|
||||||
const lite::Context *context) {
|
const lite::Context *context) {
|
||||||
auto *session = LiteSession::CreateSession(context);
|
auto *session = LiteSession::CreateSession(context);
|
||||||
if (session == nullptr) {
|
if (session == nullptr) {
|
||||||
MS_LOG(ERROR) << "Create sesssion failed";
|
MS_LOG(ERROR) << "Create session failed";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto *model = lite::ImportFromBuffer(model_buf, size, true);
|
auto *model = lite::ImportFromBuffer(model_buf, size, true);
|
||||||
|
|
|
@ -107,8 +107,10 @@ int LstmCPUKernel::InitWeightBias() {
|
||||||
}
|
}
|
||||||
memcpy(weight_h_ptr_, weight_h->MutableData(), weight_h->ElementsNum() * sizeof(float));
|
memcpy(weight_h_ptr_, weight_h->MutableData(), weight_h->ElementsNum() * sizeof(float));
|
||||||
|
|
||||||
|
std::vector<int> w_shape = weight_i->shape();
|
||||||
|
auto hidden_size = w_shape.at(1) / 4;
|
||||||
// init bias
|
// init bias
|
||||||
int bias_num = lstm_parm_->bidirectional_ ? 2 * 4 * lstm_parm_->hidden_size_ : 4 * lstm_parm_->hidden_size_;
|
int bias_num = lstm_parm_->bidirectional_ ? 2 * 4 * hidden_size : 4 * hidden_size;
|
||||||
bias_ptr_ = reinterpret_cast<float *>(malloc(bias_num * sizeof(float)));
|
bias_ptr_ = reinterpret_cast<float *>(malloc(bias_num * sizeof(float)));
|
||||||
if (bias_ptr_ == nullptr) {
|
if (bias_ptr_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error.";
|
MS_LOG(ERROR) << "LstmCPUKernel malloc bias_ptr_ error.";
|
||||||
|
@ -116,13 +118,13 @@ int LstmCPUKernel::InitWeightBias() {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->MutableData());
|
auto bias_data = reinterpret_cast<float *>(in_tensors_.at(3)->MutableData());
|
||||||
const int state_bias_offset = 4 * lstm_parm_->hidden_size_;
|
const int state_bias_offset = 4 * hidden_size;
|
||||||
for (int i = 0; i < state_bias_offset; i++) {
|
for (int i = 0; i < state_bias_offset; i++) {
|
||||||
bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset];
|
bias_ptr_[i] = bias_data[i] + bias_data[i + state_bias_offset];
|
||||||
}
|
}
|
||||||
if (lstm_parm_->bidirectional_) {
|
if (lstm_parm_->bidirectional_) {
|
||||||
bias_data += 4 * lstm_parm_->hidden_size_ * 2;
|
bias_data += 4 * hidden_size * 2;
|
||||||
auto backward_bias = bias_ptr_ + 4 * lstm_parm_->hidden_size_;
|
auto backward_bias = bias_ptr_ + 4 * hidden_size;
|
||||||
for (int i = 0; i < state_bias_offset; i++) {
|
for (int i = 0; i < state_bias_offset; i++) {
|
||||||
backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset];
|
backward_bias[i] = bias_data[i] + bias_data[i + state_bias_offset];
|
||||||
}
|
}
|
||||||
|
@ -131,6 +133,14 @@ int LstmCPUKernel::InitWeightBias() {
|
||||||
}
|
}
|
||||||
|
|
||||||
int LstmCPUKernel::Init() {
|
int LstmCPUKernel::Init() {
|
||||||
|
FreeTmpBuffer();
|
||||||
|
auto ret = InitWeightBias();
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error.";
|
||||||
|
FreeTmpBuffer();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
if (!InferShapeDone()) {
|
if (!InferShapeDone()) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -138,20 +148,12 @@ int LstmCPUKernel::Init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
int LstmCPUKernel::ReSize() {
|
int LstmCPUKernel::ReSize() {
|
||||||
FreeTmpBuffer();
|
|
||||||
auto ret = InitParam();
|
auto ret = InitParam();
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "LstmCPUKernel InitParam error.";
|
MS_LOG(ERROR) << "LstmCPUKernel InitParam error.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = InitWeightBias();
|
|
||||||
if (ret != RET_OK) {
|
|
||||||
MS_LOG(ERROR) << "LstmCPUKernel InitWeightBias error.";
|
|
||||||
FreeTmpBuffer();
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
|
|
||||||
ret = InitBuffer();
|
ret = InitBuffer();
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "LstmCPUKernel InitBuffer error.";
|
MS_LOG(ERROR) << "LstmCPUKernel InitBuffer error.";
|
||||||
|
|
|
@ -184,6 +184,13 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
||||||
const Model::Node *node) {
|
const Model::Node *node) {
|
||||||
MS_ASSERT(primitive != nullptr);
|
MS_ASSERT(primitive != nullptr);
|
||||||
TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors);
|
TypeId data_type = GetFirstFp32Fp16OrInt8Type(in_tensors);
|
||||||
|
bool need_restore = true;
|
||||||
|
if (primitive->quant_type() == schema::QuantType_WeightQuant) {
|
||||||
|
data_type = kNumberTypeFloat32;
|
||||||
|
}
|
||||||
|
if (!IsContain(packed_op, (schema::PrimitiveType)primitive->Type())) {
|
||||||
|
need_restore = false;
|
||||||
|
}
|
||||||
kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())};
|
kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())};
|
||||||
#if SUPPORT_GPU
|
#if SUPPORT_GPU
|
||||||
if (context_->IsGpuEnabled()) {
|
if (context_->IsGpuEnabled()) {
|
||||||
|
@ -216,7 +223,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
||||||
if (mindspore::lite::IsSupportFloat16() &&
|
if (mindspore::lite::IsSupportFloat16() &&
|
||||||
((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) {
|
((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) {
|
||||||
kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type};
|
kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type};
|
||||||
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type);
|
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type, need_restore);
|
||||||
auto *kernel =
|
auto *kernel =
|
||||||
KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc);
|
KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc);
|
||||||
DequantUtil::RestoreTensorData(tensor_origin_data_map);
|
DequantUtil::RestoreTensorData(tensor_origin_data_map);
|
||||||
|
@ -230,7 +237,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
||||||
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
|
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
|
||||||
desc.data_type = kNumberTypeFloat32;
|
desc.data_type = kNumberTypeFloat32;
|
||||||
}
|
}
|
||||||
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type);
|
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type, need_restore);
|
||||||
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
|
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
|
||||||
DequantUtil::RestoreTensorData(tensor_origin_data_map);
|
DequantUtil::RestoreTensorData(tensor_origin_data_map);
|
||||||
if (kernel != nullptr) {
|
if (kernel != nullptr) {
|
||||||
|
|
|
@ -26,6 +26,12 @@
|
||||||
#include "src/ops/primitive_c.h"
|
#include "src/ops/primitive_c.h"
|
||||||
|
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
|
|
||||||
|
static std::vector<schema::PrimitiveType> packed_op = {
|
||||||
|
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
|
||||||
|
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
|
||||||
|
schema::PrimitiveType_MatMul, schema::PrimitiveType_Lstm};
|
||||||
|
|
||||||
class Scheduler {
|
class Scheduler {
|
||||||
public:
|
public:
|
||||||
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors)
|
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors)
|
||||||
|
|
|
@ -253,11 +253,11 @@ STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) {
|
||||||
}
|
}
|
||||||
auto status = RET_ERROR;
|
auto status = RET_ERROR;
|
||||||
if (type_id_ == kNumberTypeInt8) {
|
if (type_id_ == kNumberTypeInt8) {
|
||||||
status =
|
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
false, 1);
|
||||||
} else if (type_id_ == kNumberTypeInt16) {
|
} else if (type_id_ == kNumberTypeInt16) {
|
||||||
status =
|
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||||
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
false, 1);
|
||||||
}
|
}
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||||
|
@ -316,11 +316,11 @@ STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) {
|
||||||
}
|
}
|
||||||
auto status = RET_ERROR;
|
auto status = RET_ERROR;
|
||||||
if (type_id_ == kNumberTypeInt8) {
|
if (type_id_ == kNumberTypeInt8) {
|
||||||
status =
|
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
false, 3);
|
||||||
} else if (type_id_ == kNumberTypeInt16) {
|
} else if (type_id_ == kNumberTypeInt16) {
|
||||||
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||||
false);
|
false, 3);
|
||||||
}
|
}
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||||
|
@ -340,10 +340,10 @@ STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) {
|
||||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||||
MS_ASSERT(primitive_c != nullptr);
|
MS_ASSERT(primitive_c != nullptr);
|
||||||
|
|
||||||
auto weight_h = cnode->input(1);
|
auto first_input = cnode->input(1);
|
||||||
ParameterPtr param_node;
|
ParameterPtr param_node;
|
||||||
ParamValueLitePtr param_value;
|
ParamValueLitePtr param_value;
|
||||||
GetLiteParameter(weight_h, ¶m_node, ¶m_value);
|
GetLiteParameter(first_input, ¶m_node, ¶m_value);
|
||||||
if (param_node == nullptr || param_value == nullptr || param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
|
if (param_node == nullptr || param_value == nullptr || param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
|
||||||
MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight";
|
MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight";
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
@ -358,10 +358,10 @@ STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) {
|
||||||
auto status = RET_ERROR;
|
auto status = RET_ERROR;
|
||||||
if (type_id_ == kNumberTypeInt8) {
|
if (type_id_ == kNumberTypeInt8) {
|
||||||
status =
|
status =
|
||||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, 0);
|
||||||
} else if (type_id_ == kNumberTypeInt16) {
|
} else if (type_id_ == kNumberTypeInt16) {
|
||||||
status =
|
status =
|
||||||
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, 0);
|
||||||
}
|
}
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||||
|
@ -510,7 +510,7 @@ STATUS WeightQuantizer::RunFp32Graph(FuncGraphPtr func_graph) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) {
|
STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
|
||||||
// 0.2 Parse input calib files
|
// 0.2 Parse input calib files
|
||||||
auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
|
auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
|
@ -652,7 +652,7 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) {
|
||||||
delete quant_sm.model;
|
delete quant_sm.model;
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
// 3. compare betwen quant and fp32
|
// 3. compare between quant and fp32
|
||||||
auto quant_outputs = quant_session->GetOutputs();
|
auto quant_outputs = quant_session->GetOutputs();
|
||||||
mean_error += CompareOutputData<float>(fp32_output_tensors_[i], quant_outputs);
|
mean_error += CompareOutputData<float>(fp32_output_tensors_[i], quant_outputs);
|
||||||
} // end_for: calib data loop
|
} // end_for: calib data loop
|
||||||
|
@ -690,8 +690,8 @@ STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) {
|
||||||
for (auto &cnode : func_graph->GetOrderedCnodes()) {
|
for (auto &cnode : func_graph->GetOrderedCnodes()) {
|
||||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||||
if (primitive_c == nullptr) {
|
if (primitive_c == nullptr) {
|
||||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive_c is nullptr";
|
||||||
return RET_ERROR;
|
continue;
|
||||||
}
|
}
|
||||||
auto op_name = cnode->fullname_with_scope();
|
auto op_name = cnode->fullname_with_scope();
|
||||||
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
||||||
|
@ -744,7 +744,7 @@ STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||||
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
|
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
|
||||||
type_id_ = kNumberTypeInt8;
|
type_id_ = kNumberTypeInt8;
|
||||||
MS_LOG(INFO) << "Do mixed bit quantization";
|
MS_LOG(INFO) << "Do mixed bit quantization";
|
||||||
return DoMiexedQuant(func_graph);
|
return DoMixedQuant(func_graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
return DoFixedQuant(func_graph);
|
return DoFixedQuant(func_graph);
|
||||||
|
|
|
@ -62,7 +62,7 @@ class WeightQuantizer : public Quantizer {
|
||||||
std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...]
|
std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...]
|
||||||
std::vector<std::unordered_map<std::string, mindspore::tensor::MSTensor *>> fp32_output_tensors_;
|
std::vector<std::unordered_map<std::string, mindspore::tensor::MSTensor *>> fp32_output_tensors_;
|
||||||
|
|
||||||
STATUS DoMiexedQuant(FuncGraphPtr);
|
STATUS DoMixedQuant(FuncGraphPtr);
|
||||||
STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c);
|
STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c);
|
||||||
STATUS DoFixedQuant(FuncGraphPtr);
|
STATUS DoFixedQuant(FuncGraphPtr);
|
||||||
STATUS RunFp32Graph(FuncGraphPtr);
|
STATUS RunFp32Graph(FuncGraphPtr);
|
||||||
|
|
Loading…
Reference in New Issue