forked from OSSInnovation/mindspore
!6730 [MSLITE]fp16 weight dequant
Merge pull request !6730 from wangchangkai/master
This commit is contained in:
commit
b993ea0288
|
@ -142,6 +142,18 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
|
||||||
const mindspore::lite::PrimitiveC *primitive) {
|
const mindspore::lite::PrimitiveC *primitive) {
|
||||||
MS_ASSERT(opParameter != nullptr);
|
MS_ASSERT(opParameter != nullptr);
|
||||||
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
|
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
|
||||||
|
|
||||||
|
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||||
|
auto *restore_data = weight_tensor->MutableData();
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
|
||||||
|
if (dequant_weight == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
weight_tensor->SetData(dequant_weight);
|
||||||
|
}
|
||||||
|
|
||||||
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
||||||
kernel::LiteKernel *kernel;
|
kernel::LiteKernel *kernel;
|
||||||
if (conv_param->input_channel_ < 32) {
|
if (conv_param->input_channel_ < 32) {
|
||||||
|
@ -152,6 +164,10 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
|
||||||
}
|
}
|
||||||
if (kernel == nullptr) {
|
if (kernel == nullptr) {
|
||||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto ret = kernel->Init();
|
auto ret = kernel->Init();
|
||||||
|
@ -159,8 +175,16 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
|
||||||
delete kernel;
|
delete kernel;
|
||||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return kernel;
|
return kernel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -218,6 +218,18 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
|
||||||
const mindspore::lite::PrimitiveC *primitive) {
|
const mindspore::lite::PrimitiveC *primitive) {
|
||||||
MS_ASSERT(opParameter != nullptr);
|
MS_ASSERT(opParameter != nullptr);
|
||||||
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
|
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
|
||||||
|
|
||||||
|
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||||
|
auto *restore_data = weight_tensor->MutableData();
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
|
||||||
|
if (dequant_weight == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
weight_tensor->SetData(dequant_weight);
|
||||||
|
}
|
||||||
|
|
||||||
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
||||||
int kernel_h = conv_param->kernel_h_;
|
int kernel_h = conv_param->kernel_h_;
|
||||||
int kernel_w = conv_param->kernel_w_;
|
int kernel_w = conv_param->kernel_w_;
|
||||||
|
@ -249,6 +261,10 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
|
||||||
}
|
}
|
||||||
if (kernel == nullptr) {
|
if (kernel == nullptr) {
|
||||||
MS_LOG(DEBUG) << "Create conv fp16 kernel failed.";
|
MS_LOG(DEBUG) << "Create conv fp16 kernel failed.";
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto ret = kernel->Init();
|
auto ret = kernel->Init();
|
||||||
|
@ -256,8 +272,16 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
|
||||||
delete kernel;
|
delete kernel;
|
||||||
MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_
|
MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_
|
||||||
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return kernel;
|
return kernel;
|
||||||
}
|
}
|
||||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator)
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator)
|
||||||
|
|
|
@ -201,9 +201,25 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
|
||||||
const mindspore::lite::PrimitiveC *primitive) {
|
const mindspore::lite::PrimitiveC *primitive) {
|
||||||
MS_ASSERT(opParameter != nullptr);
|
MS_ASSERT(opParameter != nullptr);
|
||||||
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
|
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
|
||||||
|
|
||||||
|
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||||
|
auto *restore_data = weight_tensor->MutableData();
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
|
||||||
|
if (dequant_weight == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
weight_tensor->SetData(dequant_weight);
|
||||||
|
}
|
||||||
|
|
||||||
auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||||
if (kernel == nullptr) {
|
if (kernel == nullptr) {
|
||||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto ret = kernel->Init();
|
auto ret = kernel->Init();
|
||||||
|
@ -211,8 +227,16 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
|
||||||
delete kernel;
|
delete kernel;
|
||||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return kernel;
|
return kernel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -208,9 +208,25 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
|
||||||
const mindspore::lite::PrimitiveC *primitive) {
|
const mindspore::lite::PrimitiveC *primitive) {
|
||||||
MS_ASSERT(opParameter != nullptr);
|
MS_ASSERT(opParameter != nullptr);
|
||||||
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
|
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
|
||||||
|
|
||||||
|
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||||
|
auto *restore_data = weight_tensor->MutableData();
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
|
||||||
|
if (dequant_weight == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
weight_tensor->SetData(dequant_weight);
|
||||||
|
}
|
||||||
|
|
||||||
auto kernel = new (std::nothrow) DeConvolutionFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
auto kernel = new (std::nothrow) DeConvolutionFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||||
if (kernel == nullptr) {
|
if (kernel == nullptr) {
|
||||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto ret = kernel->Init();
|
auto ret = kernel->Init();
|
||||||
|
@ -218,8 +234,16 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
|
||||||
delete kernel;
|
delete kernel;
|
||||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
if (weight_tensor->data_type() == kNumberTypeInt8 || primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return kernel;
|
return kernel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -170,9 +170,24 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
|
||||||
OpParameter *opParameter, const lite::InnerContext *ctx,
|
OpParameter *opParameter, const lite::InnerContext *ctx,
|
||||||
const kernel::KernelKey &desc,
|
const kernel::KernelKey &desc,
|
||||||
const mindspore::lite::PrimitiveC *primitive) {
|
const mindspore::lite::PrimitiveC *primitive) {
|
||||||
|
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||||
|
// data of second tensor of fc may be nullptr
|
||||||
|
auto *restore_data = weight_tensor->data_c();
|
||||||
|
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
|
||||||
|
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
|
||||||
|
if (dequant_weight == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
weight_tensor->SetData(dequant_weight);
|
||||||
|
}
|
||||||
auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||||
if (kernel == nullptr) {
|
if (kernel == nullptr) {
|
||||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||||
|
if (!weight_tensor->GetQuantParams().empty()) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto ret = kernel->Init();
|
auto ret = kernel->Init();
|
||||||
|
@ -180,8 +195,16 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
|
||||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||||
delete kernel;
|
delete kernel;
|
||||||
|
if (!weight_tensor->GetQuantParams().empty()) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
if (!weight_tensor->GetQuantParams().empty()) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return kernel;
|
return kernel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -253,9 +253,23 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
|
||||||
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
|
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
|
||||||
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
|
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
|
||||||
const mindspore::lite::PrimitiveC *primitive) {
|
const mindspore::lite::PrimitiveC *primitive) {
|
||||||
|
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||||
|
auto *restore_data = weight_tensor->data_c();
|
||||||
|
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
|
||||||
|
auto *dequant_weight = kernel::LiteKernelUtil::DequantWeight(weight_tensor);
|
||||||
|
if (dequant_weight == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
weight_tensor->SetData(dequant_weight);
|
||||||
|
}
|
||||||
auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||||
if (kernel == nullptr) {
|
if (kernel == nullptr) {
|
||||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||||
|
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto ret = kernel->Init();
|
auto ret = kernel->Init();
|
||||||
|
@ -263,8 +277,16 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
|
||||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||||
delete kernel;
|
delete kernel;
|
||||||
|
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
if (!weight_tensor->GetQuantParams().empty() && restore_data != nullptr) {
|
||||||
|
weight_tensor->FreeData();
|
||||||
|
weight_tensor->SetData(restore_data);
|
||||||
|
}
|
||||||
return kernel;
|
return kernel;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue