!5697 mat mul weight quant
Merge pull request !5697 from wangchangkai/master
This commit is contained in:
commit
ed71f1134d
|
@ -333,6 +333,10 @@ int ConvolutionBaseCPUKernel::SetQuantParam() {
|
|||
}
|
||||
int ConvolutionBaseCPUKernel::RestoreFilter(lite::tensor::Tensor *input_tensor) {
|
||||
MS_ASSERT(input_tensor != nullptr);
|
||||
if (input_tensor->data_type() != kNumberTypeUInt8) {
|
||||
MS_LOG(ERROR) << "conv weight input type error" << input_tensor->data_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (input_tensor->GetQuantParams().empty()) {
|
||||
MS_LOG(ERROR) << "no quant param";
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -53,7 +53,52 @@ kernel::LiteKernel *CpuFullConnectionInt8KernelCreator(const std::vector<lite::t
|
|||
}
|
||||
return kernel;
|
||||
}
|
||||
int RestoreFullconnectWeight(lite::tensor::Tensor *input_tensor) {
|
||||
MS_ASSERT(input_tensor != nullptr);
|
||||
if (input_tensor->data_type() != kNumberTypeUInt8) {
|
||||
MS_LOG(ERROR) << "full connect input type error" << input_tensor->data_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (input_tensor->GetQuantParams().empty()) {
|
||||
MS_LOG(ERROR) << "no quant param";
|
||||
return RET_ERROR;
|
||||
}
|
||||
const auto* quant_data = static_cast<const uint8_t*>(input_tensor->Data());
|
||||
auto* dequant_data = static_cast<float *>(malloc(input_tensor->DataSize() * sizeof(float)));
|
||||
if (dequant_data == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc faile";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (input_tensor->GetQuantParams().size() != kPerTensor) {
|
||||
size_t channels = static_cast<size_t>(input_tensor->Batch());
|
||||
if (input_tensor->GetQuantParams().size() != channels) {
|
||||
MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels;
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t per_channel_size = input_tensor->DataSize() / channels;
|
||||
auto quant_param = input_tensor->GetQuantParams();
|
||||
for (size_t i = 0; i < channels; i++) {
|
||||
auto param = quant_param.at(i);
|
||||
auto scale = param.scale;
|
||||
auto zero_point = param.zeroPoint;
|
||||
for (size_t j = 0; j < per_channel_size; j++) {
|
||||
dequant_data[per_channel_size * i + j] = static_cast<float>(
|
||||
(quant_data[per_channel_size * i + j] - zero_point) * scale);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto quant_param = input_tensor->GetQuantParams();
|
||||
auto param = quant_param.front();
|
||||
auto scale = param.scale;
|
||||
auto zero_point = param.zeroPoint;
|
||||
for (int64_t j = 0; j < input_tensor->DataSize(); j++) {
|
||||
dequant_data[j] = static_cast<float>((quant_data[j] - zero_point) * scale);
|
||||
}
|
||||
}
|
||||
input_tensor->SetData(dequant_data);
|
||||
return RET_OK;
|
||||
}
|
||||
kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
|
@ -61,6 +106,11 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::t
|
|||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->Data();
|
||||
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
RestoreFullconnectWeight(inputs.at(kWeightIndex));
|
||||
}
|
||||
auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (!kernel) {
|
||||
MS_LOG(ERROR) << "kernel is nullptr.";
|
||||
|
@ -73,6 +123,10 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::t
|
|||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
return nullptr;
|
||||
}
|
||||
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "nnacl/matmul_parameter.h"
|
||||
|
||||
using mindspore::lite::Context;
|
||||
static constexpr int kPerTensor = 1;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class FullconnectionBaseCPUKernel : public LiteKernel {
|
||||
|
|
|
@ -26,12 +26,65 @@ using mindspore::lite::RET_OK;
|
|||
using mindspore::schema::PrimitiveType_MatMul;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int RestoreMatmulWeight(lite::tensor::Tensor *input_tensor) {
|
||||
MS_ASSERT(input_tensor != nullptr);
|
||||
if (input_tensor->data_type() != kNumberTypeUInt8) {
|
||||
MS_LOG(ERROR) << "mat mul input type error" << input_tensor->data_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (input_tensor->GetQuantParams().empty()) {
|
||||
MS_LOG(ERROR) << "no quant param";
|
||||
return RET_ERROR;
|
||||
}
|
||||
const auto* quant_data = static_cast<const uint8_t*>(input_tensor->Data());
|
||||
auto* dequant_data = static_cast<float *>(malloc(input_tensor->DataSize() * sizeof(float)));
|
||||
if (dequant_data == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc faile";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (input_tensor->GetQuantParams().size() != kPerTensor) {
|
||||
size_t channels = static_cast<size_t>(input_tensor->Batch());
|
||||
if (input_tensor->GetQuantParams().size() != channels) {
|
||||
MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels;
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t per_channel_size = input_tensor->DataSize() / channels;
|
||||
auto quant_param = input_tensor->GetQuantParams();
|
||||
for (size_t i = 0; i < channels; i++) {
|
||||
auto param = quant_param.at(i);
|
||||
auto scale = param.scale;
|
||||
auto zero_point = param.zeroPoint;
|
||||
for (size_t j = 0; j < per_channel_size; j++) {
|
||||
dequant_data[per_channel_size * i + j] = static_cast<float>(
|
||||
(quant_data[per_channel_size * i + j] - zero_point) * scale);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto quant_param = input_tensor->GetQuantParams();
|
||||
auto param = quant_param.front();
|
||||
auto scale = param.scale;
|
||||
auto zero_point = param.zeroPoint;
|
||||
for (int64_t j = 0; j < input_tensor->DataSize(); j++) {
|
||||
dequant_data[j] = static_cast<float>((quant_data[j] - zero_point) * scale);
|
||||
}
|
||||
}
|
||||
input_tensor->SetData(dequant_data);
|
||||
return RET_OK;
|
||||
}
|
||||
kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, OpParameter *opParameter,
|
||||
const lite::Context *ctx, const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
|
||||
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->Data();
|
||||
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
RestoreMatmulWeight(inputs.at(kWeightIndex));
|
||||
}
|
||||
|
||||
auto input_tensor = inputs.at(kInputIndex);
|
||||
auto data_type = input_tensor->data_type();
|
||||
kernel::LiteKernel *kernel = nullptr;
|
||||
|
@ -51,6 +104,12 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::tensor::Tenso
|
|||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "nnacl/matmul_parameter.h"
|
||||
|
||||
using mindspore::lite::Context;
|
||||
static constexpr int kPerTensor = 1;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class MatmulBaseCPUKernel : public LiteKernel {
|
||||
|
|
|
@ -69,8 +69,58 @@ int ArithmeticSelfCPUKernel::DoArithmeticSelf(int task_id) {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
int RestoreMulWeight(lite::tensor::Tensor *input_tensor) {
|
||||
MS_ASSERT(input_tensor != nullptr);
|
||||
if (input_tensor->data_type() != kNumberTypeUInt8) {
|
||||
MS_LOG(ERROR) << "full connect input type error" << input_tensor->data_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (input_tensor->GetQuantParams().empty()) {
|
||||
MS_LOG(ERROR) << "no quant param";
|
||||
return RET_ERROR;
|
||||
}
|
||||
const auto* quant_data = static_cast<const uint8_t*>(input_tensor->Data());
|
||||
auto* dequant_data = static_cast<float *>(malloc(input_tensor->DataSize() * sizeof(float)));
|
||||
if (dequant_data == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc faile";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (input_tensor->GetQuantParams().size() != kPerTensor) {
|
||||
size_t channels = static_cast<size_t>(input_tensor->Batch());
|
||||
if (input_tensor->GetQuantParams().size() != channels) {
|
||||
MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels;
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t per_channel_size = input_tensor->DataSize() / channels;
|
||||
auto quant_param = input_tensor->GetQuantParams();
|
||||
for (size_t i = 0; i < channels; i++) {
|
||||
auto param = quant_param.at(i);
|
||||
auto scale = param.scale;
|
||||
auto zero_point = param.zeroPoint;
|
||||
for (size_t j = 0; j < per_channel_size; j++) {
|
||||
dequant_data[per_channel_size * i + j] = static_cast<float>(
|
||||
(quant_data[per_channel_size * i + j] - zero_point) * scale);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto quant_param = input_tensor->GetQuantParams();
|
||||
auto param = quant_param.front();
|
||||
auto scale = param.scale;
|
||||
auto zero_point = param.zeroPoint;
|
||||
for (int64_t j = 0; j < input_tensor->DataSize(); j++) {
|
||||
dequant_data[j] = static_cast<float>((quant_data[j] - zero_point) * scale);
|
||||
}
|
||||
}
|
||||
input_tensor->SetData(dequant_data);
|
||||
return RET_OK;
|
||||
}
|
||||
int ArithmeticSelfCPUKernel::Run() {
|
||||
void *restore_data = nullptr;
|
||||
if (primitive_->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
restore_data = in_tensors_[1]->Data();
|
||||
RestoreMulWeight(in_tensors_[1]);
|
||||
}
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
|
||||
|
@ -85,6 +135,10 @@ int ArithmeticSelfCPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]";
|
||||
return ret;
|
||||
}
|
||||
if (primitive_->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
in_tensors_[1]->FreeData();
|
||||
in_tensors_[1]->SetData(restore_data);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ using mindspore::schema::PrimitiveType_Rsqrt;
|
|||
using mindspore::schema::PrimitiveType_Sin;
|
||||
using mindspore::schema::PrimitiveType_Sqrt;
|
||||
using mindspore::schema::PrimitiveType_Square;
|
||||
static constexpr int kPerTensor = 1;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ArithmeticSelfCPUKernel : public LiteKernel {
|
||||
|
|
|
@ -169,13 +169,63 @@ int ScaleCPUKernel::Run() {
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
int RestoreScaleWeight(lite::tensor::Tensor *input_tensor) {
|
||||
MS_ASSERT(input_tensor != nullptr);
|
||||
if (input_tensor->data_type() != kNumberTypeUInt8) {
|
||||
MS_LOG(ERROR) << "mat mul input type error" << input_tensor->data_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (input_tensor->GetQuantParams().empty()) {
|
||||
MS_LOG(ERROR) << "no quant param";
|
||||
return RET_ERROR;
|
||||
}
|
||||
const auto* quant_data = static_cast<const uint8_t*>(input_tensor->Data());
|
||||
auto* dequant_data = static_cast<float *>(malloc(input_tensor->DataSize() * sizeof(float)));
|
||||
if (dequant_data == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc faile";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
if (input_tensor->GetQuantParams().size() != kPerTensor) {
|
||||
size_t channels = static_cast<size_t>(input_tensor->Batch());
|
||||
if (input_tensor->GetQuantParams().size() != channels) {
|
||||
MS_LOG(ERROR) << "Quant param not equal channel num " << input_tensor->GetQuantParams().size() << channels;
|
||||
return RET_ERROR;
|
||||
}
|
||||
size_t per_channel_size = input_tensor->DataSize() / channels;
|
||||
auto quant_param = input_tensor->GetQuantParams();
|
||||
for (size_t i = 0; i < channels; i++) {
|
||||
auto param = quant_param.at(i);
|
||||
auto scale = param.scale;
|
||||
auto zero_point = param.zeroPoint;
|
||||
for (size_t j = 0; j < per_channel_size; j++) {
|
||||
dequant_data[per_channel_size * i + j] = static_cast<float>(
|
||||
(quant_data[per_channel_size * i + j] - zero_point) * scale);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
auto quant_param = input_tensor->GetQuantParams();
|
||||
auto param = quant_param.front();
|
||||
auto scale = param.scale;
|
||||
auto zero_point = param.zeroPoint;
|
||||
for (int64_t j = 0; j < input_tensor->DataSize(); j++) {
|
||||
dequant_data[j] = static_cast<float>((quant_data[j] - zero_point) * scale);
|
||||
}
|
||||
}
|
||||
input_tensor->SetData(dequant_data);
|
||||
return RET_OK;
|
||||
}
|
||||
kernel::LiteKernel *CpuScaleFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_Scale);
|
||||
auto *weight_tensor = inputs.at(kWeightIndex);
|
||||
auto *restore_data = weight_tensor->Data();
|
||||
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
RestoreScaleWeight(inputs.at(kWeightIndex));
|
||||
}
|
||||
if (opParameter == nullptr) {
|
||||
MS_LOG(ERROR) << "opParameter is nullptr";
|
||||
return nullptr;
|
||||
|
@ -193,6 +243,10 @@ kernel::LiteKernel *CpuScaleFp32KernelCreator(const std::vector<lite::tensor::Te
|
|||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) {
|
||||
weight_tensor->FreeData();
|
||||
weight_tensor->SetData(restore_data);
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "src/lite_kernel.h"
|
||||
#include "nnacl/fp32/scale.h"
|
||||
|
||||
static constexpr int kPerTensor = 1;
|
||||
namespace mindspore::kernel {
|
||||
|
||||
class ScaleCPUKernel : public LiteKernel {
|
||||
|
|
|
@ -35,6 +35,7 @@ Flags::Flags() {
|
|||
AddFlag(&Flags::inputInferenceTypeIn, "inputInferenceType", "Input inference data type. FLOAT | INT8", "FLOAT");
|
||||
AddFlag(&Flags::stdDev, "stdDev", "Standard deviation value for aware-quantization", "128");
|
||||
AddFlag(&Flags::mean, "mean", "Mean value for aware-quantization", "-0.5");
|
||||
AddFlag(&Flags::bitNum, "bitNum", "Weight quantization bitNum", "8");
|
||||
AddFlag(&Flags::quantSize, "quantSize", "Weight quantization size threshold", "0");
|
||||
AddFlag(&Flags::convWeightQuantChannelThreshold, "convWeightQuantChannelThreshold",
|
||||
"convWeightQuantChannelThreshold", "16");
|
||||
|
|
|
@ -49,13 +49,13 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto inputNode = cnode->input(2);
|
||||
if (!inputNode->isa<Parameter>()) {
|
||||
auto input_node = cnode->input(2);
|
||||
if (!input_node->isa<Parameter>()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
auto paramNode = inputNode->cast<ParameterPtr>();
|
||||
if (!paramNode->has_default()) {
|
||||
auto param_node = input_node->cast<ParameterPtr>();
|
||||
if (!param_node->has_default()) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
|
@ -65,14 +65,26 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
|
|||
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
||||
bool depthwise = op_type == schema::PrimitiveType_DepthwiseConv2D ? true : false;
|
||||
|
||||
ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(paramNode->default_param());
|
||||
ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
|
||||
auto status = QuantFilter<uint8_t>(param_value, primitive_c, QuantType_WeightQuant, 255, 0,
|
||||
bitNum, true, depthwise);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
// set dtype
|
||||
param_value->set_tensor_type(kNumberTypeUInt8);
|
||||
auto abstractBase = param_node->abstract();
|
||||
if (abstractBase == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
||||
abstractTensor->element()->set_type(TypeIdToType(kNumberTypeUInt8));
|
||||
primitive_c->SetQuantType(schema::QuantType_WeightQuant);
|
||||
}
|
||||
|
||||
|
@ -86,14 +98,14 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
|
|||
}
|
||||
|
||||
ParamValueLitePtr param_value = nullptr;
|
||||
ParameterPtr param_node = nullptr;
|
||||
for (size_t i = 1; i < node->size(); i++) {
|
||||
auto inputNode = node->input(i);
|
||||
if (inputNode->isa<Parameter>() == true) {
|
||||
auto paramNode = inputNode->cast<ParameterPtr>();
|
||||
if ((paramNode != nullptr) && (paramNode->has_default() == true)) {
|
||||
param_value = std::static_pointer_cast<ParamValueLite>(paramNode->default_param());
|
||||
param_node = inputNode->cast<ParameterPtr>();
|
||||
if ((param_node != nullptr) && (param_node->has_default() == true)) {
|
||||
param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
|
||||
if ((param_value == nullptr) || (param_value->tensor_size() == 0)
|
||||
|| (param_value->tensor_shape().size() != 4)
|
||||
|| (param_value->tensor_addr() == nullptr)
|
||||
|| (param_value->tensor_type() != mindspore::kNumberTypeFloat32)) {
|
||||
param_value = nullptr;
|
||||
|
@ -115,12 +127,26 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
|
|||
return RET_ERROR;
|
||||
}
|
||||
|
||||
std::vector<schema::QuantParamT> quant_params;
|
||||
primitive_c->AddInputQuantParam(quant_params);
|
||||
auto status = QuantFilter<uint8_t>(param_value, primitive_c, QuantType_WeightQuant, 255, 0, bitNum, true, false);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
||||
return status;
|
||||
}
|
||||
param_value->set_tensor_type(kNumberTypeUInt8);
|
||||
// set dtype
|
||||
auto abstractBase = param_node->abstract();
|
||||
if (abstractBase == nullptr) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
|
||||
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
||||
abstractTensor->element()->set_type(TypeIdToType(kNumberTypeUInt8));
|
||||
primitive_c->SetQuantType(schema::QuantType_WeightQuant);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue