forked from mindspore-Ecosystem/mindspore
weight quant reconstruction && lstm/gather quant
This commit is contained in:
parent
2924552783
commit
5d613749ec
|
@ -39,6 +39,7 @@ set(LITE_SRC
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/dequant.cc
|
||||
)
|
||||
|
||||
if (SUPPORT_GPU)
|
||||
|
|
|
@ -14,9 +14,9 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include <cmath>
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
#include "src/dequant.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
namespace mindspore::lite {
|
||||
float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
|
||||
MS_ASSERT(input_tensor != nullptr);
|
||||
if (input_tensor->data_type() != kNumberTypeInt8 && input_tensor->data_type() != kNumberTypeInt16) {
|
||||
|
@ -35,6 +35,8 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) {
|
|||
}
|
||||
|
||||
void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) {
|
||||
MS_ASSERT(input_tensor != nullptr);
|
||||
MS_ASSERT(unpack_int_data != nullptr);
|
||||
auto quant_params = input_tensor->quantParams();
|
||||
if (quant_params == nullptr) {
|
||||
MS_LOG(ERROR) << "low bits quantparams is empty.";
|
||||
|
@ -47,4 +49,41 @@ void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_i
|
|||
UnPackUtil<int16_t, uint16_t>(input_tensor, origin_bit, unpack_int_data);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
std::map<Tensor *, std::pair<TypeId, void *>> DequantUtil::DequantTensor(const std::vector<Tensor *> &in_tensors,
|
||||
TypeId data_type) {
|
||||
std::map<Tensor *, std::pair<TypeId, void *>> tensor_origin_data;
|
||||
if (data_type == TypeId::kNumberTypeFloat32 || data_type == TypeId::kNumberTypeFloat16) {
|
||||
for (auto weight_tensor : in_tensors) {
|
||||
MS_ASSERT(weight_tensor != nullptr);
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
auto restore_type = weight_tensor->data_type();
|
||||
bool dequant_flag = !weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited &&
|
||||
restore_data != nullptr;
|
||||
if (dequant_flag) {
|
||||
auto *dequant_weight = DequantUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
return tensor_origin_data;
|
||||
}
|
||||
weight_tensor->set_data(dequant_weight);
|
||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||
tensor_origin_data[weight_tensor] = {restore_type, restore_data};
|
||||
}
|
||||
}
|
||||
}
|
||||
return tensor_origin_data;
|
||||
}
|
||||
|
||||
void DequantUtil::RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map) {
|
||||
for (auto &kv : tensor_origin_data_map) {
|
||||
auto *tensor = kv.first;
|
||||
auto type_id = kv.second.first;
|
||||
auto data = kv.second.second;
|
||||
tensor->FreeData();
|
||||
tensor->set_data_type(type_id);
|
||||
tensor->set_data(data);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mindspore::lite
|
|
@ -17,6 +17,8 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_
|
||||
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <cmath>
|
||||
|
@ -24,13 +26,18 @@
|
|||
#include "src/common/utils.h"
|
||||
#include "src/tensor.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
namespace mindspore::lite {
|
||||
class DequantUtil {
|
||||
public:
|
||||
static float *DequantWeight(lite::Tensor *input_tensor);
|
||||
|
||||
static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data);
|
||||
|
||||
static std::map<Tensor *, std::pair<TypeId, void *>> DequantTensor(const std::vector<Tensor *> &in_tensors,
|
||||
TypeId data_type);
|
||||
|
||||
static void RestoreTensorData(const std::map<Tensor *, std::pair<TypeId, void *>> &tensor_origin_data_map);
|
||||
|
||||
template <typename ST, typename DT = float>
|
||||
static DT *DequantData(lite::Tensor *input_tensor) {
|
||||
const auto *quant_datas = static_cast<const ST *>(input_tensor->MutableData());
|
||||
|
@ -108,7 +115,7 @@ class DequantUtil {
|
|||
static void UnPackData(int origin_bit, const T2 &packed_data, std::queue<bool> *unpack_bit_data, void *unpack_int,
|
||||
size_t *count, bool is_last) {
|
||||
T2 uint_result = 0;
|
||||
T1 result = 0;
|
||||
T1 result;
|
||||
UnPackFromUintToOrigin<T2>(packed_data, unpack_bit_data);
|
||||
while (static_cast<int>(unpack_bit_data->size()) >= origin_bit) {
|
||||
for (int k = 0; k < origin_bit; k++) {
|
||||
|
@ -163,6 +170,6 @@ class DequantUtil {
|
|||
}
|
||||
}
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
} // namespace mindspore::lite
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_
|
|
@ -27,7 +27,7 @@
|
|||
#include "src/common/graph_util.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/lite_model.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
#include "src/dequant.h"
|
||||
#if SUPPORT_NPU
|
||||
#include "src/runtime/agent/npu/npu_manager.h"
|
||||
#include "src/runtime/agent/npu/optimizer/npu_pass_manager.h"
|
||||
|
@ -120,7 +120,7 @@ int LiteSession::ConvertTensorsData(const lite::Model *model, size_t tensor_inde
|
|||
MS_LOG(ERROR) << "Malloc data for tensor failed ";
|
||||
return RET_ERROR;
|
||||
}
|
||||
kernel::DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData());
|
||||
DequantUtil::UnPackToInt(src_tensor, dst_tensor->MutableData());
|
||||
copyed_tensor_idxes_.emplace_back(tensor_index);
|
||||
} else {
|
||||
dst_tensor->set_data(const_cast<unsigned char *>(src_tensor->data()->data()));
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -359,22 +358,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
|
|||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
|
||||
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
auto restore_type = weight_tensor->data_type();
|
||||
bool dequant_flag =
|
||||
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
|
||||
if (dequant_flag) {
|
||||
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||
weight_tensor->set_data(dequant_weight);
|
||||
}
|
||||
|
||||
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
||||
kernel::LiteKernel *kernel = nullptr;
|
||||
if (conv_param->group_ == 1) {
|
||||
|
@ -385,11 +368,6 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
|
|||
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(DEBUG) << "Create conv fp16 kernel failed.";
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -398,20 +376,9 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
|
|||
if (ret != RET_OK) {
|
||||
MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_
|
||||
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Conv2D, CpuConvFp16KernelCreator)
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -138,22 +137,6 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
|
||||
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
auto restore_type = weight_tensor->data_type();
|
||||
bool dequant_flag =
|
||||
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
|
||||
if (dequant_flag) {
|
||||
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||
weight_tensor->set_data(dequant_weight);
|
||||
}
|
||||
|
||||
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
||||
kernel::LiteKernel *kernel;
|
||||
if (conv_param->input_channel_ < 32) {
|
||||
|
@ -164,11 +147,6 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
}
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -176,19 +154,9 @@ kernel::LiteKernel *CpuConvDwFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
#include "nnacl/fp16/conv_fp16.h"
|
||||
#include "nnacl/fp16/matmul_fp16.h"
|
||||
#include "nnacl/fp16/cast_fp16.h"
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -212,30 +211,9 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
|
|||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
|
||||
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
auto restore_type = weight_tensor->data_type();
|
||||
auto dequant_flag =
|
||||
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
|
||||
if (dequant_flag) {
|
||||
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||
weight_tensor->set_data(dequant_weight);
|
||||
}
|
||||
|
||||
auto kernel = new (std::nothrow) DeconvolutionDepthwiseFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -243,19 +221,9 @@ kernel::LiteKernel *CpuDeconvDwFp16KernelCreator(const std::vector<lite::Tensor
|
|||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#include "src/runtime/kernel/arm/fp16/deconvolution_fp16.h"
|
||||
#include "src/runtime/kernel/arm/fp16/deconvolution_winograd_fp16.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -220,22 +219,6 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
|
||||
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
auto restore_type = weight_tensor->data_type();
|
||||
auto dequant_flag =
|
||||
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
|
||||
if (dequant_flag) {
|
||||
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||
weight_tensor->set_data(dequant_weight);
|
||||
}
|
||||
|
||||
kernel::LiteKernel *kernel;
|
||||
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
||||
if ((conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1) &&
|
||||
|
@ -247,11 +230,6 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -259,19 +237,9 @@ kernel::LiteKernel *CpuDeConvFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DeConv2D, CpuDeConvFp16KernelCreator)
|
||||
|
|
|
@ -234,30 +234,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
|
|||
OpParameter *opParameter, const lite::InnerContext *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
// data of second tensor of fc may be nullptr
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
auto restore_type = weight_tensor->data_type();
|
||||
bool dequant_flag =
|
||||
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
|
||||
if (dequant_flag) {
|
||||
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||
weight_tensor->set_data(dequant_weight);
|
||||
}
|
||||
auto *kernel = new (std::nothrow) FullconnectionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -265,19 +244,9 @@ kernel::LiteKernel *CpuFullConnectionFp16KernelCreator(const std::vector<lite::T
|
|||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
#include "nnacl/fp16/matmul_fp16.h"
|
||||
#include "nnacl/fp16/cast_fp16.h"
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class FullconnectionFP16CPUKernel : public LiteKernel {
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "src/runtime/runtime_api.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
|
@ -330,29 +329,9 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
|
||||
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
auto restore_type = weight_tensor->data_type();
|
||||
bool dequant_flag =
|
||||
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
|
||||
if (dequant_flag) {
|
||||
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||
weight_tensor->set_data(dequant_weight);
|
||||
}
|
||||
auto *kernel = new (std::nothrow) MatmulFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -361,18 +340,8 @@ kernel::LiteKernel *CpuMatmulFp16KernelCreator(const std::vector<lite::Tensor *>
|
|||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
delete kernel;
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -356,22 +355,6 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
|
|||
MS_ASSERT(desc.type == schema::PrimitiveType_Conv2D);
|
||||
MS_ASSERT(desc.data_type == kNumberTypeFloat32);
|
||||
|
||||
// if get quantized weight, dequantize it to float32 type data.
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
auto restore_type = weight_tensor->data_type();
|
||||
bool dequant_flag =
|
||||
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
|
||||
if (dequant_flag) {
|
||||
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
free(op_parameter);
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data(dequant_weight);
|
||||
}
|
||||
|
||||
auto conv_param = reinterpret_cast<ConvParameter *>(op_parameter);
|
||||
kernel::LiteKernel *kernel = nullptr;
|
||||
if (conv_param->group_ == 1) {
|
||||
|
@ -382,11 +365,6 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
|
|||
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
free(op_parameter);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -395,20 +373,9 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &
|
|||
if (ret != RET_OK && ret != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << op_parameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -126,19 +125,6 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
|
|||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
|
||||
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
auto restore_type = weight_tensor->data_type();
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
|
||||
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data(dequant_weight);
|
||||
}
|
||||
|
||||
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
||||
kernel::LiteKernel *kernel = nullptr;
|
||||
if (primitive != nullptr && primitive->infer_flag()) {
|
||||
|
@ -162,11 +148,6 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
|
|||
}
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -174,21 +155,10 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector<lite::Tensor *>
|
|||
if (ret != RET_OK && ret != RET_INFER_INVALID) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (weight_tensor->data_type() == kNumberTypeInt8 || weight_tensor->data_type() == kNumberTypeInt16) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
#include "nnacl/fp32/conv_fp32.h"
|
||||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -202,29 +201,10 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
|
|||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_DeDepthwiseConv2D);
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
auto restore_type = weight_tensor->data_type();
|
||||
bool dequant_flag =
|
||||
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
|
||||
if (dequant_flag) {
|
||||
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data(dequant_weight);
|
||||
}
|
||||
auto kernel =
|
||||
new (std::nothrow) kernel::DeconvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -232,19 +212,9 @@ kernel::LiteKernel *CpuDeconvDwFp32KernelCreator(const std::vector<lite::Tensor
|
|||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
#include "src/runtime/kernel/arm/fp32/deconvolution_fp32.h"
|
||||
#include "src/runtime/kernel/arm/fp32/deconvolution_winograd_fp32.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -240,20 +239,6 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
|
|||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_DeConv2D);
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
auto restore_type = weight_tensor->data_type();
|
||||
bool dequant_flag =
|
||||
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
|
||||
if (dequant_flag) {
|
||||
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data(dequant_weight);
|
||||
}
|
||||
|
||||
kernel::LiteKernel *kernel;
|
||||
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
|
||||
|
@ -266,11 +251,6 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
|
|||
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -278,21 +258,9 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
|
|||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -228,28 +228,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
|
|||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_FullConnection);
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
// data of second tensor of fc may be nullptr
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
auto restore_type = weight_tensor->data_type();
|
||||
bool dequant_flag =
|
||||
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
|
||||
if (dequant_flag) {
|
||||
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data(dequant_weight);
|
||||
}
|
||||
auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (!kernel) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -257,19 +238,9 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::T
|
|||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include "include/errorcode.h"
|
||||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
|
||||
using mindspore::lite::InnerContext;
|
||||
namespace mindspore::kernel {
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include "nnacl/fp32/matmul_fp32.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_INPUT_TENSOR_ERROR;
|
||||
|
@ -417,30 +416,9 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::Tensor *>
|
|||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_MatMul);
|
||||
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->data_c();
|
||||
auto restore_type = weight_tensor->data_type();
|
||||
bool dequant_flag =
|
||||
!weight_tensor->quant_params().empty() && weight_tensor->quant_params().front().inited && restore_data != nullptr;
|
||||
if (dequant_flag) {
|
||||
auto *dequant_weight = kernel::DequantUtil::DequantWeight(weight_tensor);
|
||||
if (dequant_weight == nullptr) {
|
||||
MS_LOG(ERROR) << "dequant data is nullptr.";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
weight_tensor->set_data(dequant_weight);
|
||||
}
|
||||
|
||||
auto kernel = new (std::nothrow) MatmulCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -448,21 +426,9 @@ kernel::LiteKernel *CpuMatmulFp32KernelCreator(const std::vector<lite::Tensor *>
|
|||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (dequant_flag) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->set_data(restore_data);
|
||||
weight_tensor->set_data_type(restore_type);
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "src/runtime/kernel/opencl/opencl_kernel.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
#include "mindspore/lite/src/dequant.h"
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
|
@ -263,10 +263,10 @@ int OpenCLKernel::DequantWeight() {
|
|||
if (is_fp16) {
|
||||
#ifdef ENABLE_ARM64
|
||||
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) {
|
||||
dequant_weight = kernel::DequantUtil::DequantData<int8_t, float16_t>(weight_tensor);
|
||||
dequant_weight = lite::DequantUtil::DequantData<int8_t, float16_t>(weight_tensor);
|
||||
weight_tensor->set_data_type(kNumberTypeFloat16);
|
||||
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) {
|
||||
dequant_weight = kernel::DequantUtil::DequantData<int16_t, float16_t>(weight_tensor);
|
||||
dequant_weight = lite::DequantUtil::DequantData<int16_t, float16_t>(weight_tensor);
|
||||
weight_tensor->set_data_type(kNumberTypeFloat16);
|
||||
} else {
|
||||
set_flag = false;
|
||||
|
@ -276,10 +276,10 @@ int OpenCLKernel::DequantWeight() {
|
|||
#endif
|
||||
} else {
|
||||
if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt8) {
|
||||
dequant_weight = kernel::DequantUtil::DequantData<int8_t, float>(weight_tensor);
|
||||
dequant_weight = lite::DequantUtil::DequantData<int8_t, float>(weight_tensor);
|
||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||
} else if (in_tensors_.at(kWeightIndex)->data_type() == kNumberTypeInt16) {
|
||||
dequant_weight = kernel::DequantUtil::DequantData<int16_t, float>(weight_tensor);
|
||||
dequant_weight = lite::DequantUtil::DequantData<int16_t, float>(weight_tensor);
|
||||
weight_tensor->set_data_type(kNumberTypeFloat32);
|
||||
} else {
|
||||
set_flag = false;
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "src/lite_kernel.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
#include "src/runtime/kernel/arm/base/dequant.h"
|
||||
#include "mindspore/lite/src/dequant.h"
|
||||
#include "src/runtime/kernel/opencl/utils.h"
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "src/common/utils.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/sub_graph_kernel.h"
|
||||
#include "src/dequant.h"
|
||||
#if SUPPORT_GPU
|
||||
#include "src/runtime/kernel/opencl/opencl_subgraph.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
|
@ -213,8 +214,10 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
|||
if (mindspore::lite::IsSupportFloat16() &&
|
||||
((context_->IsCpuFloat16Enabled() && data_type == kNumberTypeFloat32) || data_type == kNumberTypeFloat16)) {
|
||||
kernel::KernelKey fp16_cpu_desc{desc.arch, kNumberTypeFloat16, desc.type};
|
||||
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, fp16_cpu_desc.data_type);
|
||||
auto *kernel =
|
||||
KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, fp16_cpu_desc);
|
||||
DequantUtil::RestoreTensorData(tensor_origin_data_map);
|
||||
if (kernel != nullptr) {
|
||||
MS_LOG(DEBUG) << "Get fp16 op success: " << schema::EnumNamePrimitiveType(fp16_cpu_desc.type) << " "
|
||||
<< node->name_;
|
||||
|
@ -225,7 +228,9 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
|
|||
MS_LOG(DEBUG) << "Get fp16 op failed, back to fp32 op.";
|
||||
desc.data_type = kNumberTypeFloat32;
|
||||
}
|
||||
auto tensor_origin_data_map = DequantUtil::DequantTensor(in_tensors, desc.data_type);
|
||||
auto *kernel = KernelRegistry::GetInstance()->GetKernel(in_tensors, out_tensors, primitive, context_, desc);
|
||||
DequantUtil::RestoreTensorData(tensor_origin_data_map);
|
||||
if (kernel != nullptr) {
|
||||
return kernel;
|
||||
}
|
||||
|
|
|
@ -126,6 +126,7 @@ set(TEST_LITE_SRC
|
|||
${LITE_DIR}/src/kernel_registry.cc
|
||||
${LITE_DIR}/src/lite_kernel.cc
|
||||
${LITE_DIR}/src/lite_session.cc
|
||||
${LITE_DIR}/src/dequant.cc
|
||||
${LITE_DIR}/src/sub_graph_kernel.cc
|
||||
${LITE_DIR}/src/lite_model.cc
|
||||
${LITE_DIR}/src/scheduler.cc
|
||||
|
|
|
@ -95,6 +95,7 @@ set(LITE_SRC
|
|||
${SRC_DIR}/executor.cc
|
||||
${SRC_DIR}/lite_model.cc
|
||||
${SRC_DIR}/errorcode.cc
|
||||
${SRC_DIR}/dequant.cc
|
||||
)
|
||||
if (SUPPORT_TRAIN)
|
||||
set(LITE_SRC
|
||||
|
|
|
@ -782,4 +782,27 @@ FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &func_graph) {
|
|||
return new_func_graph;
|
||||
}
|
||||
|
||||
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, ParamValueLitePtr *param_value) {
|
||||
MS_ASSERT(node != nullptr);
|
||||
MS_ASSERT(param_node != nullptr);
|
||||
MS_ASSERT(param_value != nullptr);
|
||||
|
||||
auto op_name = node->fullname_with_scope();
|
||||
|
||||
*param_node = node->cast<ParameterPtr>();
|
||||
if (*param_node == nullptr) {
|
||||
MS_LOG(INFO) << op_name << " can not cast to ParameterPtr";
|
||||
return;
|
||||
}
|
||||
if (!(*param_node)->has_default()) {
|
||||
MS_LOG(INFO) << op_name << " not has_default";
|
||||
return;
|
||||
}
|
||||
|
||||
*param_value = std::static_pointer_cast<ParamValueLite>((*param_node)->default_param());
|
||||
if (*param_value == nullptr) {
|
||||
MS_LOG(INFO) << "default_param can not cast to ParamValueLite";
|
||||
return;
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -75,9 +75,10 @@ class QuantStrategy {
|
|||
bool CanMulOpQuantized(const CNodePtr &node) const;
|
||||
bool CanOpPostQuantized(AnfNodePtr &node) const;
|
||||
|
||||
private:
|
||||
size_t mWeightSize;
|
||||
size_t mConvWeightQuantChannelThreshold;
|
||||
|
||||
private:
|
||||
static const std::vector<schema::PrimitiveType> conv_types;
|
||||
static const std::vector<schema::PrimitiveType> mul_types;
|
||||
};
|
||||
|
@ -356,5 +357,8 @@ STATUS CopyInputDataToTensor(size_t input_index, size_t image_index,
|
|||
const std::vector<std::vector<std::string>> &images, mindspore::tensor::MSTensor *tensor);
|
||||
|
||||
FuncGraphPtr CopyFuncGraph(const FuncGraphPtr &);
|
||||
|
||||
void GetLiteParameter(const AnfNodePtr &node, ParameterPtr *param_node, ParamValueLitePtr *param_value);
|
||||
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include "src/common/common.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
||||
using std::string;
|
||||
using std::vector;
|
||||
|
@ -73,13 +72,13 @@ WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const std::string &config_f
|
|||
this->bit_num_ = static_cast<size_t>(std::stoull(bitNum));
|
||||
auto convQuantWeightChannelThreshold = static_cast<size_t>(std::stoull(convWeightChannelThreshold));
|
||||
quant_strategy_ = std::make_unique<QuantStrategy>(quantSize, convQuantWeightChannelThreshold);
|
||||
quant_max = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
|
||||
quant_min = -(1 << (unsigned int)(this->bit_num_ - 1));
|
||||
// parse type_id
|
||||
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
|
||||
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
|
||||
// parse type_id_
|
||||
if (this->bit_num_ > 0 && this->bit_num_ <= 8) {
|
||||
type_id = kNumberTypeInt8;
|
||||
type_id_ = kNumberTypeInt8;
|
||||
} else if (this->bit_num_ <= 16) {
|
||||
type_id = kNumberTypeInt16;
|
||||
type_id_ = kNumberTypeInt16;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "invalid input bits";
|
||||
}
|
||||
|
@ -90,7 +89,7 @@ WeightQuantizer::~WeightQuantizer() { delete fp32_session_; }
|
|||
STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node,
|
||||
std::shared_ptr<PrimitiveC> primitive_c) {
|
||||
// set dtype
|
||||
param_value->set_tensor_type(type_id);
|
||||
param_value->set_tensor_type(type_id_);
|
||||
auto abstract_base = param_node->abstract();
|
||||
if (abstract_base == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
|
||||
|
@ -101,49 +100,158 @@ STATUS WeightQuantizer::SetAbstract(ParamValueLitePtr param_value, ParameterPtr
|
|||
return RET_ERROR;
|
||||
}
|
||||
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
|
||||
abstract_tensor->element()->set_type(TypeIdToType(type_id));
|
||||
abstract_tensor->element()->set_type(TypeIdToType(type_id_));
|
||||
primitive_c->set_quant_type(schema::QuantType_WeightQuant);
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
|
||||
for (auto &cnode : nodes) {
|
||||
if (!quant_strategy_->CanConvOpQuantized(cnode)) {
|
||||
continue;
|
||||
}
|
||||
STATUS WeightQuantizer::DoConvQuantize(CNodePtr cnode) {
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input_node = cnode->input(2);
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto input_node = cnode->input(2);
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
ParameterPtr param_node;
|
||||
ParamValueLitePtr param_value;
|
||||
|
||||
auto param_node = input_node->cast<ParameterPtr>();
|
||||
if (!param_node->has_default()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
GetLiteParameter(input_node, ¶m_node, ¶m_value);
|
||||
if (param_node == nullptr || param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "GetLiteParameter error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
|
||||
if (param_value == nullptr) {
|
||||
if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) {
|
||||
MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status =
|
||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status =
|
||||
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(param_value, param_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoMulQuantize(CNodePtr cnode) {
|
||||
auto already_quant = false;
|
||||
ParamValueLitePtr param_value = nullptr;
|
||||
ParameterPtr param_node = nullptr;
|
||||
for (size_t i = 1; i < cnode->size(); i++) {
|
||||
auto inputNode = cnode->input(i);
|
||||
if (inputNode->isa<Parameter>()) {
|
||||
param_node = inputNode->cast<ParameterPtr>();
|
||||
if ((param_node != nullptr) && param_node->has_default()) {
|
||||
param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
|
||||
if ((param_value == nullptr) || (param_value->tensor_size() == 0) || (param_value->tensor_addr() == nullptr)) {
|
||||
param_value = nullptr;
|
||||
continue;
|
||||
} else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 ||
|
||||
param_value->tensor_type() == mindspore::kNumberTypeInt16) {
|
||||
MS_LOG(INFO) << "the node: " << cnode->fullname_with_scope() << " input_i: " << i << "has been "
|
||||
<< " quantized";
|
||||
already_quant = true;
|
||||
break;
|
||||
} else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) {
|
||||
param_value = nullptr;
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (already_quant) {
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
if (param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "No valid input param node !";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status =
|
||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status =
|
||||
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, true);
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(param_value, param_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) {
|
||||
MS_ASSERT(cnode != nullptr);
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
MS_ASSERT(primitive_c != nullptr);
|
||||
|
||||
if (cnode->inputs().size() < 4) {
|
||||
MS_LOG(ERROR) << op_name << " inputs is " << cnode->inputs().size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
{
|
||||
auto weight_i = cnode->input(2);
|
||||
ParameterPtr param_node;
|
||||
ParamValueLitePtr param_value;
|
||||
GetLiteParameter(weight_i, ¶m_node, ¶m_value);
|
||||
if (param_node == nullptr || param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "GetLiteParameter error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) {
|
||||
MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type();
|
||||
return RET_ERROR;
|
||||
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
|
||||
MS_LOG(WARNING) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
|
||||
return RET_OK;
|
||||
}
|
||||
if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
|
||||
MS_LOG(INFO) << op_name << " weight_i cnt: " << param_value->tensor_size() / 4 << " < "
|
||||
<< quant_strategy_->mWeightSize;
|
||||
return RET_OK;
|
||||
}
|
||||
auto status = RET_ERROR;
|
||||
if (type_id == kNumberTypeInt8) {
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status =
|
||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
|
||||
} else if (type_id == kNumberTypeInt16) {
|
||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status =
|
||||
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
|
||||
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
|
@ -155,77 +263,109 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
{
|
||||
auto weight_h = cnode->input(3);
|
||||
ParameterPtr param_node;
|
||||
ParamValueLitePtr param_value;
|
||||
GetLiteParameter(weight_h, ¶m_node, ¶m_value);
|
||||
if (param_node == nullptr || param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "GetLiteParameter error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
|
||||
MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status =
|
||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status =
|
||||
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(param_value, param_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
{
|
||||
if (cnode->inputs().size() > 4) {
|
||||
auto bias = cnode->input(4);
|
||||
ParameterPtr param_node;
|
||||
ParamValueLitePtr param_value;
|
||||
GetLiteParameter(bias, ¶m_node, ¶m_value);
|
||||
if (param_node == nullptr || param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "GetLiteParameter error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
|
||||
MS_LOG(ERROR) << "param_value tensor type is: " << param_value->tensor_type() << " not quant";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status =
|
||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
||||
false);
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(param_value, param_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
|
||||
for (auto &node : nodes) {
|
||||
if (!quant_strategy_->CanMulOpQuantized(node)) {
|
||||
continue;
|
||||
}
|
||||
auto already_quant = false;
|
||||
ParamValueLitePtr param_value = nullptr;
|
||||
ParameterPtr param_node = nullptr;
|
||||
for (size_t i = 1; i < node->size(); i++) {
|
||||
auto inputNode = node->input(i);
|
||||
if (inputNode->isa<Parameter>()) {
|
||||
param_node = inputNode->cast<ParameterPtr>();
|
||||
if ((param_node != nullptr) && param_node->has_default()) {
|
||||
param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
|
||||
if ((param_value == nullptr) || (param_value->tensor_size() == 0) ||
|
||||
(param_value->tensor_addr() == nullptr)) {
|
||||
param_value = nullptr;
|
||||
continue;
|
||||
} else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 ||
|
||||
param_value->tensor_type() == mindspore::kNumberTypeInt16) {
|
||||
MS_LOG(INFO) << "the node: " << node->fullname_with_scope() << " input_i: " << i << "has been "
|
||||
<< " quantized";
|
||||
already_quant = true;
|
||||
break;
|
||||
} else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) {
|
||||
param_value = nullptr;
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) {
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
MS_ASSERT(primitive_c != nullptr);
|
||||
|
||||
if (already_quant) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (param_value == nullptr) {
|
||||
MS_LOG(ERROR) << "No valid input param node !";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(node->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto status = RET_ERROR;
|
||||
if (type_id == kNumberTypeInt8) {
|
||||
status =
|
||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
|
||||
} else if (type_id == kNumberTypeInt16) {
|
||||
status =
|
||||
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bit_num_, true);
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(param_value, param_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto weight_h = cnode->input(1);
|
||||
ParameterPtr param_node;
|
||||
ParamValueLitePtr param_value;
|
||||
GetLiteParameter(weight_h, ¶m_node, ¶m_value);
|
||||
if (param_node == nullptr || param_value == nullptr || param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
|
||||
MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight";
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
if (param_value->tensor_size() / 4 < quant_strategy_->mWeightSize) {
|
||||
MS_LOG(INFO) << cnode->fullname_with_scope() << " param cnt: " << param_value->tensor_size() / 4 << " < "
|
||||
<< quant_strategy_->mWeightSize;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
auto status = RET_ERROR;
|
||||
if (type_id_ == kNumberTypeInt8) {
|
||||
status =
|
||||
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
||||
} else if (type_id_ == kNumberTypeInt16) {
|
||||
status =
|
||||
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
status = SetAbstract(param_value, param_node, primitive_c);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetAbstract failed : " << status;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -315,6 +455,23 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) {
|
|||
}
|
||||
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
for (auto &cnode : cnodes) {
|
||||
auto op_type = NodePrimitiveType(cnode);
|
||||
if (op_type == schema::PrimitiveType_Lstm) {
|
||||
status = DoLstmQuntize(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoLstmQuntize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (op_type == schema::PrimitiveType_Gather) {
|
||||
status = DoGatherQuntize(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoGatherQuntize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto iter = cnodes.end(); iter != cnodes.begin();) {
|
||||
auto cnode = *(--iter);
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
|
@ -357,18 +514,18 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) {
|
|||
}
|
||||
// 1. try quant
|
||||
for (int bit_num_t = 2; bit_num_t <= 8; bit_num_t++) {
|
||||
type_id = TypeId::kNumberTypeInt8;
|
||||
type_id_ = TypeId::kNumberTypeInt8;
|
||||
int quant_max_t = (1 << (unsigned int)(bit_num_t - 1)) - 1;
|
||||
int quant_min_t = -(1 << (unsigned int)(bit_num_t - 1));
|
||||
|
||||
if (type_id == TypeId::kNumberTypeInt8) {
|
||||
if (type_id_ == TypeId::kNumberTypeInt8) {
|
||||
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
|
||||
quant_min_t, bit_num_t, true);
|
||||
} else if (type_id == TypeId::kNumberTypeInt16) {
|
||||
} else if (type_id_ == TypeId::kNumberTypeInt16) {
|
||||
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType::QuantType_WeightQuant, quant_max_t,
|
||||
quant_min_t, bit_num_t, true);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unexpected type_id: " << type_id;
|
||||
MS_LOG(ERROR) << "unexpected type_id_: " << type_id_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
|
@ -456,13 +613,53 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
for (auto &cnode : func_graph->GetOrderedCnodes()) {
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "primitive_c is nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto op_name = cnode->fullname_with_scope();
|
||||
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
||||
|
||||
if (quant_strategy_->CanConvOpQuantized(cnode)) {
|
||||
auto status = DoConvQuantize(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoConvQuantize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (quant_strategy_->CanMulOpQuantized(cnode)) {
|
||||
auto status = DoMulQuantize(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoMulQuantize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (op_type == schema::PrimitiveType_Lstm) {
|
||||
auto status = DoLstmQuntize(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoLstmQuntize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (op_type == schema::PrimitiveType_Gather) {
|
||||
auto status = DoGatherQuntize(cnode);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoGatherQuntize error";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << op_name << " of type: " << schema::EnumNamePrimitiveType(op_type) << " no need quant";
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
STATUS ret;
|
||||
auto cnodes = func_graph->GetOrderedCnodes();
|
||||
|
||||
if (!config_file_.empty()) {
|
||||
ret = ParseConfigFile(config_file_, &config_param_);
|
||||
auto ret = ParseConfigFile(config_file_, &config_param_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ReadConfig error.";
|
||||
return RET_ERROR;
|
||||
|
@ -470,20 +667,14 @@ STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|||
}
|
||||
|
||||
if (config_param_.mixed) {
|
||||
bit_num_ = 8;
|
||||
quant_max_ = (1 << (unsigned int)(this->bit_num_ - 1)) - 1;
|
||||
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
|
||||
type_id_ = kNumberTypeInt8;
|
||||
MS_LOG(INFO) << "Do mixed bit quantization";
|
||||
return DoMiexedQuant(func_graph);
|
||||
}
|
||||
|
||||
ret = DoConvQuantize(cnodes);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoConvQuantize failed :" << ret;
|
||||
return ret;
|
||||
}
|
||||
ret = DoMulQuantize(cnodes);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoMulQuantize failed :" << ret;
|
||||
return ret;
|
||||
}
|
||||
return ret;
|
||||
return DoFixedQuant(func_graph);
|
||||
}
|
||||
} // namespace mindspore::lite::quant
|
||||
|
|
|
@ -41,19 +41,21 @@ class WeightQuantizer : public Quantizer {
|
|||
~WeightQuantizer();
|
||||
|
||||
STATUS DoQuantize(FuncGraphPtr func_graph) override;
|
||||
STATUS DoConvQuantize(const std::list<CNodePtr> &nodes);
|
||||
STATUS DoMulQuantize(const std::list<CNodePtr> &nodes);
|
||||
STATUS DoConvQuantize(CNodePtr);
|
||||
STATUS DoMulQuantize(CNodePtr);
|
||||
STATUS DoLstmQuntize(CNodePtr cnode);
|
||||
STATUS DoGatherQuntize(CNodePtr cnode);
|
||||
static STATUS WeightQuantInputCheck(const converter::Flags *config);
|
||||
static bool IsPosNum(const std::string &str);
|
||||
|
||||
int quant_max;
|
||||
int quant_min;
|
||||
TypeId type_id{kTypeUnknown};
|
||||
int quant_max_{127};
|
||||
int quant_min_{-128};
|
||||
TypeId type_id_{kNumberTypeInt8};
|
||||
std::map<std::string, int> opname_bit_;
|
||||
|
||||
private:
|
||||
std::unique_ptr<QuantStrategy> quant_strategy_;
|
||||
size_t bit_num_;
|
||||
size_t bit_num_{8};
|
||||
std::string config_file_;
|
||||
PostQuantConfig config_param_;
|
||||
std::vector<std::vector<std::string>> images_; // multi_input, [[mode_input_0], [model_input_1]...]
|
||||
|
@ -61,6 +63,7 @@ class WeightQuantizer : public Quantizer {
|
|||
|
||||
STATUS DoMiexedQuant(FuncGraphPtr);
|
||||
STATUS SetAbstract(ParamValueLitePtr param_value, ParameterPtr param_node, std::shared_ptr<PrimitiveC> primitive_c);
|
||||
STATUS DoFixedQuant(FuncGraphPtr);
|
||||
};
|
||||
} // namespace mindspore::lite::quant
|
||||
#endif
|
||||
|
|
Loading…
Reference in New Issue