weight quant reconstruction && lstm/gather quant

This commit is contained in:
xutianchun 2021-01-13 10:10:07 +08:00
parent 2924552783
commit 5d613749ec
29 changed files with 414 additions and 522 deletions

View File

@ -39,6 +39,7 @@ set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc
${CMAKE_CURRENT_SOURCE_DIR}/dequant.cc
)
if (SUPPORT_GPU)

View File

@ -14,9 +14,9 @@
* limitations under the License.
*/
#include <cmath>
#include "src/runtime/kernel/arm/base/dequant.h"
#include "src/dequant.h"
namespace mindspore::kernel {
namespace mindspore::lite {
float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
MS_ASSERT(input_tensor != nullptr);
if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) {
@ -35,6 +35,8 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
}
void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) {
MS_ASSERT(input_tensor != nullptr);
MS_ASSERT(unpack_int_data != nullptr);
auto quant_params = input_tensor->quantParams();
if (quant_params == nullptr) {
MS_LOG(ERROR) << "low bits quantparams is empty.";
@ -47,4 +49,41 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i
UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data);
}
}
} // namespace mindspore::kernel
std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors,
TypeId data_type) {
std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data;
if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) {
for (auto weight_tensor : in_tensors) {
MS_ASSERT(weight_tensor != nullptr);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited &&
restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
return tensor_origin_data;
}
weight_tensor->set_data(dequant_weight);
weight_tensor->set_data_type(kNumberTypeFloat32);
tensor_origin_data[weight_tensor] = {restore_type, restore_data};
}
}
}
return tensor_origin_data;
}
void DequantUtil::RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map) {
for (auto &kv : tensor_origin_data_map) {
auto *tensor = kv.first;
auto type_id = kv.second.first;
auto data = kv.second.second;
tensor->FreeData();
tensor->set_data_type(type_id);
tensor->set_data(data);
}
}
} // namespace mindspore::lite

View File

@ -17,6 +17,8 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_
#include <map>
#include <utility>
#include <vector>
#include <queue>
#include <cmath>
@ -24,13 +26,18 @@
#include "src/common/utils.h"
#include "src/tensor.h"
namespace mindspore::kernel {
namespace mindspore::lite {
class DequantUtil {
public:
static float *DequantWeight(lite::Tensor *input_tensor);
static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);
static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors,
TypeId data_type);
static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map);
template <typename ST, typename DT = float>
static DT *DequantData(lite::Tensor *input_tensor) {
const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData());
@ -108,7 +115,7 @@ class DequantUtil {
static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int,
size_t *count, bool is_last) {
T2 uint_result = 0;
T1 result = 0;
T1 result;
UnPackFromUintToOrigin<T2>(packed_data, unpack_bit_data);
while (static_cast<int>(unpack_bit_data->size()) >= origin_bit) {
for (int k = 0; k < origin_bit; k++) {
@ -163,6 +170,6 @@ class DequantUtil {
}
}
};
} // namespace mindspore::kernel
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_

View File

@ -27,7 +27,7 @@
#include "src/common/graph_util.h"
#include "src/kernel_registry.h"
#include "src/lite_model.h"
#include "src/runtime/kernel/arm/base/dequant.h"
#include "src/dequant.h"
#if SUPPORT_NPU
#include "src/runtime/agent/npu/npu_manager.h"
#include "src/runtime/agent/npu/optimizer/npu_pass_manager.h"
@ -120,7 +120,7 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
MS_LOG(ERROR) << "Malloc data for tensor failed ";
return RET_ERROR;
}
kernel::DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData());
DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData());
copyed_tensor_idxes_.emplace_back(tensor_index);
} else {
dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data()));

View File

@ -25,7 +25,6 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@ -359,22 +358,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->set_data(dequant_weight);
}
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
kernel::LiteKernel *kernel = nullptr;
if (conv_param->group_ == 1) {
@ -385,11 +368,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
if (kernel == nullptr) {
MS_LOG(DEBUG) << "Create conv fp16 kernel failed.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
}
@ -398,20 +376,9 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
if (ret != RET_OK) {
MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel;
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator)

View File

@ -22,7 +22,6 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@ -138,22 +137,6 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->set_data(dequant_weight);
}
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
kernel::LiteKernel *kernel;
if (conv_param->input_channel_ < 32) {
@ -164,11 +147,6 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
}
@ -176,19 +154,9 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel;
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

View File

