forked from mindspore-Ecosystem/mindspore
!3736 quant aware training add without bn fold and bn fold
Merge pull request !3736 from chenzhongming/master
This commit is contained in:
commit
fc259aebcf
|
@ -32,13 +32,12 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|||
|
||||
DOWNLOAD_BASIC_URL = "http://download.mindspore.cn/model_zoo"
|
||||
OFFICIAL_NAME = "official"
|
||||
DEFAULT_CACHE_DIR = '~/.cache'
|
||||
MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet',
|
||||
'lenet', 'resnet', 'ssd', 'vgg', 'yolo']
|
||||
DEFAULT_CACHE_DIR = '.cache'
|
||||
MODEL_TARGET_CV = ['alexnet', 'fasterrcnn', 'googlenet', 'lenet', 'resnet', 'resnet50', 'ssd', 'vgg', 'yolo']
|
||||
MODEL_TARGET_NLP = ['bert', 'mass', 'transformer']
|
||||
|
||||
|
||||
def _packing_targz(output_filename, savepath="./"):
|
||||
def _packing_targz(output_filename, savepath=DEFAULT_CACHE_DIR):
|
||||
"""
|
||||
Packing the input filename to filename.tar.gz in source dir.
|
||||
"""
|
||||
|
@ -49,7 +48,7 @@ def _packing_targz(output_filename, savepath="./"):
|
|||
raise OSError("Cannot tar file {} for - {}".format(output_filename, e))
|
||||
|
||||
|
||||
def _unpacking_targz(input_filename, savepath="./"):
|
||||
def _unpacking_targz(input_filename, savepath=DEFAULT_CACHE_DIR):
|
||||
"""
|
||||
Unpacking the input filename to dirs.
|
||||
"""
|
||||
|
@ -69,14 +68,14 @@ def _remove_path_if_exists(path):
|
|||
|
||||
|
||||
def _create_path_if_not_exists(path):
|
||||
if os.path.exists(path):
|
||||
if not os.path.exists(path):
|
||||
if os.path.isfile(path):
|
||||
os.remove(path)
|
||||
else:
|
||||
os.mkdir(path)
|
||||
|
||||
|
||||
def _get_weights_file(url, hash_md5=None, savepath='./'):
|
||||
def _get_weights_file(url, hash_md5=None, savepath=DEFAULT_CACHE_DIR):
|
||||
"""
|
||||
get checkpoint weight from giving url.
|
||||
|
||||
|
@ -103,7 +102,8 @@ def _get_weights_file(url, hash_md5=None, savepath='./'):
|
|||
download_md5 = m.hexdigest()
|
||||
return download_md5 == hash_md5
|
||||
|
||||
_create_path_if_not_exists(savepath)
|
||||
_remove_path_if_exists(os.path.realpath(savepath))
|
||||
_create_path_if_not_exists(os.path.realpath(savepath))
|
||||
ckpt_name = os.path.basename(url.split("/")[-1])
|
||||
# identify file exist or not
|
||||
file_path = os.path.join(savepath, ckpt_name)
|
||||
|
@ -112,8 +112,8 @@ def _get_weights_file(url, hash_md5=None, savepath='./'):
|
|||
print('File already exists!')
|
||||
return file_path
|
||||
|
||||
file_path = file_path[:-7] if ".tar.gz" in file_path else file_path
|
||||
_remove_path_if_exists(file_path)
|
||||
file_path_ = file_path[:-7] if ".tar.gz" in file_path else file_path
|
||||
_remove_path_if_exists(file_path_)
|
||||
|
||||
# download the checkpoint file
|
||||
print('Downloading data from url {}'.format(url))
|
||||
|
@ -126,14 +126,12 @@ def _get_weights_file(url, hash_md5=None, savepath='./'):
|
|||
print('\nDownload finished!')
|
||||
|
||||
# untar file_path
|
||||
_unpacking_targz(file_path)
|
||||
_unpacking_targz(file_path, os.path.realpath(savepath))
|
||||
|
||||
# # get the file size
|
||||
file_path = os.path.join(savepath, ckpt_name)
|
||||
filesize = os.path.getsize(file_path)
|
||||
# turn the file size to Mb format
|
||||
print('File size = %.2f Mb' % (filesize / 1024 / 1024))
|
||||
return file_path
|
||||
return file_path_
|
||||
|
||||
|
||||
def _get_url_paths(url, ext='.tar.gz'):
|
||||
|
@ -150,7 +148,7 @@ def _get_url_paths(url, ext='.tar.gz'):
|
|||
|
||||
def _get_file_from_url(base_url, base_name):
|
||||
idx = 0
|
||||
urls = _get_url_paths(base_url)
|
||||
urls = _get_url_paths(base_url + "/")
|
||||
files = [url.split('/')[-1] for url in urls]
|
||||
for i, name in enumerate(files):
|
||||
if re.match(base_name + '*', name) is not None:
|
||||
|
@ -172,8 +170,8 @@ def load_weights(network, network_name=None, force_reload=True, **kwargs):
|
|||
dataset (string, optional): Dataset to train the network. Default: 'cifar10'.
|
||||
|
||||
Example:
|
||||
>>> mindspore.hub.load(network, network_name='lenet',
|
||||
**{'device_target': 'ascend', 'dataset':'cifar10', 'version': 'beta0.5'})
|
||||
>>> hub.load(network, network_name='lenet',
|
||||
**{'device_target': 'ascend', 'dataset':'mnist', 'version': '0.5.0'})
|
||||
"""
|
||||
if not isinstance(network, nn.Cell):
|
||||
logger.error("Failed to combine the net and the parameters.")
|
||||
|
@ -195,9 +193,11 @@ def load_weights(network, network_name=None, force_reload=True, **kwargs):
|
|||
model_type = "cv"
|
||||
elif network_name.split("_")[0] in MODEL_TARGET_NLP:
|
||||
model_type = "nlp"
|
||||
else:
|
||||
raise ValueError("Unsupported network {} download checkpoint.".format(network_name.split("_")[0]))
|
||||
|
||||
download_base_url = "/".join([DOWNLOAD_BASIC_URL,
|
||||
OFFICIAL_NAME, model_type])
|
||||
OFFICIAL_NAME, model_type, network_name])
|
||||
download_file_name = "_".join(
|
||||
[network_name, device_target, version, dataset, OFFICIAL_NAME])
|
||||
download_url = _get_file_from_url(download_base_url, download_file_name)
|
||||
|
|
|
@ -39,7 +39,8 @@ __all__ = [
|
|||
'Conv2dBnAct',
|
||||
'DenseBnAct',
|
||||
'FakeQuantWithMinMax',
|
||||
'Conv2dBatchNormQuant',
|
||||
'Conv2dBnFoldQuant',
|
||||
'Conv2dBnWithoutFoldQuant',
|
||||
'Conv2dQuant',
|
||||
'DenseQuant',
|
||||
'ActQuant',
|
||||
|
@ -393,7 +394,7 @@ class FakeQuantWithMinMax(Cell):
|
|||
return out
|
||||
|
||||
|
||||
class Conv2dBatchNormQuant(Cell):
|
||||
class Conv2dBnFoldQuant(Cell):
|
||||
r"""
|
||||
2D convolution with BatchNormal op folded layer.
|
||||
|
||||
|
@ -418,7 +419,7 @@ class Conv2dBatchNormQuant(Cell):
|
|||
mean vector. Default: 'zeros'.
|
||||
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
|
||||
variance vector. Default: 'ones'.
|
||||
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
|
||||
fake (bool): Conv2dBnFoldQuant Cell add FakeQuantWithMinMax op or not. Default: True.
|
||||
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
|
@ -433,7 +434,7 @@ class Conv2dBatchNormQuant(Cell):
|
|||
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
|
||||
Examples:
|
||||
>>> batchnorm_quant = nn.Conv2dBatchNormQuant(1, 6, kernel_size= (2, 2), stride=(1, 1), pad_mode="valid",
|
||||
>>> batchnorm_quant = nn.Conv2dBnFoldQuant(1, 6, kernel_size= (2, 2), stride=(1, 1), pad_mode="valid",
|
||||
>>> dilation=(1, 1))
|
||||
>>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mindspore.float32)
|
||||
>>> result = batchnorm_quant(input_x)
|
||||
|
@ -462,8 +463,8 @@ class Conv2dBatchNormQuant(Cell):
|
|||
narrow_range=False,
|
||||
quant_delay=0,
|
||||
freeze_bn=100000):
|
||||
"""init Conv2dBatchNormQuant layer"""
|
||||
super(Conv2dBatchNormQuant, self).__init__()
|
||||
"""init Conv2dBnFoldQuant layer"""
|
||||
super(Conv2dBnFoldQuant, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = twice(kernel_size)
|
||||
|
@ -580,6 +581,132 @@ class Conv2dBatchNormQuant(Cell):
|
|||
return out
|
||||
|
||||
|
||||
class Conv2dBnWithoutFoldQuant(Cell):
|
||||
r"""
|
||||
2D convolution + batchnorm without fold with fake quant op layer.
|
||||
|
||||
For a more Detailed overview of Conv2d op.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channel :math:`C_{in}`.
|
||||
out_channels (int): The number of output channel :math:`C_{out}`.
|
||||
kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window.
|
||||
stride (int): Specifies stride for all spatial dimensions with the same value. Default: 1.
|
||||
pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
|
||||
padding (int): Implicit paddings on both sides of the input. Default: 0.
|
||||
dilation (int): Specifying the dilation rate to use for dilated convolution. Default: 1.
|
||||
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
|
||||
divisible by the number of groups. Default: 1.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||
has_bn (bool): Specifies to used batchnorm or not. Default: False.
|
||||
eps (float): Parameters for BatchNormal. Default: 1e-5.
|
||||
momentum (float): Parameters for BatchNormal op. Default: 0.997.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||
Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
|
||||
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
|
||||
Examples:
|
||||
>>> conv2d_quant = nn.Conv2dQuant(1, 6, kernel_size=(2, 2), stride=(1, 1), pad_mode="valid",
|
||||
>>> dilation=(1, 1))
|
||||
>>> input_x = Tensor(np.random.randint(-2, 2, (2, 1, 1, 3)), mstype.float32)
|
||||
>>> result = conv2d_quant(input_x)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
pad_mode='same',
|
||||
padding=0,
|
||||
dilation=1,
|
||||
group=1,
|
||||
has_bias=False,
|
||||
has_bn=True,
|
||||
eps=1e-5,
|
||||
momentum=0.997,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
super(Conv2dBnWithoutFoldQuant, self).__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
self.kernel_size = (kernel_size, kernel_size)
|
||||
else:
|
||||
self.kernel_size = kernel_size
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
self.has_bias = has_bias
|
||||
self.stride = twice(stride)
|
||||
self.dilation = twice(dilation)
|
||||
self.pad_mode = pad_mode
|
||||
self.padding = padding
|
||||
self.group = group
|
||||
self.quant_delay = quant_delay
|
||||
|
||||
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
|
||||
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
|
||||
|
||||
self.bias_add = P.BiasAdd()
|
||||
if check_bool(has_bias):
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.conv = P.Conv2D(out_channel=self.out_channels,
|
||||
kernel_size=self.kernel_size,
|
||||
mode=1,
|
||||
pad_mode=self.pad_mode,
|
||||
pad=self.padding,
|
||||
stride=self.stride,
|
||||
dilation=self.dilation,
|
||||
group=self.group)
|
||||
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=False,
|
||||
per_channel=per_channel,
|
||||
channel_axis=0,
|
||||
num_channels=out_channels,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.has_bn = validator.check_bool("has_bn", has_bn)
|
||||
if has_bn:
|
||||
self.batchnorm = BatchNorm2d(out_channels)
|
||||
|
||||
def construct(self, x):
|
||||
weight = self.fake_quant_weight(self.weight)
|
||||
out = self.conv(x, weight)
|
||||
if self.has_bias:
|
||||
out = self.bias_add(out, self.bias)
|
||||
if self.has_bn:
|
||||
out = self.batchnorm(out)
|
||||
return out
|
||||
|
||||
def extend_repr(self):
|
||||
s = 'in_channels={}, out_channels={}, kernel_size={}, stride={}, ' \
|
||||
'pad_mode={}, padding={}, dilation={}, group={}, ' \
|
||||
'has_bias={}, quant_delay={}'.format(self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
||||
self.pad_mode, self.padding, self.dilation, self.group,
|
||||
self.has_bias, self.quant_delay)
|
||||
return s
|
||||
|
||||
|
||||
class Conv2dQuant(Cell):
|
||||
r"""
|
||||
2D convolution with fake quant op layer.
|
||||
|
|
|
@ -170,22 +170,23 @@ class ConvertToQuantNetwork:
|
|||
if subcell.has_bn:
|
||||
if self.bn_fold:
|
||||
bn_inner = subcell.batchnorm
|
||||
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
|
||||
conv_inner.out_channels,
|
||||
kernel_size=conv_inner.kernel_size,
|
||||
stride=conv_inner.stride,
|
||||
pad_mode=conv_inner.pad_mode,
|
||||
padding=conv_inner.padding,
|
||||
dilation=conv_inner.dilation,
|
||||
group=conv_inner.group,
|
||||
eps=bn_inner.eps,
|
||||
quant_delay=self.weight_qdelay,
|
||||
freeze_bn=self.freeze_bn,
|
||||
per_channel=self.weight_channel,
|
||||
num_bits=self.weight_bits,
|
||||
fake=True,
|
||||
symmetric=self.weight_symmetric,
|
||||
narrow_range=self.weight_range)
|
||||
conv_inner = quant.Conv2dBnFoldQuant(conv_inner.in_channels,
|
||||
conv_inner.out_channels,
|
||||
kernel_size=conv_inner.kernel_size,
|
||||
stride=conv_inner.stride,
|
||||
pad_mode=conv_inner.pad_mode,
|
||||
padding=conv_inner.padding,
|
||||
dilation=conv_inner.dilation,
|
||||
group=conv_inner.group,
|
||||
eps=bn_inner.eps,
|
||||
momentum=bn_inner.momentum,
|
||||
quant_delay=self.weight_qdelay,
|
||||
freeze_bn=self.freeze_bn,
|
||||
per_channel=self.weight_channel,
|
||||
num_bits=self.weight_bits,
|
||||
fake=True,
|
||||
symmetric=self.weight_symmetric,
|
||||
narrow_range=self.weight_range)
|
||||
# change original network BatchNormal OP parameters to quant network
|
||||
conv_inner.gamma = subcell.batchnorm.gamma
|
||||
conv_inner.beta = subcell.batchnorm.beta
|
||||
|
@ -195,7 +196,31 @@ class ConvertToQuantNetwork:
|
|||
subcell.batchnorm = None
|
||||
subcell.has_bn = False
|
||||
else:
|
||||
raise ValueError("Only support Batchnorm fold mode.")
|
||||
bn_inner = subcell.batchnorm
|
||||
conv_inner = quant.Conv2dBnWithoutFoldQuant(conv_inner.in_channels,
|
||||
conv_inner.out_channels,
|
||||
kernel_size=conv_inner.kernel_size,
|
||||
stride=conv_inner.stride,
|
||||
pad_mode=conv_inner.pad_mode,
|
||||
padding=conv_inner.padding,
|
||||
dilation=conv_inner.dilation,
|
||||
group=conv_inner.group,
|
||||
eps=bn_inner.eps,
|
||||
momentum=bn_inner.momentum,
|
||||
has_bn=True,
|
||||
quant_delay=self.weight_qdelay,
|
||||
per_channel=self.weight_channel,
|
||||
num_bits=self.weight_bits,
|
||||
symmetric=self.weight_symmetric,
|
||||
narrow_range=self.weight_range)
|
||||
# change original network BatchNormal OP parameters to quant network
|
||||
conv_inner.batchnorm.gamma = subcell.batchnorm.gamma
|
||||
conv_inner.batchnorm.beta = subcell.batchnorm.beta
|
||||
conv_inner.batchnorm.moving_mean = subcell.batchnorm.moving_mean
|
||||
conv_inner.batchnorm.moving_variance = subcell.batchnorm.moving_variance
|
||||
del subcell.batchnorm
|
||||
subcell.batchnorm = None
|
||||
subcell.has_bn = False
|
||||
else:
|
||||
conv_inner = quant.Conv2dQuant(conv_inner.in_channels,
|
||||
conv_inner.out_channels,
|
||||
|
@ -354,7 +379,7 @@ class ExportToQuantInferNetwork:
|
|||
if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)):
|
||||
if cell_core.has_bias:
|
||||
bias = cell_core.bias.data.asnumpy()
|
||||
elif isinstance(cell_core, quant.Conv2dBatchNormQuant):
|
||||
elif isinstance(cell_core, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant)):
|
||||
weight, bias = quant_utils.fold_batchnorm(weight, cell_core)
|
||||
|
||||
# apply the quant
|
||||
|
|
|
@ -176,13 +176,13 @@ def scale_zp_from_data(op, minq, maxq, data_type):
|
|||
|
||||
def fold_batchnorm(weight, cell_quant):
|
||||
r"""
|
||||
Fold the batchnorm in `Conv2dBatchNormQuant` to weight.
|
||||
Fold the batchnorm in `Conv2dBnFoldQuant` to weight.
|
||||
|
||||
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
|
||||
|
||||
Args:
|
||||
weight (numpy.ndarray): Weight of `cell_quant`.
|
||||
cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBatchNormQuant`.
|
||||
cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnFoldQuant`.
|
||||
|
||||
Returns:
|
||||
weight (numpy.ndarray): Folded weight.
|
||||
|
|
|
@ -12,4 +12,5 @@ setuptools >= 40.8.0
|
|||
matplotlib >= 3.1.3 # for ut test
|
||||
opencv-python >= 4.1.2.30 # for ut test
|
||||
sklearn >= 0.0 # for st test
|
||||
pandas >= 1.0.2 # for ut test
|
||||
pandas >= 1.0.2 # for ut test
|
||||
bs4
|
||||
|
|
Loading…
Reference in New Issue