From 4a76059f2b5a48ab71afd5c74c3cd8b580ff2ed5 Mon Sep 17 00:00:00 2001 From: chengxianbin Date: Tue, 18 Aug 2020 15:13:31 +0800 Subject: [PATCH] fix quant export bug --- mindspore/nn/layer/quant.py | 9 ++----- mindspore/train/quant/quant.py | 5 ++-- mindspore/train/quant/quant_utils.py | 39 ++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 2aa075b1212..0bb69b44763 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -607,7 +607,6 @@ class Conv2dBnWithoutFoldQuant(Cell): group (int): Split filter into groups, `in_ channels` and `out_channels` should be divisible by the number of groups. Default: 1. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. - has_bn (bool): Specifies to used batchnorm or not. Default: False. eps (float): Parameters for BatchNormal. Default: 1e-5. momentum (float): Parameters for BatchNormal op. Default: 0.997. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. @@ -641,7 +640,6 @@ class Conv2dBnWithoutFoldQuant(Cell): dilation=1, group=1, has_bias=False, - has_bn=True, eps=1e-5, momentum=0.997, weight_init='normal', @@ -693,17 +691,14 @@ class Conv2dBnWithoutFoldQuant(Cell): symmetric=symmetric, narrow_range=narrow_range, quant_delay=quant_delay) - self.has_bn = validator.check_bool("has_bn", has_bn) - if has_bn: - self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum) + self.batchnorm = BatchNorm2d(out_channels, eps=eps, momentum=momentum) def construct(self, x): weight = self.fake_quant_weight(self.weight) out = self.conv(x, weight) if self.has_bias: out = self.bias_add(out, self.bias) - if self.has_bn: - out = self.batchnorm(out) + out = self.batchnorm(out) return out def extend_repr(self): diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index 4bd87d14801..2e9769da882 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -208,7 +208,6 @@ class ConvertToQuantNetwork: group=conv_inner.group, eps=bn_inner.eps, momentum=bn_inner.momentum, - has_bn=True, quant_delay=self.weight_qdelay, per_channel=self.weight_channel, num_bits=self.weight_bits, @@ -378,8 +377,10 @@ class ExportToQuantInferNetwork: if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)): if cell_core.has_bias: bias = cell_core.bias.data.asnumpy() - elif isinstance(cell_core, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant)): + elif isinstance(cell_core, quant.Conv2dBnFoldQuant): weight, bias = quant_utils.fold_batchnorm(weight, cell_core) + elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant): + weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core) # apply the quant weight = quant_utils.weight2int(weight, scale_w, zp_w) diff --git a/mindspore/train/quant/quant_utils.py b/mindspore/train/quant/quant_utils.py index d01c7645191..7dac5a27b4e 100644 --- a/mindspore/train/quant/quant_utils.py +++ b/mindspore/train/quant/quant_utils.py @@ -211,3 +211,42 @@ def fold_batchnorm(weight, cell_quant): weight = weight * _gamma / _sigma bias = beta - gamma * mean / sigma return weight, bias + + +def without_fold_batchnorm(weight, cell_quant): + r""" + Fold the batchnorm in `Conv2dBnWithoutFoldQuant` to weight. + + Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. + + Args: + weight (numpy.ndarray): Weight of `cell_quant`. + cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnWithoutFoldQuant`. + + Returns: + weight (numpy.ndarray): whihout folded weight. + bias (numpy.ndarray): without folded bias. + """ + variance = cell_quant.batchnorm.moving_variance.data.asnumpy() + mean = cell_quant.batchnorm.moving_mean.data.asnumpy() + gamma = cell_quant.batchnorm.gamma.data.asnumpy() + beta = cell_quant.batchnorm.beta.data.asnumpy() + epsilon = cell_quant.batchnorm.eps + sigma = np.sqrt(variance + epsilon) + + if gamma.shape[0] == weight.shape[0]: + # `Conv2d` or `Dense` op weight + shape_list = [-1] + [1] * len(weight.shape[1:]) + _gamma = gamma.reshape(shape_list) + _sigma = sigma.reshape(shape_list) + elif gamma.shape[0] == weight.shape[1]: + # `DepthwiseConv2d` op weight + shape_list = [1, -1] + [1] * len(weight.shape[2:]) + _gamma = gamma.reshape(shape_list) + _sigma = sigma.reshape(shape_list) + else: + raise ValueError("Unsupported weight shape({})".format(weight.shape)) + + weight = weight * _gamma / _sigma + bias = beta - gamma * mean / sigma + return weight, bias