@ -20,7 +20,6 @@
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"
#include "nnacl/fp16/conv_fp16.h"
#include "nnacl/fp16/matmul_fp16.h"
#include "nnacl/fp16/cast_fp16.h"

View File

@ -20,7 +20,6 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@ -212,30 +211,9 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
auto dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->set_data(dequant_weight);
}
auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
}
@ -243,19 +221,9 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel;
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

View File

@ -17,7 +17,6 @@
#include "src/runtime/kernel/arm/fp16/deconvolution_fp16.h"
#include "src/runtime/kernel/arm/fp16/deconvolution_winograd_fp16.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@ -220,22 +219,6 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
auto dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->set_data(dequant_weight);
}
kernel::LiteKernel *kernel;
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
if ((conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1) &&
@ -247,11 +230,6 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
}
@ -259,19 +237,9 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel;
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DeConv2D, CpuDeConvFp16KernelCreator)

View File

@ -234,30 +234,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
OpParameter *opParameter, const lite::InnerContext *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *weight_tensor = inputs.at(kWeightIndex);
// data of second tensor of fc may be nullptr
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->set_data(dequant_weight);
}
auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
}
@ -265,19 +244,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel;
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

View File

@ -24,7 +24,6 @@
#include "nnacl/fp16/matmul_fp16.h"
#include "nnacl/fp16/cast_fp16.h"
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/dequant.h"
namespace mindspore::kernel {
class FullconnectionFP16CPUKernel : public LiteKernel {

View File

@ -20,7 +20,6 @@
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/base/dequant.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
@ -330,29 +329,9 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data_type(kNumberTypeFloat32);
weight_tensor->set_data(dequant_weight);
}
auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
}
@ -361,18 +340,8 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

View File

@ -22,7 +22,6 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@ -356,22 +355,6 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
MS_ASSERT(desc.data_type == kNumberTypeFloat32);
// if get quantized weight, dequantize it to float32 type data.
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(op_parameter);
return nullptr;
}
weight_tensor->set_data(dequant_weight);
}
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
kernel::LiteKernel *kernel = nullptr;
if (conv_param->group_ == 1) {
@ -382,11 +365,6 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(op_parameter);
return nullptr;
}
@ -395,20 +373,9 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
if (ret != RET_OK && ret != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel;
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

View File

@ -21,7 +21,6 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@ -126,19 +125,6 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data(dequant_weight);
}
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
kernel::LiteKernel *kernel = nullptr;
if (primitive != nullptr && primitive->infer_flag()) {
@ -162,11 +148,6 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
}
@ -174,21 +155,10 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
if (ret != RET_OK && ret != RET_INFER_INVALID) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel;
return nullptr;
}
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

View File

@ -19,7 +19,6 @@
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"
#include "nnacl/fp32/conv_fp32.h"
#include "nnacl/fp32/matmul_fp32.h"

View File

@ -19,7 +19,6 @@
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@ -202,29 +201,10 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data(dequant_weight);
}
auto kernel =
new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
}
@ -232,19 +212,9 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel;
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

View File

@ -17,7 +17,6 @@
#include "src/runtime/kernel/arm/fp32/deconvolution_fp32.h"
#include "src/runtime/kernel/arm/fp32/deconvolution_winograd_fp32.h"
#include "src/runtime/runtime_api.h"
#include "src/runtime/kernel/arm/base/dequant.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
@ -240,20 +239,6 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data(dequant_weight);
}
kernel::LiteKernel *kernel;
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
@ -266,11 +251,6 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
}
@ -278,21 +258,9 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel;
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

View File

@ -228,28 +228,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_FullConnection);
auto *weight_tensor = inputs.at(kWeightIndex);
// data of second tensor of fc may be nullptr
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
return nullptr;
}
weight_tensor->set_data(dequant_weight);
}
auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (!kernel) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
}
@ -257,19 +238,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel;
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

View File

@ -22,7 +22,6 @@
#include "include/errorcode.h"
#include "nnacl/fp32/matmul_fp32.h"
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/dequant.h"
using mindspore::lite::InnerContext;
namespace mindspore::kernel {

View File

@ -19,7 +19,6 @@
#include "nnacl/fp32/matmul_fp32.h"
#include "src/runtime/runtime_api.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/base/dequant.h"
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INPUT_TENSOR_ERROR;
@ -417,30 +416,9 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::Tensor *>
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_MatMul);
auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->data_c();
auto restore_type = weight_tensor->data_type();
bool dequant_flag =
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
if (dequant_flag) {
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
if (dequant_weight == nullptr) {
MS_LOG(ERROR) << "dequant data is nullptr.";
free(opParameter);
return nullptr;
}
weight_tensor->set_data(dequant_weight);
}
auto kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
free(opParameter);
return nullptr;
}
@ -448,21 +426,9 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::Tensor *>
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
delete kernel;
return nullptr;
}
if (dequant_flag) {
weight_tensor->FreeData();
weight_tensor->set_data(restore_data);
weight_tensor->set_data_type(restore_type);
}
return kernel;
}

