From 315c6a5742999febc28df42935bf6bb3ed0acc43 Mon Sep 17 00:00:00 2001 From: chenzomi Date: Thu, 30 Jul 2020 16:59:17 +0800 Subject: [PATCH] quant aware training add without bn fold and bn fold --- mindspore/hub.py | 36 +++---- mindspore/nn/layer/quant.py | 139 +++++++++++++++++++++++++-- mindspore/train/quant/quant.py | 61 ++++++++---- mindspore/train/quant/quant_utils.py | 4 +- requirements.txt | 3 +- 5 files changed, 198 insertions(+), 45 deletions(-) diff --git a/mindspore/hub.py b/mindspore/hub.py index 72013c82188..f60e841729e 100644 --- a/mindspore/hub.py +++ b/mindspore/hub.py @@ -32,13 +32,12 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net DOWNLOAD_BASIC_URL = "http://download.mindspore.cn/model_zoo" OFFICIAL_NAME = "official" -DEFAULT_CACHE_DIR = '~/.cache' -MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet', - 'lenet', 'resnet', 'ssd', 'vgg', 'yolo'] +DEFAULT_CACHE_DIR = '.cache' +MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet', 'lenet', 'resnet', 'resnet50', 'ssd', 'vgg', 'yolo'] MODEL_TARGET_NLP = ['bert', 'mass', 'transformer'] -def _packing_targz(output_filename, savepath="./"): +def _packing_targz(output_filename, savepath=DEFAULT_CACHE_DIR): """ Packing the input filename to filename.tar.gz in source dir. """ @@ -49,7 +48,7 @@ def _packing_targz(output_filename, savepath="./"): raise OSError("Cannot tar file {} for - {}".format(output_filename, e)) -def _unpacking_targz(input_filename, savepath="./"): +def _unpacking_targz(input_filename, savepath=DEFAULT_CACHE_DIR): """ Unpacking the input filename to dirs. """ @@ -69,14 +68,14 @@ def _remove_path_if_exists(path): def _create_path_if_not_exists(path): - if os.path.exists(path): + if not os.path.exists(path): if os.path.isfile(path): os.remove(path) else: os.mkdir(path) -def _get_weights_file(url, hash_md5=None, savepath='./'): +def _get_weights_file(url, hash_md5=None, savepath=DEFAULT_CACHE_DIR): """ get checkpoint weight from giving url. @@ -103,7 +102,8 @@ def _get_weights_file(url, hash_md5=None, savepath='./'): download_md5 = m.hexdigest() return download_md5 == hash_md5 - _create_path_if_not_exists(savepath) + _remove_path_if_exists(os.path.realpath(savepath)) + _create_path_if_not_exists(os.path.realpath(savepath)) ckpt_name = os.path.basename(url.split("/")[-1]) # identify file exist or not file_path = os.path.join(savepath, ckpt_name) @@ -112,8 +112,8 @@ def _get_weights_file(url, hash_md5=None, savepath='./'): print('File already exists!') return file_path - file_path = file_path[:-7] if ".tar.gz" in file_path else file_path - _remove_path_if_exists(file_path) + file_path_ = file_path[:-7] if ".tar.gz" in file_path else file_path + _remove_path_if_exists(file_path_) # download the checkpoint file print('Downloading data from url {}'.format(url)) @@ -126,14 +126,12 @@ def _get_weights_file(url, hash_md5=None, savepath='./'): print('\nDownload finished!') # untar file_path - _unpacking_targz(file_path) + _unpacking_targz(file_path, os.path.realpath(savepath)) - # # get the file size - file_path = os.path.join(savepath, ckpt_name) filesize = os.path.getsize(file_path) # turn the file size to Mb format print('File size = %.2f Mb' % (filesize / 1024 / 1024)) - return file_path + return file_path_ def _get_url_paths(url, ext='.tar.gz'): @@ -150,7 +148,7 @@ def _get_url_paths(url, ext='.tar.gz'): def _get_file_from_url(base_url, base_name): idx = 0 - urls = _get_url_paths(base_url) + urls = _get_url_paths(base_url + "/") files = [url.split('/')[-1] for url in urls] for i, name in enumerate(files): if re.match(base_name + '*', name) is not None: @@ -172,8 +170,8 @@ def load_weights(network, network_name=None, force_reload=True, **kwargs): dataset (string, optional): Dataset to train the network. Default: 'cifar10'. Example: - >>> mindspore.hub.load(network, network_name='lenet', - **{'device_target': 'ascend', 'dataset':'cifar10', 'version': 'beta0.5'}) + >>> hub.load(network, network_name='lenet', + **{'device_target': 'ascend', 'dataset':'mnist', 'version': '0.5.0'}) """ if not isinstance(network, nn.Cell): logger.error("Failed to combine the net and the parameters.") @@ -195,9 +193,11 @@ def load_weights(network, network_name=None, force_reload=True, **kwargs): model_type = "cv" elif network_name.split("_")[0] in MODEL_TARGET_NLP: model_type = "nlp" + else: + raise ValueError("Unsupported network {} download checkpoint.".format(network_name.split("_")[0])) download_base_url = "/".join([DOWNLOAD_BASIC_URL, - OFFICIAL_NAME, model_type]) + OFFICIAL_NAME, model_type, network_name]) download_file_name = "_".join( [network_name, device_target, version, dataset, OFFICIAL_NAME]) download_url = _get_file_from_url(download_base_url, download_file_name) diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index cde81a76441..1dfae36a21f 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -39,7 +39,8 @@ __all__ = [ 'Conv2dBnAct', 'DenseBnAct', 'FakeQuantWithMinMax', - 'Conv2dBatchNormQuant', + 'Conv2dBnFoldQuant', + 'Conv2dBnWithoutFoldQuant', 'Conv2dQuant', 'DenseQuant', 'ActQuant', @@ -393,7 +394,7 @@ class FakeQuantWithMinMax(Cell): return out -class Conv2dBatchNormQuant(Cell): +class Conv2dBnFoldQuant(Cell): r""" 2D convolution with BatchNormal op folded layer. @@ -418,7 +419,7 @@ class Conv2dBatchNormQuant(Cell): mean vector. Default: 'zeros'. var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the variance vector. Default: 'ones'. - fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. + fake (bool): Conv2dBnFoldQuant Cell add FakeQuantWithMinMax op or not. Default: True. per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. @@ -433,7 +434,7 @@ class Conv2dBatchNormQuant(Cell): Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. Examples: - >>> batchnorm_quant = nn.Conv2dBatchNormQuant(1, 6, kernel_size= (2, 2), stride=(1, 1), pad_mode="valid", + >>> batchnorm_quant = nn.Conv2dBnFoldQuant(1, 6, kernel_size= (2, 2), stride=(1, 1), pad_mode="valid", >>> dilation=(1, 1)) >>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mindspore.float32) >>> result = batchnorm_quant(input_x) @@ -462,8 +463,8 @@ class Conv2dBatchNormQuant(Cell): narrow_range=False, quant_delay=0, freeze_bn=100000): - """init Conv2dBatchNormQuant layer""" - super(Conv2dBatchNormQuant, self).__init__() + """init Conv2dBnFoldQuant layer""" + super(Conv2dBnFoldQuant, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = twice(kernel_size) @@ -580,6 +581,132 @@ class Conv2dBatchNormQuant(Cell): return out +class Conv2dBnWithoutFoldQuant(Cell): + r""" + 2D convolution + batchnorm without fold with fake quant op layer. + + For a more Detailed overview of Conv2d op. + + Args: + in_channels (int): The number of input channel :math:`C_{in}`. + out_channels (int): The number of output channel :math:`C_{out}`. + kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window. + stride (int): Specifies stride for all spatial dimensions with the same value. Default: 1. + pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". + padding (int): Implicit paddings on both sides of the input. Default: 0. + dilation (int): Specifying the dilation rate to use for dilated convolution. Default: 1. + group (int): Split filter into groups, `in_ channels` and `out_channels` should be + divisible by the number of groups. Default: 1. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. + has_bn (bool): Specifies to used batchnorm or not. Default: False. + eps (float): Parameters for BatchNormal. Default: 1e-5. + momentum (float): Parameters for BatchNormal op. Default: 0.997. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. + Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. + per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> conv2d_quant = nn.Conv2dQuant(1, 6, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid", + >>> dilation=(1, 1)) + >>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mstype.float32) + >>> result = conv2d_quant(input_x) + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + has_bias=False, + has_bn=True, + eps=1e-5, + momentum=0.997, + weight_init='normal', + bias_init='zeros', + per_channel=False, + num_bits=8, + symmetric=False, + narrow_range=False, + quant_delay=0): + super(Conv2dBnWithoutFoldQuant, self).__init__() + if isinstance(kernel_size, int): + self.kernel_size = (kernel_size, kernel_size) + else: + self.kernel_size = kernel_size + self.in_channels = check_int_positive(in_channels) + self.out_channels = check_int_positive(out_channels) + self.has_bias = has_bias + self.stride = twice(stride) + self.dilation = twice(dilation) + self.pad_mode = pad_mode + self.padding = padding + self.group = group + self.quant_delay = quant_delay + + weight_shape = [out_channels, in_channels // group, *self.kernel_size] + self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') + + self.bias_add = P.BiasAdd() + if check_bool(has_bias): + self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') + else: + self.bias = None + + self.conv = P.Conv2D(out_channel=self.out_channels, + kernel_size=self.kernel_size, + mode=1, + pad_mode=self.pad_mode, + pad=self.padding, + stride=self.stride, + dilation=self.dilation, + group=self.group) + self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, + max_init=6, + ema=False, + per_channel=per_channel, + channel_axis=0, + num_channels=out_channels, + num_bits=num_bits, + symmetric=symmetric, + narrow_range=narrow_range, + quant_delay=quant_delay) + self.has_bn = validator.check_bool("has_bn", has_bn) + if has_bn: + self.batchnorm = BatchNorm2d(out_channels) + + def construct(self, x): + weight = self.fake_quant_weight(self.weight) + out = self.conv(x, weight) + if self.has_bias: + out = self.bias_add(out, self.bias) + if self.has_bn: + out = self.batchnorm(out) + return out + + def extend_repr(self): + s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \ + 'pad_mode={}, padding={}, dilation={}, group={}, ' \ + 'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride, + self.pad_mode, self.padding, self.dilation, self.group, + self.has_bias, self.quant_delay) + return s + + class Conv2dQuant(Cell): r""" 2D convolution with fake quant op layer. diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index d73dce25ae4..29062bc9aaa 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -170,22 +170,23 @@ class ConvertToQuantNetwork: if subcell.has_bn: if self.bn_fold: bn_inner = subcell.batchnorm - conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels, - conv_inner.out_channels, - kernel_size=conv_inner.kernel_size, - stride=conv_inner.stride, - pad_mode=conv_inner.pad_mode, - padding=conv_inner.padding, - dilation=conv_inner.dilation, - group=conv_inner.group, - eps=bn_inner.eps, - quant_delay=self.weight_qdelay, - freeze_bn=self.freeze_bn, - per_channel=self.weight_channel, - num_bits=self.weight_bits, - fake=True, - symmetric=self.weight_symmetric, - narrow_range=self.weight_range) + conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels, + conv_inner.out_channels, + kernel_size=conv_inner.kernel_size, + stride=conv_inner.stride, + pad_mode=conv_inner.pad_mode, + padding=conv_inner.padding, + dilation=conv_inner.dilation, + group=conv_inner.group, + eps=bn_inner.eps, + momentum=bn_inner.momentum, + quant_delay=self.weight_qdelay, + freeze_bn=self.freeze_bn, + per_channel=self.weight_channel, + num_bits=self.weight_bits, + fake=True, + symmetric=self.weight_symmetric, + narrow_range=self.weight_range) # change original network BatchNormal OP parameters to quant network conv_inner.gamma = subcell.batchnorm.gamma conv_inner.beta = subcell.batchnorm.beta @@ -195,7 +196,31 @@ class ConvertToQuantNetwork: subcell.batchnorm = None subcell.has_bn = False else: - raise ValueError("Only support Batchnorm fold mode.") + bn_inner = subcell.batchnorm + conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels, + conv_inner.out_channels, + kernel_size=conv_inner.kernel_size, + stride=conv_inner.stride, + pad_mode=conv_inner.pad_mode, + padding=conv_inner.padding, + dilation=conv_inner.dilation, + group=conv_inner.group, + eps=bn_inner.eps, + momentum=bn_inner.momentum, + has_bn=True, + quant_delay=self.weight_qdelay, + per_channel=self.weight_channel, + num_bits=self.weight_bits, + symmetric=self.weight_symmetric, + narrow_range=self.weight_range) + # change original network BatchNormal OP parameters to quant network + conv_inner.batchnorm.gamma = subcell.batchnorm.gamma + conv_inner.batchnorm.beta = subcell.batchnorm.beta + conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean + conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance + del subcell.batchnorm + subcell.batchnorm = None + subcell.has_bn = False else: conv_inner = quant.Conv2dQuant(conv_inner.in_channels, conv_inner.out_channels, @@ -354,7 +379,7 @@ class ExportToQuantInferNetwork: if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)): if cell_core.has_bias: bias = cell_core.bias.data.asnumpy() - elif isinstance(cell_core, quant.Conv2dBatchNormQuant): + elif isinstance(cell_core, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant)): weight, bias = quant_utils.fold_batchnorm(weight, cell_core) # apply the quant diff --git a/mindspore/train/quant/quant_utils.py b/mindspore/train/quant/quant_utils.py index 5d524391bef..1e2481ceaa0 100644 --- a/mindspore/train/quant/quant_utils.py +++ b/mindspore/train/quant/quant_utils.py @@ -176,13 +176,13 @@ def scale_zp_from_data(op, minq, maxq, data_type): def fold_batchnorm(weight, cell_quant): r""" - Fold the batchnorm in `Conv2dBatchNormQuant` to weight. + Fold the batchnorm in `Conv2dBnFoldQuant` to weight. Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. Args: weight (numpy.ndarray): Weight of `cell_quant`. - cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBatchNormQuant`. + cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnFoldQuant`. Returns: weight (numpy.ndarray): Folded weight. diff --git a/requirements.txt b/requirements.txt index 4038e63ea77..5fe70c0492a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ setuptools >= 40.8.0 matplotlib >= 3.1.3 # for ut test opencv-python >= 4.1.2.30 # for ut test sklearn >= 0.0 # for st test -pandas >= 1.0.2 # for ut test \ No newline at end of file +pandas >= 1.0.2 # for ut test +bs4