forked from mindspore-Ecosystem/mindspore
!4658 fix quant export bug
Merge pull request !4658 from chengxb7532/master
This commit is contained in:
commit
ff2851a25c
|
@ -607,7 +607,6 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
|||
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
|
||||
divisible by the number of groups. Default: 1.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||
has_bn (bool): Specifies to used batchnorm or not. Default: False.
|
||||
eps (float): Parameters for BatchNormal. Default: 1e-5.
|
||||
momentum (float): Parameters for BatchNormal op. Default: 0.997.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||
|
@ -641,7 +640,6 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
|||
dilation=1,
|
||||
group=1,
|
||||
has_bias=False,
|
||||
has_bn=True,
|
||||
eps=1e-5,
|
||||
momentum=0.997,
|
||||
weight_init='normal',
|
||||
|
@ -693,17 +691,14 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
|||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.has_bn = validator.check_bool("has_bn", has_bn)
|
||||
if has_bn:
|
||||
self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum)
|
||||
self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum)
|
||||
|
||||
def construct(self, x):
|
||||
weight = self.fake_quant_weight(self.weight)
|
||||
out = self.conv(x, weight)
|
||||
if self.has_bias:
|
||||
out = self.bias_add(out, self.bias)
|
||||
if self.has_bn:
|
||||
out = self.batchnorm(out)
|
||||
out = self.batchnorm(out)
|
||||
return out
|
||||
|
||||
def extend_repr(self):
|
||||
|
|
|
@ -208,7 +208,6 @@ class ConvertToQuantNetwork:
|
|||
group=conv_inner.group,
|
||||
eps=bn_inner.eps,
|
||||
momentum=bn_inner.momentum,
|
||||
has_bn=True,
|
||||
quant_delay=self.weight_qdelay,
|
||||
per_channel=self.weight_channel,
|
||||
num_bits=self.weight_bits,
|
||||
|
@ -378,8 +377,10 @@ class ExportToQuantInferNetwork:
|
|||
if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)):
|
||||
if cell_core.has_bias:
|
||||
bias = cell_core.bias.data.asnumpy()
|
||||
elif isinstance(cell_core, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant)):
|
||||
elif isinstance(cell_core, quant.Conv2dBnFoldQuant):
|
||||
weight, bias = quant_utils.fold_batchnorm(weight, cell_core)
|
||||
elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant):
|
||||
weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core)
|
||||
|
||||
# apply the quant
|
||||
weight = quant_utils.weight2int(weight, scale_w, zp_w)
|
||||
|
|
|
@ -211,3 +211,42 @@ def fold_batchnorm(weight, cell_quant):
|
|||
weight = weight * _gamma / _sigma
|
||||
bias = beta - gamma * mean / sigma
|
||||
return weight, bias
|
||||
|
||||
|
||||
def without_fold_batchnorm(weight, cell_quant):
|
||||
r"""
|
||||
Fold the batchnorm in `Conv2dBnWithoutFoldQuant` to weight.
|
||||
|
||||
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
|
||||
|
||||
Args:
|
||||
weight (numpy.ndarray): Weight of `cell_quant`.
|
||||
cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnWithoutFoldQuant`.
|
||||
|
||||
Returns:
|
||||
weight (numpy.ndarray): whihout folded weight.
|
||||
bias (numpy.ndarray): without folded bias.
|
||||
"""
|
||||
variance = cell_quant.batchnorm.moving_variance.data.asnumpy()
|
||||
mean = cell_quant.batchnorm.moving_mean.data.asnumpy()
|
||||
gamma = cell_quant.batchnorm.gamma.data.asnumpy()
|
||||
beta = cell_quant.batchnorm.beta.data.asnumpy()
|
||||
epsilon = cell_quant.batchnorm.eps
|
||||
sigma = np.sqrt(variance + epsilon)
|
||||
|
||||
if gamma.shape[0] == weight.shape[0]:
|
||||
# `Conv2d` or `Dense` op weight
|
||||
shape_list = [-1] + [1] * len(weight.shape[1:])
|
||||
_gamma = gamma.reshape(shape_list)
|
||||
_sigma = sigma.reshape(shape_list)
|
||||
elif gamma.shape[0] == weight.shape[1]:
|
||||
# `DepthwiseConv2d` op weight
|
||||
shape_list = [1, -1] + [1] * len(weight.shape[2:])
|
||||
_gamma = gamma.reshape(shape_list)
|
||||
_sigma = sigma.reshape(shape_list)
|
||||
else:
|
||||
raise ValueError("Unsupported weight shape({})".format(weight.shape))
|
||||
|
||||
weight = weight * _gamma / _sigma
|
||||
bias = beta - gamma * mean / sigma
|
||||
return weight, bias
|
||||
|
|
Loading…
Reference in New Issue