View File

@ -15,7 +15,7 @@
*/
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "src/runtime/kernel/arm/base/dequant.h"
#include "mindspore/lite/src/dequant.h"
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
@ -263,10 +263,10 @@ int OpenCLKernel::DequantWeight() {
if (is_fp16) {
#ifdef ENABLE_ARM64
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) {
dequant_weight = kernel::DequantUtil::DequantData<int8_t, float16_t>(weight_tensor);
dequant_weight = lite::DequantUtil::DequantData<int8_t, float16_t>(weight_tensor);
weight_tensor->set_data_type(kNumberTypeFloat16);
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) {
dequant_weight = kernel::DequantUtil::DequantData<int16_t, float16_t>(weight_tensor);
dequant_weight = lite::DequantUtil::DequantData<int16_t, float16_t>(weight_tensor);
weight_tensor->set_data_type(kNumberTypeFloat16);
} else {
set_flag = false;
@ -276,10 +276,10 @@ int OpenCLKernel::DequantWeight() {
#endif
} else {
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) {
dequant_weight = kernel::DequantUtil::DequantData<int8_t, float>(weight_tensor);
dequant_weight = lite::DequantUtil::DequantData<int8_t, float>(weight_tensor);
weight_tensor->set_data_type(kNumberTypeFloat32);
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) {
dequant_weight = kernel::DequantUtil::DequantData<int16_t, float>(weight_tensor);
dequant_weight = lite::DequantUtil::DequantData<int16_t, float>(weight_tensor);
weight_tensor->set_data_type(kNumberTypeFloat32);
} else {
set_flag = false;

View File

@ -25,7 +25,7 @@
#include "src/lite_kernel.h"
#include "include/errorcode.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/arm/base/dequant.h"
#include "mindspore/lite/src/dequant.h"
#include "src/runtime/kernel/opencl/utils.h"
using mindspore::lite::RET_ERROR;

View File

@ -27,6 +27,7 @@
#include "src/common/utils.h"
#include "src/kernel_registry.h"
#include "src/sub_graph_kernel.h"
#include "src/dequant.h"
#if SUPPORT_GPU
#include "src/runtime/kernel/opencl/opencl_subgraph.h"
#include "src/runtime/opencl/opencl_runtime.h"
@ -213,8 +214,10 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
if (mindspore::lite::IsSupportFloat16() &&
((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) {
kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type};
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type);
auto *kernel =
KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc);
DequantUtil::RestoreTensorData(tensor_origin_data_map);
if (kernel != nullptr) {
MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " "
<< node->name_;
@ -225,7 +228,9 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
desc.data_type = kNumberTypeFloat32;
}
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type);
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
DequantUtil::RestoreTensorData(tensor_origin_data_map);
if (kernel != nullptr) {
return kernel;
}

View File

@ -126,6 +126,7 @@ set(TEST_LITE_SRC
${LITE_DIR}/src/kernel_registry.cc
${LITE_DIR}/src/lite_kernel.cc
${LITE_DIR}/src/lite_session.cc
${LITE_DIR}/src/dequant.cc
${LITE_DIR}/src/sub_graph_kernel.cc
${LITE_DIR}/src/lite_model.cc
${LITE_DIR}/src/scheduler.cc

View File

@ -95,6 +95,7 @@ set(LITE_SRC
${SRC_DIR}/executor.cc
${SRC_DIR}/lite_model.cc
${SRC_DIR}/errorcode.cc
${SRC_DIR}/dequant.cc
)
if (SUPPORT_TRAIN)
set(LITE_SRC

View File

@ -782,4 +782,27 @@ FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &func_graph) {
return new_func_graph;
}
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, ParamValueLitePtr *param_value) {
MS_ASSERT(node != nullptr);
MS_ASSERT(param_node != nullptr);
MS_ASSERT(param_value != nullptr);
auto op_name = node->fullname_with_scope();
*param_node = node->cast<ParameterPtr>();
if (*param_node == nullptr) {
MS_LOG(INFO) << op_name << " can not cast to ParameterPtr";
return;
}
if (!(*param_node)->has_default()) {
MS_LOG(INFO) << op_name << " not has_default";
return;
}
*param_value = std::static_pointer_cast<ParamValueLite>((*param_node)->default_param());
if (*param_value == nullptr) {
MS_LOG(INFO) << "default_param can not cast to ParamValueLite";
return;
}
}
} // namespace mindspore::lite::quant

View File

@ -75,9 +75,10 @@ class QuantStrategy {
bool CanMulOpQuantized(const CNodePtr &node) const;
bool CanOpPostQuantized(AnfNodePtr &node) const;
private:
size_t mWeightSize;
size_t mConvWeightQuantChannelThreshold;
private:
static const std::vector<schema::PrimitiveType> conv_types;
static const std::vector<schema::PrimitiveType> mul_types;
};
@ -356,5 +357,8 @@ STATUS CopyInputDataToTensor(size_t input_index, size_t image_index,
const std::vector<std::vector<std::string>> &images, mindspore::tensor::MSTensor *tensor);
FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &);
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, ParamValueLitePtr *param_value);
} // namespace mindspore::lite::quant
#endif

View File

@ -20,7 +20,6 @@
#include <vector>
#include <unordered_map>
#include "src/common/common.h"
#include "ir/dtype/type_id.h"
using std::string;
using std::vector;
@ -73,13 +72,13 @@ WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_f
this->bit_num_ = static_cast<size_t>(std::stoull(bitNum));
auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold));
quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold);
quant_max = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
quant_min = -(1 << (unsigned int)(this->bit_num_ - 1));
// parse type_id
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
// parse type_id_
if (this->bit_num_ > 0 && this->bit_num_ <= 8) {
type_id = kNumberTypeInt8;
type_id_ = kNumberTypeInt8;
} else if (this->bit_num_ <= 16) {
type_id = kNumberTypeInt16;
type_id_ = kNumberTypeInt16;
} else {
MS_LOG(ERROR) << "invalid input bits";
}
@ -90,7 +89,7 @@ WeightQuantizer::~WeightQuantizer() { delete fp32_session_; }
STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node,
std::shared_ptr<PrimitiveC> primitive_c) {
// set dtype
param_value->set_tensor_type(type_id);
param_value->set_tensor_type(type_id_);
auto abstract_base = param_node->abstract();
if (abstract_base == nullptr) {
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
@ -101,49 +100,158 @@ STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr
return RET_ERROR;
}
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
abstract_tensor->element()->set_type(TypeIdToType(type_id));
abstract_tensor->element()->set_type(TypeIdToType(type_id_));
primitive_c->set_quant_type(schema::QuantType_WeightQuant);
return RET_OK;
}
STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
for (auto &cnode : nodes) {
if (!quant_strategy_->CanConvOpQuantized(cnode)) {
continue;
}
STATUS WeightQuantizer::DoConvQuantize(CNodePtr cnode) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return RET_ERROR;
}
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return RET_ERROR;
}
auto input_node = cnode->input(2);
if (!input_node->isa<Parameter>()) {
return RET_ERROR;
}
auto input_node = cnode->input(2);
if (!input_node->isa<Parameter>()) {
return RET_ERROR;
}
ParameterPtr param_node;
ParamValueLitePtr param_value;
auto param_node = input_node->cast<ParameterPtr>();
if (!param_node->has_default()) {
return RET_ERROR;
}
GetLiteParameter(input_node, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
if (param_value == nullptr) {
if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) {
MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type();
return RET_ERROR;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
} else if (type_id_ == kNumberTypeInt16) {
status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
return RET_OK;
}
STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
auto already_quant = false;
ParamValueLitePtr param_value = nullptr;
ParameterPtr param_node = nullptr;
for (size_t i = 1; i < cnode->size(); i++) {
auto inputNode = cnode->input(i);
if (inputNode->isa<Parameter>()) {
param_node = inputNode->cast<ParameterPtr>();
if ((param_node != nullptr) && param_node->has_default()) {
param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
if ((param_value == nullptr) || (param_value->tensor_size() == 0) || (param_value->tensor_addr() == nullptr)) {
param_value = nullptr;
continue;
} else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 ||
param_value->tensor_type() == mindspore::kNumberTypeInt16) {
MS_LOG(INFO) << "the node: " << cnode->fullname_with_scope() << " input_i: " << i << "has been "
<< " quantized";
already_quant = true;
break;
} else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) {
param_value = nullptr;
continue;
} else {
break;
}
}
}
}
if (already_quant) {
return RET_OK;
}
if (param_value == nullptr) {
MS_LOG(ERROR) << "No valid input param node !";
return RET_ERROR;
}
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return RET_ERROR;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
} else if (type_id_ == kNumberTypeInt16) {
status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
return RET_OK;
}
STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) {
MS_ASSERT(cnode != nullptr);
auto op_name = cnode->fullname_with_scope();
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
MS_ASSERT(primitive_c != nullptr);
if (cnode->inputs().size() < 4) {
MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size();
return RET_ERROR;
}
{
auto weight_i = cnode->input(2);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(weight_i, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) {
MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type();
return RET_ERROR;
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_OK;
}
if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < "
<< quant_strategy_->mWeightSize;
return RET_OK;
}
auto status = RET_ERROR;
if (type_id == kNumberTypeInt8) {
if (type_id_ == kNumberTypeInt8) {
status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
} else if (type_id == kNumberTypeInt16) {
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
} else if (type_id_ == kNumberTypeInt16) {
status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
@ -155,77 +263,109 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
return RET_ERROR;
}
}
{
auto weight_h = cnode->input(3);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(weight_h, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_ERROR;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
} else if (type_id_ == kNumberTypeInt16) {
status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
}
{
if (cnode->inputs().size() > 4) {
auto bias = cnode->input(4);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(bias, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr) {
MS_LOG(ERROR) << "GetLiteParameter error";
return RET_ERROR;
}
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
return RET_ERROR;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
} else if (type_id_ == kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
false);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
}
}
return RET_OK;
}
STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
for (auto &node : nodes) {
if (!quant_strategy_->CanMulOpQuantized(node)) {
continue;
}
auto already_quant = false;
ParamValueLitePtr param_value = nullptr;
ParameterPtr param_node = nullptr;
for (size_t i = 1; i < node->size(); i++) {
auto inputNode = node->input(i);
if (inputNode->isa<Parameter>()) {
param_node = inputNode->cast<ParameterPtr>();
if ((param_node != nullptr) && param_node->has_default()) {
param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
if ((param_value == nullptr) || (param_value->tensor_size() == 0) ||
(param_value->tensor_addr() == nullptr)) {
param_value = nullptr;
continue;
} else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 ||
param_value->tensor_type() == mindspore::kNumberTypeInt16) {
MS_LOG(INFO) << "the node: " << node->fullname_with_scope() << " input_i: " << i << "has been "
<< " quantized";
already_quant = true;
break;
} else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) {
param_value = nullptr;
continue;
} else {
break;
}
}
}
}
STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
MS_ASSERT(primitive_c != nullptr);
if (already_quant) {
continue;
}
if (param_value == nullptr) {
MS_LOG(ERROR) << "No valid input param node !";
return RET_ERROR;
}
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return RET_ERROR;
}
auto status = RET_ERROR;
if (type_id == kNumberTypeInt8) {
status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
} else if (type_id == kNumberTypeInt16) {
status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
auto weight_h = cnode->input(1);
ParameterPtr param_node;
ParamValueLitePtr param_value;
GetLiteParameter(weight_h, &param_node, &param_value);
if (param_node == nullptr || param_value == nullptr || param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight";
return RET_OK;
}
if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
MS_LOG(INFO) << cnode->fullname_with_scope() << " param cnt: " << param_value->tensor_size() / 4 << " < "
<< quant_strategy_->mWeightSize;
return RET_OK;
}
auto status = RET_ERROR;
if (type_id_ == kNumberTypeInt8) {
status =
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
} else if (type_id_ == kNumberTypeInt16) {
status =
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
}
status = SetAbstract(param_value, param_node, primitive_c);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetAbstract failed : " << status;
return RET_ERROR;
}
return RET_OK;
}
@ -315,6 +455,23 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) {
}
auto cnodes = func_graph->GetOrderedCnodes();
for (auto &cnode : cnodes) {
auto op_type = NodePrimitiveType(cnode);
if (op_type == schema::PrimitiveType_Lstm) {
status = DoLstmQuntize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoLstmQuntize error";
return RET_ERROR;
}
} else if (op_type == schema::PrimitiveType_Gather) {
status = DoGatherQuntize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoGatherQuntize error";
return RET_ERROR;
}
}
}
for (auto iter = cnodes.end(); iter != cnodes.begin();) {
auto cnode = *(--iter);
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
@ -357,18 +514,18 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) {
}
// 1. try quant
for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) {
type_id = TypeId::kNumberTypeInt8;
type_id_ = TypeId::kNumberTypeInt8;
int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
if (type_id == TypeId::kNumberTypeInt8) {
if (type_id_ == TypeId::kNumberTypeInt8) {
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
quant_min_t, bit_num_t, true);
} else if (type_id == TypeId::kNumberTypeInt16) {
} else if (type_id_ == TypeId::kNumberTypeInt16) {
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
quant_min_t, bit_num_t, true);
} else {
MS_LOG(ERROR) << "unexpected type_id: " << type_id;
MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
return RET_ERROR;
}
if (status != RET_OK) {
@ -456,13 +613,53 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) {
return RET_OK;
}
STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) {
MS_ASSERT(func_graph != nullptr);
for (auto &cnode : func_graph->GetOrderedCnodes()) {
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "primitive_c is nullptr";
return RET_ERROR;
}
auto op_name = cnode->fullname_with_scope();
auto op_type = (schema::PrimitiveType)primitive_c->Type();
if (quant_strategy_->CanConvOpQuantized(cnode)) {
auto status = DoConvQuantize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoConvQuantize error";
return RET_ERROR;
}
} else if (quant_strategy_->CanMulOpQuantized(cnode)) {
auto status = DoMulQuantize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoMulQuantize error";
return RET_ERROR;
}
} else if (op_type == schema::PrimitiveType_Lstm) {
auto status = DoLstmQuntize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoLstmQuntize error";
return RET_ERROR;
}
} else if (op_type == schema::PrimitiveType_Gather) {
auto status = DoGatherQuntize(cnode);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoGatherQuntize error";
return RET_ERROR;
}
} else {
MS_LOG(DEBUG) << op_name << " of type: " << schema::EnumNamePrimitiveType(op_type) << " no need quant";
}
}
return RET_OK;
}
STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
MS_ASSERT(func_graph != nullptr);
STATUS ret;
auto cnodes = func_graph->GetOrderedCnodes();
if (!config_file_.empty()) {
ret = ParseConfigFile(config_file_, &config_param_);
auto ret = ParseConfigFile(config_file_, &config_param_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ReadConfig error.";
return RET_ERROR;
@ -470,20 +667,14 @@ STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
}
if (config_param_.mixed) {
bit_num_ = 8;
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
type_id_ = kNumberTypeInt8;
MS_LOG(INFO) << "Do mixed bit quantization";
return DoMiexedQuant(func_graph);
}
ret = DoConvQuantize(cnodes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoConvQuantize failed :" << ret;
return ret;
}
ret = DoMulQuantize(cnodes);
if (ret != RET_OK) {
MS_LOG(ERROR) << "DoMulQuantize failed :" << ret;
return ret;
}
return ret;
return DoFixedQuant(func_graph);
}
} // namespace mindspore::lite::quant

