!17432 [MS][LITE][CPU] lsq quant_aware_training method test on mslite
From: @zhang__sss Reviewed-by: @zlq2020,@hangangqiang Signed-off-by: @hangangqiang
This commit is contained in:
commit
b96dac51bd
|
@ -23,8 +23,11 @@ from ..._checkparam import Validator
|
|||
from ...common import Tensor
|
||||
from ...common import dtype as mstype
|
||||
from ...common.api import _executor
|
||||
from ...common.parameter import Parameter
|
||||
from ...nn import Cell
|
||||
from ...nn.layer import quant
|
||||
from ...ops import operations as P
|
||||
from ...ops import functional as F
|
||||
from ...ops.operations import _inner_ops as inner
|
||||
from ..quant import quant_utils
|
||||
from ..quant.qat import _AddFakeQuantInput, _AddFakeQuantAfterSubCell
|
||||
|
@ -32,6 +35,151 @@ from ..quant.qat import _AddFakeQuantInput, _AddFakeQuantAfterSubCell
|
|||
|
||||
__all__ = ["ExportToQuantInferNetwork"]
|
||||
|
||||
class QuantBlock(Cell):
|
||||
r"""
|
||||
A quant block of Conv/Dense, activation layer for Ascend deploy.
|
||||
|
||||
Calculate Conv or Dense in Int8, with Quant and DeQuant.
|
||||
|
||||
Notes:
|
||||
This block is only for deploy, and not trainable.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input space.
|
||||
out_channels (int): The number of channels in the output space.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
||||
activation (string): Specifies activation type. The optional values are as following:
|
||||
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, out\_channels)`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
core_op,
|
||||
weight,
|
||||
quant_op,
|
||||
dequant_op,
|
||||
dequant_scale,
|
||||
bias=None,
|
||||
activation=None):
|
||||
super(QuantBlock, self).__init__()
|
||||
self.core_op = core_op
|
||||
self.weight = weight
|
||||
self.quant = quant_op
|
||||
self.dequant = dequant_op
|
||||
self.dequant_scale = dequant_scale
|
||||
self.bias = bias
|
||||
self.has_bias = bias is not None
|
||||
self.activation = activation
|
||||
self.has_act = activation is not None
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.sub = P.Sub()
|
||||
self.weight_offset = Parameter(np.zeros(shape=weight.shape, dtype=np.int8), name='weight_offset')
|
||||
|
||||
def construct(self, x):
|
||||
x = self.quant(x)
|
||||
if self.has_bias:
|
||||
weight = self.sub(self.weight, self.weight_offset)
|
||||
x = self.core_op(x, weight)
|
||||
x = self.bias_add(x, self.bias)
|
||||
else:
|
||||
x = self.core_op(x, self.weight)
|
||||
x = self.dequant(x, self.dequant_scale)
|
||||
x = F.cast(x, mstype.float32)
|
||||
if self.has_act:
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
def extend_repr(self):
|
||||
s = f'quant={self.quant}, core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]'
|
||||
if self.has_bias:
|
||||
s += f', bias=shape[{self.bias.shape}]'
|
||||
if self.has_act:
|
||||
s += f', activation={self.activation}'
|
||||
s += f', dequant={self.dequant}'
|
||||
return s
|
||||
|
||||
|
||||
class QuantMindirBlock(Cell):
|
||||
"""A quant binary block of Conv/Dense, activation layer for export MINDIR model.
|
||||
|
||||
Args:
|
||||
core_op (Cell): The operation cell.
|
||||
weight (Tensor): The weight of the cell.
|
||||
bias (Tensor): The bias of the cell. Default: None.
|
||||
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
param_dict (dict): The information of the cell.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
core_op,
|
||||
weight,
|
||||
bias=None,
|
||||
activation=None,
|
||||
param_dict=None):
|
||||
|
||||
super(QuantMindirBlock, self).__init__()
|
||||
self.core_op = core_op
|
||||
if activation is not None:
|
||||
self.core_op.add_prim_attr("activation_name", activation.__class__.__name__)
|
||||
self.core_op.add_prim_attr("filter_maxq", Tensor(param_dict["filter_maxq"]))
|
||||
self.core_op.add_prim_attr("filter_minq", Tensor(param_dict["filter_minq"]))
|
||||
if param_dict["output_maxq"] is not None:
|
||||
self.core_op.add_prim_attr("output_maxq", Tensor(param_dict["output_maxq"]))
|
||||
self.core_op.add_prim_attr("output_minq", Tensor(param_dict["output_minq"]))
|
||||
self.core_op.add_prim_attr("symmetric", Tensor(param_dict["symmetric"]))
|
||||
if hasattr(core_op, 'pad_mode'):
|
||||
self.core_op.add_prim_attr("pad_mode", core_op.pad_mode)
|
||||
self.core_op.add_prim_attr("act_num_bits", Tensor(8))
|
||||
self.core_op.add_prim_attr("weight_num_bits", Tensor(param_dict["weight_num_bits"]))
|
||||
self.core_op.add_prim_attr("weight_narrow_range", Tensor(param_dict["weight_narrow_range"]))
|
||||
if param_dict["input_narrow_range"] is not None:
|
||||
self.core_op.add_prim_attr("input_narrow_range", Tensor(param_dict["input_narrow_range"]))
|
||||
if param_dict["output_narrow_range"] is not None:
|
||||
self.core_op.add_prim_attr("output_narrow_range", Tensor(param_dict["output_narrow_range"]))
|
||||
if param_dict["input_maxq"] == 'None':
|
||||
self.core_op.add_prim_attr("mean", Tensor(param_dict["mean"]))
|
||||
self.core_op.add_prim_attr("std_dev", Tensor(param_dict["std_dev"]))
|
||||
elif param_dict["input_maxq"] is not None:
|
||||
self.core_op.add_prim_attr("input_maxq", Tensor(param_dict["input_maxq"]))
|
||||
self.core_op.add_prim_attr("input_minq", Tensor(param_dict["input_minq"]))
|
||||
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
self.has_bias = bias is not None
|
||||
self.activation = activation
|
||||
self.has_act = activation is not None
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
def construct(self, x):
|
||||
if self.has_bias:
|
||||
x = self.core_op(x, self.weight)
|
||||
x = self.bias_add(x, self.bias)
|
||||
else:
|
||||
x = self.core_op(x, self.weight)
|
||||
if self.has_act:
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
def extend_repr(self):
|
||||
s = f'core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]'
|
||||
if self.has_bias:
|
||||
s += f', bias=shape[{self.bias.shape}]'
|
||||
if self.has_act:
|
||||
s += f', activation={self.activation}'
|
||||
return s
|
||||
|
||||
class ExportToQuantInferNetwork:
|
||||
"""
|
||||
Convert quantization aware network to infer network.
|
||||
|
@ -92,15 +240,20 @@ class ExportToQuantInferNetwork:
|
|||
param_dict["output_minq"] = None
|
||||
param_dict["input_maxq"] = None
|
||||
param_dict["input_minq"] = None
|
||||
param_dict["input_narrow_range"] = None
|
||||
param_dict["output_narrow_range"] = None
|
||||
param_dict["weight_narrow_range"] = cell_core.fake_quant_weight.narrow_range
|
||||
param_dict["mean"] = self.mean
|
||||
param_dict["std_dev"] = self.std_dev
|
||||
param_dict["symmetric"] = cell_core.fake_quant_weight.symmetric
|
||||
param_dict["weight_num_bits"] = cell_core.fake_quant_weight.num_bits
|
||||
|
||||
scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
|
||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type)
|
||||
if fake_quant_a_out is not None:
|
||||
_, _, param_dict["output_maxq"], param_dict["output_minq"] = \
|
||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type)
|
||||
param_dict["output_narrow_range"] = fake_quant_a_out.narrow_range
|
||||
|
||||
info = self.quant_info_table.get(w_minq_name, None)
|
||||
if not info:
|
||||
|
@ -120,6 +273,7 @@ class ExportToQuantInferNetwork:
|
|||
|
||||
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
|
||||
quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_in, np_type)
|
||||
param_dict["input_narrow_range"] = fake_quant_a_in.narrow_range
|
||||
else:
|
||||
# skip quant layer
|
||||
scale_a_in, zp_a_in = 1.0, 0.0
|
||||
|
@ -175,9 +329,9 @@ class ExportToQuantInferNetwork:
|
|||
if bias_b is not None:
|
||||
bias_b = Tensor(bias_b, mstype.float32)
|
||||
if self.is_mindir:
|
||||
block = quant.QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
|
||||
block = QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
|
||||
else:
|
||||
block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
|
||||
block = QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
|
||||
return block
|
||||
|
||||
def _add_output_min_max_for_op(self, origin_op, fake_quant_cell):
|
||||
|
|
|
@ -138,6 +138,10 @@ def scale_zp_max_min_from_fake_quant_cell(cell, data_type):
|
|||
"""Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMaxObserver`."""
|
||||
minq = cell.minq.data.asnumpy()
|
||||
maxq = cell.maxq.data.asnumpy()
|
||||
# make sure maxq > 0 and minq <= 0
|
||||
if cell.mode == 'LEARNED_SCALE':
|
||||
maxq = np.abs(maxq)
|
||||
minq = -np.abs(minq)
|
||||
|
||||
scale, zp = cal_quantization_params(
|
||||
minq, maxq, data_type,
|
||||
|
|
|
@ -69,14 +69,10 @@ T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
|
|||
const auto zeroPoint = quantParam->zeroPoint;
|
||||
const auto numBit = quantParam->numBits;
|
||||
const auto narrowRange = quantParam->narrowRange;
|
||||
double maxLimitTemp = static_cast<float>((1 << (unsigned int)numBit) - 1);
|
||||
const double maxLimit = static_cast<float>(maxLimitTemp - zeroPoint + std::numeric_limits<T>::min()) * scale;
|
||||
double minLimit;
|
||||
if (narrowRange) {
|
||||
minLimit = static_cast<float>(std::numeric_limits<T>::min() + 1 - zeroPoint) * scale;
|
||||
} else {
|
||||
minLimit = static_cast<float>(std::numeric_limits<T>::min() - zeroPoint) * scale;
|
||||
}
|
||||
const int32_t quantMax = (1 << (unsigned int)(numBit - 1)) - 1;
|
||||
const int32_t quantMin = -1 * (1 << (unsigned int)(numBit - 1)) + (narrowRange ? 1 : 0);
|
||||
const double maxLimit = static_cast<float>(quantMax - zeroPoint) * scale;
|
||||
const double minLimit = static_cast<float>(quantMin - zeroPoint) * scale;
|
||||
|
||||
return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
|
||||
double tmp;
|
||||
|
|
|
@ -118,7 +118,7 @@ static STATUS CompressTensor(schema::TensorT *tensor_input, const std::unique_pt
|
|||
}
|
||||
}
|
||||
|
||||
if (bit_num != 8 && bit_num != 16 && !repetition_packed) {
|
||||
if (bit_num != 8 && bit_num != 16 && !repetition_packed && dst_node->quantType != schema::QuantType_QUANT_NONE) {
|
||||
auto status = DoBitPack(bit_num, tensor_input);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "do bit pack failed. " << status;
|
||||
|
|
|
@ -27,7 +27,8 @@
|
|||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t numbits) {
|
||||
int ConvertInputQuantParam(const PrimitivePtr &prim, bool input_narrow_range, bool weight_narrow_range,
|
||||
int32_t act_numbits, int32_t weight_numbits) {
|
||||
auto quant_param_holder = prim->GetAttr("quant_params")->cast<lite::QuantParamHolderPtr>();
|
||||
std::vector<schema::QuantParamT> quants;
|
||||
schema::QuantParamT quant_param;
|
||||
|
@ -40,8 +41,8 @@ int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t
|
|||
auto *max_buf = static_cast<float *>(input_max_ptr->data_c());
|
||||
quant_param.min = *min_buf;
|
||||
quant_param.max = *max_buf;
|
||||
auto ret =
|
||||
lite::quant::CalQuantizationParams(&quant_param, quant_param.min, quant_param.max, narrow_range, numbits);
|
||||
auto ret = lite::quant::CalQuantizationParams(&quant_param, quant_param.min, quant_param.max, input_narrow_range,
|
||||
act_numbits);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Can't calculate quant parameters";
|
||||
return ret;
|
||||
|
@ -64,8 +65,8 @@ int ConvertInputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t
|
|||
schema::QuantParamT tmp_quant_param;
|
||||
tmp_quant_param.min = *min_buf;
|
||||
tmp_quant_param.max = *max_buf;
|
||||
auto ret =
|
||||
lite::quant::CalQuantizationParams(&tmp_quant_param, tmp_quant_param.min, tmp_quant_param.max, true, numbits);
|
||||
auto ret = lite::quant::CalQuantizationParams(&tmp_quant_param, tmp_quant_param.min, tmp_quant_param.max,
|
||||
weight_narrow_range, weight_numbits);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Can't calculate quant parameters";
|
||||
return ret;
|
||||
|
@ -104,39 +105,77 @@ int ConvertOutputQuantParam(const PrimitivePtr &prim, bool narrow_range, int32_t
|
|||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int ConvertQuantParam(const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
auto narrow_range = prim->GetAttr("narrow_range");
|
||||
bool narrow_range_param = false;
|
||||
int GetNarrowRange(const PrimitivePtr &prim, const std::string &narrow_range_str, bool *narrow_range_param) {
|
||||
auto narrow_range = prim->GetAttr(narrow_range_str);
|
||||
if (narrow_range != nullptr) {
|
||||
if (utils::isa<tensor::TensorPtr>(narrow_range)) {
|
||||
auto narrow_range_tensor = narrow_range->cast<tensor::TensorPtr>();
|
||||
narrow_range_param = *reinterpret_cast<bool *>(narrow_range_tensor->data_c());
|
||||
*narrow_range_param = *reinterpret_cast<bool *>(narrow_range_tensor->data_c());
|
||||
} else if (utils::isa<ImmTraits<bool>::type>(narrow_range)) {
|
||||
narrow_range_param = GetValue<bool>(narrow_range);
|
||||
*narrow_range_param = GetValue<bool>(narrow_range);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "valueptr is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto num_bits = prim->GetAttr("num_bits");
|
||||
int32_t num_bits_param = 8;
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int GetNumBits(const PrimitivePtr &prim, const std::string &num_bits_str, int *num_bits_param) {
|
||||
auto num_bits = prim->GetAttr(num_bits_str);
|
||||
if (num_bits != nullptr) {
|
||||
if (utils::isa<tensor::TensorPtr>(num_bits)) {
|
||||
auto num_bits_tensor = num_bits->cast<tensor::TensorPtr>();
|
||||
num_bits_param = *reinterpret_cast<int64_t *>(num_bits_tensor->data_c());
|
||||
*num_bits_param = *reinterpret_cast<int64_t *>(num_bits_tensor->data_c());
|
||||
} else if (utils::isa<ImmTraits<int64_t>::type>(num_bits)) {
|
||||
num_bits_param = GetValue<int64_t>(num_bits);
|
||||
*num_bits_param = GetValue<int64_t>(num_bits);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "valueptr is invalid.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
}
|
||||
auto status = ConvertInputQuantParam(prim, narrow_range_param, num_bits_param);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int ConvertQuantParam(const PrimitivePtr &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
bool input_narrow_range_param = false;
|
||||
auto status = GetNarrowRange(prim, "input_narrow_range", &input_narrow_range_param);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get input narrow range failed.";
|
||||
return status;
|
||||
}
|
||||
bool weight_narrow_range_param = true;
|
||||
status = GetNarrowRange(prim, "weight_narrow_range", &weight_narrow_range_param);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get weight narrow range failed.";
|
||||
return status;
|
||||
}
|
||||
bool output_narrow_range_param = false;
|
||||
status = GetNarrowRange(prim, "output_narrow_range", &output_narrow_range_param);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get output narrow range failed.";
|
||||
return status;
|
||||
}
|
||||
|
||||
int32_t act_num_bits_param = 8;
|
||||
status = GetNumBits(prim, "act_num_bits", &act_num_bits_param);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get activation num_bits failed.";
|
||||
return status;
|
||||
}
|
||||
int32_t weight_num_bits_param = 8;
|
||||
status = GetNumBits(prim, "weight_num_bits", &weight_num_bits_param);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "get weight num_bits failed.";
|
||||
return status;
|
||||
}
|
||||
status = ConvertInputQuantParam(prim, input_narrow_range_param, weight_narrow_range_param, act_num_bits_param,
|
||||
weight_num_bits_param);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "compute int quant param failed.";
|
||||
return status;
|
||||
}
|
||||
status = ConvertOutputQuantParam(prim, narrow_range_param, num_bits_param);
|
||||
status = ConvertOutputQuantParam(prim, output_narrow_range_param, act_num_bits_param);
|
||||
if (status != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "compute output quant param failed.";
|
||||
return status;
|
||||
|
|
|
@ -175,6 +175,11 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
|
|||
MS_LOG(ERROR) << "compute tensor to int8 prechannel failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int bit_num = tensor->quantParams.front()->numBits;
|
||||
if (DoBitPack(bit_num, tensor.get()) != RET_OK) {
|
||||
MS_LOG(ERROR) << "bit pack failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
index++;
|
||||
continue;
|
||||
}
|
||||
|
@ -183,6 +188,11 @@ STATUS TensorQuantPass::Run(schema::MetaGraphT *graph) {
|
|||
if (quantParam->dstDtype == TypeId::kNumberTypeInt8 || quantParam->dstDtype == TypeId::kNumberTypeUInt8 ||
|
||||
quantParam->dstDtype == TypeId::kNumberTypeFloat32 || quantParam->dstDtype == TypeId::kNumberTypeFloat) {
|
||||
status = ComputeDataToInt8(tensor, index);
|
||||
int bit_num = tensor->quantParams.front()->numBits;
|
||||
if (DoBitPack(bit_num, tensor.get()) != RET_OK) {
|
||||
MS_LOG(ERROR) << "bit pack failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else if (quantParam->dstDtype == TypeId::kNumberTypeInt32) {
|
||||
// quant bias data
|
||||
status = ComputeDataToInt32(tensor);
|
||||
|
|
|
@ -297,8 +297,8 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
const int8_t quantMin = std::numeric_limits<int8_t>::min() + (narrowRange ? 1 : 0);
|
||||
const int8_t quantMax = std::numeric_limits<int8_t>::max();
|
||||
const int8_t quantMax = (1 << (unsigned int)(numBits - 1)) - 1;
|
||||
const int8_t quantMin = -1 * (1 << (unsigned int)(numBits - 1)) + (narrowRange ? 1 : 0);
|
||||
auto quantMinFloat = static_cast<double>(quantMin);
|
||||
auto quantMaxFloat = static_cast<double>(quantMax);
|
||||
if (fabs(quantMaxFloat - quantMinFloat) <= 0.0f) {
|
||||
|
|
|
@ -20,7 +20,6 @@ import numpy as np
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops.primitive import Primitive
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -1638,144 +1637,3 @@ class MulQuant(Cell):
|
|||
x = self.mul(x1, x2)
|
||||
x = self.fake_quant_act(x)
|
||||
return x
|
||||
|
||||
|
||||
class QuantBlock(Cell):
|
||||
r"""
|
||||
A quant block of Conv/Dense, activation layer for Ascend deploy.
|
||||
|
||||
Calculate Conv or Dense in Int8, with Quant and DeQuant.
|
||||
|
||||
Notes:
|
||||
This block is only for deploy, and not trainable.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input space.
|
||||
out_channels (int): The number of channels in the output space.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
||||
activation (string): Specifies activation type. The optional values are as following:
|
||||
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, out\_channels)`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
core_op,
|
||||
weight,
|
||||
quant_op,
|
||||
dequant_op,
|
||||
dequant_scale,
|
||||
bias=None,
|
||||
activation=None):
|
||||
super(QuantBlock, self).__init__()
|
||||
self.core_op = core_op
|
||||
self.weight = weight
|
||||
self.quant = quant_op
|
||||
self.dequant = dequant_op
|
||||
self.dequant_scale = dequant_scale
|
||||
self.bias = bias
|
||||
self.has_bias = bias is not None
|
||||
self.activation = activation
|
||||
self.has_act = activation is not None
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.sub = P.Sub()
|
||||
self.weight_offset = Parameter(np.zeros(shape=weight.shape, dtype=np.int8), name='weight_offset')
|
||||
|
||||
def construct(self, x):
|
||||
x = self.quant(x)
|
||||
if self.has_bias:
|
||||
weight = self.sub(self.weight, self.weight_offset)
|
||||
x = self.core_op(x, weight)
|
||||
x = self.bias_add(x, self.bias)
|
||||
else:
|
||||
x = self.core_op(x, self.weight)
|
||||
x = self.dequant(x, self.dequant_scale)
|
||||
x = F.cast(x, mstype.float32)
|
||||
if self.has_act:
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
def extend_repr(self):
|
||||
s = f'quant={self.quant}, core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]'
|
||||
if self.has_bias:
|
||||
s += f', bias=shape[{self.bias.shape}]'
|
||||
if self.has_act:
|
||||
s += f', activation={self.activation}'
|
||||
s += f', dequant={self.dequant}'
|
||||
return s
|
||||
|
||||
|
||||
class QuantMindirBlock(Cell):
|
||||
"""A quant binary block of Conv/Dense, activation layer for export MINDIR model.
|
||||
|
||||
Args:
|
||||
core_op (Cell): The operation cell.
|
||||
weight (Tensor): The weight of the cell.
|
||||
bias (Tensor): The bias of the cell. Default: None.
|
||||
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
param_dict (dict): The information of the cell.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
core_op,
|
||||
weight,
|
||||
bias=None,
|
||||
activation=None,
|
||||
param_dict=None):
|
||||
|
||||
super(QuantMindirBlock, self).__init__()
|
||||
self.core_op = core_op
|
||||
if activation is not None:
|
||||
self.core_op.add_prim_attr("activation_name", activation.__class__.__name__)
|
||||
self.core_op.add_prim_attr("filter_maxq", Tensor(param_dict["filter_maxq"]))
|
||||
self.core_op.add_prim_attr("filter_minq", Tensor(param_dict["filter_minq"]))
|
||||
if param_dict["output_maxq"] is not None:
|
||||
self.core_op.add_prim_attr("output_maxq", Tensor(param_dict["output_maxq"]))
|
||||
self.core_op.add_prim_attr("output_minq", Tensor(param_dict["output_minq"]))
|
||||
self.core_op.add_prim_attr("symmetric", Tensor(param_dict["symmetric"]))
|
||||
if hasattr(core_op, 'pad_mode'):
|
||||
self.core_op.add_prim_attr("pad_mode", core_op.pad_mode)
|
||||
self.core_op.add_prim_attr("num_bits", Tensor(8))
|
||||
self.core_op.add_prim_attr("narrow_range", Tensor(False))
|
||||
if param_dict["input_maxq"] == 'None':
|
||||
self.core_op.add_prim_attr("mean", Tensor(param_dict["mean"]))
|
||||
self.core_op.add_prim_attr("std_dev", Tensor(param_dict["std_dev"]))
|
||||
elif param_dict["input_maxq"] is not None:
|
||||
self.core_op.add_prim_attr("input_maxq", Tensor(param_dict["input_maxq"]))
|
||||
self.core_op.add_prim_attr("input_minq", Tensor(param_dict["input_minq"]))
|
||||
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
self.has_bias = bias is not None
|
||||
self.activation = activation
|
||||
self.has_act = activation is not None
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
def construct(self, x):
|
||||
if self.has_bias:
|
||||
x = self.core_op(x, self.weight)
|
||||
x = self.bias_add(x, self.bias)
|
||||
else:
|
||||
x = self.core_op(x, self.weight)
|
||||
if self.has_act:
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
def extend_repr(self):
|
||||
s = f'core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]'
|
||||
if self.has_bias:
|
||||
s += f', bias=shape[{self.bias.shape}]'
|
||||
if self.has_act:
|
||||
s += f', activation={self.activation}'
|
||||
return s
|
||||
|
|
|
@ -21,17 +21,19 @@ set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
|||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SRC_ROOT}/out)
|
||||
|
||||
# Set include directory and library directory
|
||||
set(FWKACL_LIB_DIR $ENV{ASCEND_HOME}/fwkacllib)
|
||||
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/acllib)
|
||||
set(ATLAS_ACL_LIB_DIR $ENV{ASCEND_HOME}/ascend-toolkit/latest/acllib)
|
||||
|
||||
# Header path
|
||||
include_directories(${ACL_LIB_DIR}/include/)
|
||||
include_directories(${FWKACL_LIB_DIR}/include/)
|
||||
include_directories(${ATLAS_ACL_LIB_DIR}/include/)
|
||||
include_directories(${PROJECT_SRC_ROOT}/../inc)
|
||||
|
||||
# add host lib path
|
||||
link_directories(${ACL_LIB_DIR})
|
||||
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
link_directories(${ACL_LIB_DIR} ${FWKACL_LIB_DIR})
|
||||
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${FWKACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
|
||||
add_executable(main utils.cpp
|
||||
SampleProcess.cpp
|
||||
|
|
|
@ -36,7 +36,7 @@ if [ $# == 4 ]; then
|
|||
device_id=$4
|
||||
fi
|
||||
|
||||
echo "mindir name: "$model
|
||||
echo "air name: "$model
|
||||
echo "dataset path: "$data_path
|
||||
echo "label path: "$label_path
|
||||
echo "device id: "$device_id
|
||||
|
|
|
@ -21,17 +21,19 @@ set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
|||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SRC_ROOT}/out)
|
||||
|
||||
# Set include directory and library directory
|
||||
set(FWKACL_LIB_DIR $ENV{ASCEND_HOME}/fwkacllib)
|
||||
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/acllib)
|
||||
set(ATLAS_ACL_LIB_DIR $ENV{ASCEND_HOME}/ascend-toolkit/latest/acllib)
|
||||
|
||||
# Header path
|
||||
include_directories(${ACL_LIB_DIR}/include/)
|
||||
include_directories(${FWKACL_LIB_DIR}/include/)
|
||||
include_directories(${ATLAS_ACL_LIB_DIR}/include/)
|
||||
include_directories(${PROJECT_SRC_ROOT}/../inc)
|
||||
|
||||
# add host lib path
|
||||
link_directories(${ACL_LIB_DIR})
|
||||
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
link_directories(${ACL_LIB_DIR} ${FWKACL_LIB_DIR})
|
||||
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${FWKACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
|
||||
add_executable(main utils.cpp
|
||||
SampleProcess.cpp
|
||||
|
|
|
@ -36,7 +36,7 @@ if [ $# == 4 ]; then
|
|||
device_id=$4
|
||||
fi
|
||||
|
||||
echo "mindir name: "$model
|
||||
echo "air name: "$model
|
||||
echo "dataset path: "$data_path
|
||||
echo "label path: "$label_path
|
||||
echo "device id: "$device_id
|
||||
|
|
|
@ -21,17 +21,19 @@ set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
|||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SRC_ROOT}/out)
|
||||
|
||||
# Set include directory and library directory
|
||||
set(FWKACL_LIB_DIR $ENV{ASCEND_HOME}/fwkacllib)
|
||||
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/acllib)
|
||||
set(ATLAS_ACL_LIB_DIR $ENV{ASCEND_HOME}/ascend-toolkit/latest/acllib)
|
||||
|
||||
# Header path
|
||||
include_directories(${ACL_LIB_DIR}/include/)
|
||||
include_directories(${FWKACL_LIB_DIR}/include/)
|
||||
include_directories(${ATLAS_ACL_LIB_DIR}/include/)
|
||||
include_directories(${PROJECT_SRC_ROOT}/../inc)
|
||||
|
||||
# add host lib path
|
||||
link_directories(${ACL_LIB_DIR})
|
||||
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
link_directories(${ACL_LIB_DIR} ${FWKACL_LIB_DIR})
|
||||
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${FWKACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
|
||||
add_executable(main utils.cpp
|
||||
SampleProcess.cpp
|
||||
|
|
|
@ -36,7 +36,7 @@ if [ $# == 4 ]; then
|
|||
device_id=$4
|
||||
fi
|
||||
|
||||
echo "mindir name: "$model
|
||||
echo "air name: "$model
|
||||
echo "dataset path: "$data_path
|
||||
echo "label path: "$label_path
|
||||
echo "device id: "$device_id
|
||||
|
|
|
@ -21,17 +21,19 @@ set(PROJECT_SRC_ROOT ${CMAKE_CURRENT_LIST_DIR}/)
|
|||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${PROJECT_SRC_ROOT}/out)
|
||||
|
||||
# Set include directory and library directory
|
||||
set(FWKACL_LIB_DIR $ENV{ASCEND_HOME}/fwkacllib)
|
||||
set(ACL_LIB_DIR $ENV{ASCEND_HOME}/acllib)
|
||||
set(ATLAS_ACL_LIB_DIR $ENV{ASCEND_HOME}/ascend-toolkit/latest/acllib)
|
||||
|
||||
# Header path
|
||||
include_directories(${ACL_LIB_DIR}/include/)
|
||||
include_directories(${FWKACL_LIB_DIR}/include/)
|
||||
include_directories(${ATLAS_ACL_LIB_DIR}/include/)
|
||||
include_directories(${PROJECT_SRC_ROOT}/../inc)
|
||||
|
||||
# add host lib path
|
||||
link_directories(${ACL_LIB_DIR})
|
||||
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
link_directories(${ACL_LIB_DIR} ${FWKACL_LIB_DIR})
|
||||
find_library(acl libascendcl.so ${ACL_LIB_DIR}/lib64 ${FWKACL_LIB_DIR}/lib64 ${ATLAS_ACL_LIB_DIR}/lib64)
|
||||
|
||||
add_executable(main utils.cpp
|
||||
SampleProcess.cpp
|
||||
|
|
|
@ -38,7 +38,7 @@ if [ $# == 6 ]; then
|
|||
device_id=$6
|
||||
fi
|
||||
|
||||
echo "mindir name: "$model
|
||||
echo "air name: "$model
|
||||
echo "dataset path: "$data_path
|
||||
echo "annotation path: "$anno_path
|
||||
echo "image shape path: "$image_shape_path
|
||||
|
|
Loading…
Reference in New Issue