View File

@ -41,19 +41,21 @@ class WeightQuantizer : public Quantizer {
~WeightQuantizer();
STATUS DoQuantize(FuncGraphPtr func_graph) override;
STATUS DoConvQuantize(const std::list<CNodePtr> &nodes);
STATUS DoMulQuantize(const std::list<CNodePtr> &nodes);
STATUS DoConvQuantize(CNodePtr);
STATUS DoMulQuantize(CNodePtr);
STATUS DoLstmQuntize(CNodePtr cnode);
STATUS DoGatherQuntize(CNodePtr cnode);
static STATUS WeightQuantInputCheck(const converter::Flags *config);
static bool IsPosNum(const std::string &str);
int quant_max;
int quant_min;
TypeId type_id{kTypeUnknown};
int quant_max_{127};
int quant_min_{-128};
TypeId type_id_{kNumberTypeInt8};
std::map<std::string, int> opname_bit_;
private:
std::unique_ptr<QuantStrategy> quant_strategy_;
size_t bit_num_;
size_t bit_num_{8};
std::string config_file_;
PostQuantConfig config_param_;
std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...]
@ -61,6 +63,7 @@ class WeightQuantizer : public Quantizer {
STATUS DoMiexedQuant(FuncGraphPtr);
STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c);
STATUS DoFixedQuant(FuncGraphPtr);
};
} // namespace mindspore::lite::quant
#endif