forked from mindspore-Ecosystem/mindspore
!1084 add new interface quant combined
Merge pull request !1084 from SanjayChan/04quant
This commit is contained in:
commit
4bb5c7b39a
|
@ -97,7 +97,7 @@ class Cell:
|
||||||
|
|
||||||
After invoked, can get all the cell's children's name prefix by '_param_prefix'.
|
After invoked, can get all the cell's children's name prefix by '_param_prefix'.
|
||||||
"""
|
"""
|
||||||
cells = self.cells_and_names
|
cells = self.cells_and_names()
|
||||||
|
|
||||||
for cell_name, cell in cells:
|
for cell_name, cell in cells:
|
||||||
cell._param_prefix = cell_name
|
cell._param_prefix = cell_name
|
||||||
|
|
|
@ -0,0 +1,182 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Use combination of Conv, Dense, Relu, Batchnorm."""
|
||||||
|
|
||||||
|
from .normalization import BatchNorm2d
|
||||||
|
from .activation import get_activation
|
||||||
|
from ..cell import Cell
|
||||||
|
from . import conv, basic
|
||||||
|
from ..._checkparam import ParamValidator as validator
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['Conv2d', 'Dense']
|
||||||
|
|
||||||
|
class Conv2d(Cell):
|
||||||
|
r"""
|
||||||
|
A combination of convolution, Batchnorm, activation layer.
|
||||||
|
|
||||||
|
For a more Detailed overview of Conv2d op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channel :math:`C_{in}`.
|
||||||
|
out_channels (int): The number of output channel :math:`C_{out}`.
|
||||||
|
kernel_size (Union[int, tuple]): The data type is int or tuple with 2 integers. Specifies the height
|
||||||
|
and width of the 2D convolution window. Single int means the value if for both height and width of
|
||||||
|
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
|
||||||
|
width of the kernel.
|
||||||
|
stride (int): Specifies stride for all spatial dimensions with the same value. Value of stride should be
|
||||||
|
greater or equal to 1 but bounded by the height and width of the input. Default: 1.
|
||||||
|
pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
|
||||||
|
padding (int): Implicit paddings on both sides of the input. Default: 0.
|
||||||
|
dilation (int): Specifying the dilation rate to use for dilated convolution. If set to be :math:`k > 1`,
|
||||||
|
there will be :math:`k - 1` pixels skipped for each sampling location. Its value should be greater
|
||||||
|
or equal to 1 and bounded by the height and width of the input. Default: 1.
|
||||||
|
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
|
||||||
|
divisible by the number of groups. Default: 1.
|
||||||
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||||
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||||
|
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
|
||||||
|
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
|
||||||
|
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
|
||||||
|
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
|
||||||
|
Initializer for more details. Default: 'normal'.
|
||||||
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
|
||||||
|
Initializer and string are the same as 'weight_init'. Refer to the values of
|
||||||
|
Initializer for more details. Default: 'zeros'.
|
||||||
|
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
||||||
|
activation (string): Specifies activation type. The optional values are as following:
|
||||||
|
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||||
|
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = combined.Conv2d(120, 240, 4, batchnorm=True, activation='ReLU')
|
||||||
|
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||||
|
>>> net(input).shape()
|
||||||
|
(1, 240, 1024, 640)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
pad_mode='same',
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
group=1,
|
||||||
|
has_bias=False,
|
||||||
|
weight_init='normal',
|
||||||
|
bias_init='zeros',
|
||||||
|
batchnorm=None,
|
||||||
|
activation=None):
|
||||||
|
super(Conv2d, self).__init__()
|
||||||
|
self.conv = conv.Conv2d(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
pad_mode,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
group,
|
||||||
|
has_bias,
|
||||||
|
weight_init,
|
||||||
|
bias_init)
|
||||||
|
self.has_bn = batchnorm is not None
|
||||||
|
self.has_act = activation is not None
|
||||||
|
self.batchnorm = batchnorm
|
||||||
|
if batchnorm is True:
|
||||||
|
self.batchnorm = BatchNorm2d(out_channels)
|
||||||
|
elif batchnorm is not None:
|
||||||
|
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
|
||||||
|
self.activation = get_activation(activation)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
if self.has_bn:
|
||||||
|
x = self.batchnorm(x)
|
||||||
|
if self.has_act:
|
||||||
|
x = self.activation(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Dense(Cell):
|
||||||
|
r"""
|
||||||
|
A combination of Dense, Batchnorm, activation layer.
|
||||||
|
|
||||||
|
For a more Detailed overview of Dense op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of channels in the input space.
|
||||||
|
out_channels (int): The number of channels in the output space.
|
||||||
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||||
|
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||||
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||||
|
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||||
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||||
|
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||||
|
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
||||||
|
activation (string): Specifies activation type. The optional values are as following:
|
||||||
|
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||||
|
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor of shape :math:`(N, out\_channels)`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = nn.Dense(3, 4)
|
||||||
|
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||||
|
>>> net(input)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
weight_init='normal',
|
||||||
|
bias_init='zeros',
|
||||||
|
has_bias=True,
|
||||||
|
batchnorm=None,
|
||||||
|
activation=None):
|
||||||
|
super(Dense, self).__init__()
|
||||||
|
self.dense = basic.Dense(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
weight_init,
|
||||||
|
bias_init,
|
||||||
|
has_bias)
|
||||||
|
self.has_bn = batchnorm is not None
|
||||||
|
self.has_act = activation is not None
|
||||||
|
if batchnorm is True:
|
||||||
|
self.batchnorm = BatchNorm2d(out_channels)
|
||||||
|
elif batchnorm is not None:
|
||||||
|
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
|
||||||
|
self.activation = get_activation(activation)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.dense(x)
|
||||||
|
if self.has_bn:
|
||||||
|
x = self.batchnorm(x)
|
||||||
|
if self.has_act:
|
||||||
|
x = self.activation(x)
|
||||||
|
return x
|
|
@ -191,6 +191,8 @@ class Conv2dBatchNormQuant(Cell):
|
||||||
stride,
|
stride,
|
||||||
pad_mode,
|
pad_mode,
|
||||||
padding=0,
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
group=1,
|
||||||
eps=1e-5,
|
eps=1e-5,
|
||||||
momentum=0.9,
|
momentum=0.9,
|
||||||
weight_init=None,
|
weight_init=None,
|
||||||
|
@ -198,7 +200,6 @@ class Conv2dBatchNormQuant(Cell):
|
||||||
gamma_init=None,
|
gamma_init=None,
|
||||||
mean_init=None,
|
mean_init=None,
|
||||||
var_init=None,
|
var_init=None,
|
||||||
group=1,
|
|
||||||
quant_delay=0,
|
quant_delay=0,
|
||||||
freeze_bn=100000,
|
freeze_bn=100000,
|
||||||
fake=True,
|
fake=True,
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
quantization.
|
||||||
|
|
||||||
|
User can use aware quantization to train a model. Mindspore supports quantization aware training,
|
||||||
|
which models quantization errors in both the forward and backward passes using fake-quantization
|
||||||
|
ops. Note that the entire computation is carried out in floating point. At the end of quantization
|
||||||
|
aware training, Mindspore provides conversion functions to convert the trained model into lower precision.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .quant import convert_quant_network
|
||||||
|
|
||||||
|
__all__ = ["convert_quant_network"]
|
|
@ -0,0 +1,262 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""aware quantization."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from ... import nn
|
||||||
|
from ... import ops
|
||||||
|
from ..._checkparam import ParamValidator as validator
|
||||||
|
from ..._checkparam import Rel
|
||||||
|
from ...nn.layer import combined
|
||||||
|
from ...nn.layer import quant
|
||||||
|
|
||||||
|
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
|
||||||
|
nn.ReLU6: quant.ReLU6Quant,
|
||||||
|
nn.HSigmoid: quant.HSigmoidQuant,
|
||||||
|
nn.HSwish: quant.HSwishQuant}
|
||||||
|
|
||||||
|
|
||||||
|
class _AddFakeQuantInputOutput(nn.Cell):
|
||||||
|
"""
|
||||||
|
Add FakeQuant at input and output of the Network. Only support one input and one output case.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, network, quant_delay=0):
|
||||||
|
super(_AddFakeQuantInputOutput, self).__init__(auto_prefix=False)
|
||||||
|
self.network = network
|
||||||
|
self.fake_quant_input = quant.FakeQuantWithMinMax(
|
||||||
|
min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
|
||||||
|
self.fake_quant_input.update_parameters_name('fake_quant_input')
|
||||||
|
self.fake_quant_output = quant.FakeQuantWithMinMax(
|
||||||
|
min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
|
||||||
|
self.fake_quant_output.update_parameters_name('fake_quant_output')
|
||||||
|
|
||||||
|
def construct(self, data):
|
||||||
|
data = self.fake_quant_input(data)
|
||||||
|
output = self.network(data)
|
||||||
|
output = self.fake_quant_output(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class _AddFakeQuantAfterSubCell(nn.Cell):
|
||||||
|
"""
|
||||||
|
Add FakeQuant after of the sub Cell.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, subcell, quant_delay=0, num_bits=8):
|
||||||
|
super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False)
|
||||||
|
self.subcell = subcell
|
||||||
|
self.fake_quant_act = quant.FakeQuantWithMinMax(min_init=-6,
|
||||||
|
max_init=6,
|
||||||
|
num_bits=num_bits,
|
||||||
|
quant_delay=quant_delay,
|
||||||
|
ema=True)
|
||||||
|
|
||||||
|
def construct(self, *data):
|
||||||
|
output = self.subcell(*data)
|
||||||
|
output = self.fake_quant_act(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class ConvertToQuantNetwork:
|
||||||
|
"""
|
||||||
|
Convert network to quantization aware network
|
||||||
|
"""
|
||||||
|
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
network,
|
||||||
|
quant_delay=0,
|
||||||
|
bn_fold=False,
|
||||||
|
freeze_bn=0,
|
||||||
|
weight_bits=8,
|
||||||
|
act_bits=8,
|
||||||
|
per_channel=False,
|
||||||
|
symmetric=False,
|
||||||
|
narrow_range=False):
|
||||||
|
self.network = validator.check_isinstance(
|
||||||
|
'network', network, (nn.Cell,))
|
||||||
|
self.quant_delay = validator.check_integer(
|
||||||
|
"quant delay", quant_delay, 0, Rel.GE)
|
||||||
|
self.freeze_bn = validator.check_integer(
|
||||||
|
"freeze bn", freeze_bn, 0, Rel.GE)
|
||||||
|
self.weight_bits = validator.check_integer(
|
||||||
|
"weights bit", weight_bits, 0, Rel.GE)
|
||||||
|
self.act_bits = validator.check_integer(
|
||||||
|
"activations bit", act_bits, 0, Rel.GE)
|
||||||
|
self.bn_fold = validator.check_bool("bn fold", bn_fold)
|
||||||
|
self.per_channel = validator.check_bool("per channel", per_channel)
|
||||||
|
self.symmetric = validator.check_bool("symmetric", symmetric)
|
||||||
|
self.narrow_range = validator.check_bool("narrow range", narrow_range)
|
||||||
|
|
||||||
|
def _convert_op_name(self, name):
|
||||||
|
pattern = re.compile(r'([A-Z]{1})')
|
||||||
|
name_new = re.sub(pattern, r'_\1', name).lower()
|
||||||
|
if name_new[0] == '_':
|
||||||
|
name_new = name_new[1:]
|
||||||
|
return name_new
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.network.update_cell_prefix()
|
||||||
|
network = self._convert_subcells2quant(self.network)
|
||||||
|
return network
|
||||||
|
|
||||||
|
def _convert_subcells2quant(self, network):
|
||||||
|
"""
|
||||||
|
convet sub cell to quant cell
|
||||||
|
"""
|
||||||
|
cells = network.name_cells()
|
||||||
|
change = False
|
||||||
|
for name in cells:
|
||||||
|
subcell = cells[name]
|
||||||
|
if subcell == network:
|
||||||
|
continue
|
||||||
|
elif isinstance(subcell, combined.Conv2d):
|
||||||
|
prefix = subcell.param_prefix
|
||||||
|
new_subcell = self._convert_conv(subcell)
|
||||||
|
new_subcell.update_parameters_name(prefix + '.')
|
||||||
|
network.insert_child_to_cell(name, new_subcell)
|
||||||
|
change = True
|
||||||
|
elif isinstance(subcell, combined.Dense):
|
||||||
|
prefix = subcell.param_prefix
|
||||||
|
new_subcell = self._convert_dense(subcell)
|
||||||
|
new_subcell.update_parameters_name(prefix + '.')
|
||||||
|
network.insert_child_to_cell(name, new_subcell)
|
||||||
|
change = True
|
||||||
|
else:
|
||||||
|
self._convert_subcells2quant(subcell)
|
||||||
|
if isinstance(network, nn.SequentialCell) and change:
|
||||||
|
network.cell_list = list(network.cells())
|
||||||
|
|
||||||
|
# tensoradd to tensoradd quant
|
||||||
|
add_list = []
|
||||||
|
for name in network.__dict__:
|
||||||
|
if name[0] == '_':
|
||||||
|
continue
|
||||||
|
attr = network.__dict__[name]
|
||||||
|
if isinstance(attr, ops.Primitive) and attr.name in ConvertToQuantNetwork.__quant_op_name__:
|
||||||
|
add_list.append((name, attr))
|
||||||
|
for name, prim_op in add_list:
|
||||||
|
prefix = name
|
||||||
|
add_quant = _AddFakeQuantAfterSubCell(prim_op) # quant.TensorAddQuant()
|
||||||
|
prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
|
||||||
|
add_quant.update_parameters_name(prefix + '.')
|
||||||
|
del network.__dict__[name]
|
||||||
|
network.insert_child_to_cell(name, add_quant)
|
||||||
|
return network
|
||||||
|
|
||||||
|
def _convert_conv(self, subcell):
|
||||||
|
"""
|
||||||
|
convet conv cell to combine cell
|
||||||
|
"""
|
||||||
|
conv_inner = subcell.conv
|
||||||
|
bn_inner = subcell.batchnorm
|
||||||
|
if subcell.batchnorm is not None and self.bn_fold:
|
||||||
|
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
|
||||||
|
conv_inner.out_channels,
|
||||||
|
kernel_size=conv_inner.kernel_size,
|
||||||
|
stride=conv_inner.stride,
|
||||||
|
pad_mode=conv_inner.pad_mode,
|
||||||
|
padding=conv_inner.padding,
|
||||||
|
dilation=conv_inner.dilation,
|
||||||
|
group=conv_inner.group,
|
||||||
|
eps=bn_inner.eps,
|
||||||
|
momentum=bn_inner.momentum,
|
||||||
|
quant_delay=self.quant_delay,
|
||||||
|
freeze_bn=self.freeze_bn,
|
||||||
|
per_channel=self.per_channel,
|
||||||
|
num_bits=self.weight_bits,
|
||||||
|
fake=True,
|
||||||
|
symmetric=self.symmetric,
|
||||||
|
narrow_range=self.narrow_range)
|
||||||
|
del subcell.batchnorm
|
||||||
|
subcell.batchnorm = None
|
||||||
|
subcell.has_bn = False
|
||||||
|
else:
|
||||||
|
conv_inner = quant.Conv2dQuant(conv_inner.in_channels,
|
||||||
|
conv_inner.out_channels,
|
||||||
|
kernel_size=conv_inner.kernel_size,
|
||||||
|
stride=conv_inner.stride,
|
||||||
|
pad_mode=conv_inner.pad_mode,
|
||||||
|
padding=conv_inner.padding,
|
||||||
|
dilation=conv_inner.dilation,
|
||||||
|
group=conv_inner.group,
|
||||||
|
has_bias=conv_inner.has_bias,
|
||||||
|
quant_delay=self.quant_delay,
|
||||||
|
per_channel=self.per_channel,
|
||||||
|
num_bits=self.weight_bits,
|
||||||
|
symmetric=self.symmetric,
|
||||||
|
narrow_range=self.narrow_range)
|
||||||
|
subcell.conv = conv_inner
|
||||||
|
if subcell.activation is not None:
|
||||||
|
subcell.activation = self._convert_activation(subcell.activation)
|
||||||
|
else:
|
||||||
|
subcell = _AddFakeQuantAfterSubCell(subcell)
|
||||||
|
return subcell
|
||||||
|
|
||||||
|
def _convert_dense(self, subcell):
|
||||||
|
"""
|
||||||
|
convert dense cell to combine dense cell
|
||||||
|
"""
|
||||||
|
dense_inner = subcell.dense
|
||||||
|
dense_inner = quant.DenseQuant(dense_inner.in_channels,
|
||||||
|
dense_inner.out_channels,
|
||||||
|
has_bias=dense_inner.has_bias,
|
||||||
|
quant_delay=self.quant_delay,
|
||||||
|
per_channel=self.per_channel,
|
||||||
|
num_bits=self.weight_bits)
|
||||||
|
subcell.dense = dense_inner
|
||||||
|
if subcell.activation is not None:
|
||||||
|
subcell.activation = self._convert_activation(subcell.activation)
|
||||||
|
return subcell
|
||||||
|
|
||||||
|
def _convert_activation(self, activation):
|
||||||
|
act_class = activation.__class__
|
||||||
|
if act_class not in _ACTIVATION_MAP:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported activation in auto Quant: ", act_class)
|
||||||
|
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, quant_delay=self.quant_delay)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_quant_network(network,
|
||||||
|
quant_delay=0,
|
||||||
|
bn_fold=False,
|
||||||
|
freeze_bn=0,
|
||||||
|
weight_bits=8,
|
||||||
|
act_bits=8,
|
||||||
|
per_channel=False,
|
||||||
|
symmetric=False,
|
||||||
|
narrow_range=False
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Create aware quantizaiton training network.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
network (Cell): Obtain a pipeline through network for saving graph summary.
|
||||||
|
quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0.
|
||||||
|
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False.
|
||||||
|
freeze_bn (bool): Number of steps after which BN parameters used total mean and variance. Default: 0.
|
||||||
|
weight_bits (int): Number of bits to use for quantizing weights. Default: 8.
|
||||||
|
act_bits (int): Number of bits to use for quantizing activations. Default: 8.
|
||||||
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||||
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
|
||||||
|
returns:
|
||||||
|
Cell, Network which has change to aware quantization training network.
|
||||||
|
"""
|
||||||
|
net = ConvertToQuantNetwork(
|
||||||
|
network, quant_delay, bn_fold, freeze_bn, weight_bits, act_bits, per_channel, symmetric, narrow_range)
|
||||||
|
return net.run()
|
|
@ -0,0 +1,100 @@
|
||||||
|
"""MobileNetV2"""
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
def make_divisible(input_x, div_by=8):
|
||||||
|
return int((input_x + div_by) // div_by)
|
||||||
|
|
||||||
|
|
||||||
|
def _conv_bn(in_channel,
|
||||||
|
out_channel,
|
||||||
|
ksize,
|
||||||
|
stride=1):
|
||||||
|
"""Get a conv2d batchnorm and relu layer."""
|
||||||
|
return nn.SequentialCell(
|
||||||
|
[nn.Conv2d(in_channel,
|
||||||
|
out_channel,
|
||||||
|
kernel_size=ksize,
|
||||||
|
stride=stride),
|
||||||
|
nn.BatchNorm2d(out_channel)])
|
||||||
|
|
||||||
|
|
||||||
|
class InvertedResidual(nn.Cell):
|
||||||
|
def __init__(self, inp, oup, stride, expend_ratio):
|
||||||
|
super(InvertedResidual, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
assert stride in [1, 2]
|
||||||
|
|
||||||
|
hidden_dim = int(inp * expend_ratio)
|
||||||
|
self.use_res_connect = self.stride == 1 and inp == oup
|
||||||
|
if expend_ratio == 1:
|
||||||
|
self.conv = nn.SequentialCell([
|
||||||
|
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim),
|
||||||
|
nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.ReLU6(),
|
||||||
|
nn.Conv2d(hidden_dim, oup, 1, 1),
|
||||||
|
nn.BatchNorm2d(oup)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.conv = nn.SequentialCell([
|
||||||
|
nn.Conv2d(inp, hidden_dim, 1, 1),
|
||||||
|
nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.ReLU6(),
|
||||||
|
|
||||||
|
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim),
|
||||||
|
nn.BatchNorm2d(hidden_dim),
|
||||||
|
nn.ReLU6(),
|
||||||
|
|
||||||
|
nn.Conv2d(hidden_dim, oup, 1, 1),
|
||||||
|
nn.BatchNorm2d(oup)
|
||||||
|
])
|
||||||
|
|
||||||
|
def construct(self, input_x):
|
||||||
|
out = self.conv(input_x)
|
||||||
|
if self.use_res_connect:
|
||||||
|
out = input_x + out
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV2(nn.Cell):
|
||||||
|
def __init__(self, num_class=1000, input_size=224, width_mul=1.):
|
||||||
|
super(MobileNetV2, self).__init__()
|
||||||
|
block = InvertedResidual
|
||||||
|
input_channel = 32
|
||||||
|
last_channel = 1280
|
||||||
|
inverted_residual_setting = [
|
||||||
|
[1, 16, 1, 1],
|
||||||
|
[6, 24, 2, 2],
|
||||||
|
[6, 32, 3, 2],
|
||||||
|
[6, 64, 4, 2],
|
||||||
|
[6, 96, 3, 1],
|
||||||
|
[6, 160, 3, 2],
|
||||||
|
[6, 230, 1, 1],
|
||||||
|
]
|
||||||
|
if width_mul > 1.0:
|
||||||
|
last_channel = make_divisible(last_channel * width_mul)
|
||||||
|
self.last_channel = last_channel
|
||||||
|
features = [_conv_bn(3, input_channel, 3, 2)]
|
||||||
|
|
||||||
|
for t, c, n, s in inverted_residual_setting:
|
||||||
|
out_channel = make_divisible(c * width_mul) if t > 1 else c
|
||||||
|
for i in range(n):
|
||||||
|
if i == 0:
|
||||||
|
features.append(block(input_channel, out_channel, s, t))
|
||||||
|
else:
|
||||||
|
features.append(block(input_channel, out_channel, 1, t))
|
||||||
|
input_channel = out_channel
|
||||||
|
|
||||||
|
features.append(_conv_bn(input_channel, self.last_channel, 1))
|
||||||
|
|
||||||
|
self.features = nn.SequentialCell(features)
|
||||||
|
self.mean = P.ReduceMean(keep_dims=False)
|
||||||
|
self.classifier = nn.Dense(self.last_channel, num_class)
|
||||||
|
|
||||||
|
def construct(self, input_x):
|
||||||
|
out = input_x
|
||||||
|
out = self.features(out)
|
||||||
|
out = self.mean(out, (2, 3))
|
||||||
|
out = self.classifier(out)
|
||||||
|
return out
|
|
@ -0,0 +1,108 @@
|
||||||
|
"""mobile net v2"""
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.nn.layer import combined
|
||||||
|
|
||||||
|
|
||||||
|
def make_divisible(input_x, div_by=8):
|
||||||
|
return int((input_x + div_by) // div_by)
|
||||||
|
|
||||||
|
|
||||||
|
def _conv_bn(in_channel,
|
||||||
|
out_channel,
|
||||||
|
ksize,
|
||||||
|
stride=1):
|
||||||
|
"""Get a conv2d batchnorm and relu layer."""
|
||||||
|
return nn.SequentialCell(
|
||||||
|
[combined.Conv2d(in_channel,
|
||||||
|
out_channel,
|
||||||
|
kernel_size=ksize,
|
||||||
|
stride=stride,
|
||||||
|
batchnorm=True)])
|
||||||
|
|
||||||
|
|
||||||
|
class InvertedResidual(nn.Cell):
|
||||||
|
def __init__(self, inp, oup, stride, expend_ratio):
|
||||||
|
super(InvertedResidual, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
assert stride in [1, 2]
|
||||||
|
|
||||||
|
hidden_dim = int(inp * expend_ratio)
|
||||||
|
self.use_res_connect = self.stride == 1 and inp == oup
|
||||||
|
if expend_ratio == 1:
|
||||||
|
self.conv = nn.SequentialCell([
|
||||||
|
combined.Conv2d(hidden_dim,
|
||||||
|
hidden_dim,
|
||||||
|
3,
|
||||||
|
stride,
|
||||||
|
group=hidden_dim,
|
||||||
|
batchnorm=True,
|
||||||
|
activation='relu6'),
|
||||||
|
combined.Conv2d(hidden_dim, oup, 1, 1,
|
||||||
|
batchnorm=True)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.conv = nn.SequentialCell([
|
||||||
|
combined.Conv2d(inp, hidden_dim, 1, 1,
|
||||||
|
batchnorm=True,
|
||||||
|
activation='relu6'),
|
||||||
|
combined.Conv2d(hidden_dim,
|
||||||
|
hidden_dim,
|
||||||
|
3,
|
||||||
|
stride,
|
||||||
|
group=hidden_dim,
|
||||||
|
batchnorm=True,
|
||||||
|
activation='relu6'),
|
||||||
|
combined.Conv2d(hidden_dim, oup, 1, 1,
|
||||||
|
batchnorm=True)
|
||||||
|
])
|
||||||
|
self.add = P.TensorAdd()
|
||||||
|
|
||||||
|
def construct(self, input_x):
|
||||||
|
out = self.conv(input_x)
|
||||||
|
if self.use_res_connect:
|
||||||
|
out = self.add(input_x, out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV2(nn.Cell):
|
||||||
|
def __init__(self, num_class=1000, input_size=224, width_mul=1.):
|
||||||
|
super(MobileNetV2, self).__init__()
|
||||||
|
block = InvertedResidual
|
||||||
|
input_channel = 32
|
||||||
|
last_channel = 1280
|
||||||
|
inverted_residual_setting = [
|
||||||
|
[1, 16, 1, 1],
|
||||||
|
[6, 24, 2, 2],
|
||||||
|
[6, 32, 3, 2],
|
||||||
|
[6, 64, 4, 2],
|
||||||
|
[6, 96, 3, 1],
|
||||||
|
[6, 160, 3, 2],
|
||||||
|
[6, 230, 1, 1],
|
||||||
|
]
|
||||||
|
if width_mul > 1.0:
|
||||||
|
last_channel = make_divisible(last_channel * width_mul)
|
||||||
|
self.last_channel = last_channel
|
||||||
|
features = [_conv_bn(3, input_channel, 3, 2)]
|
||||||
|
|
||||||
|
for t, c, n, s in inverted_residual_setting:
|
||||||
|
out_channel = make_divisible(c * width_mul) if t > 1 else c
|
||||||
|
for i in range(n):
|
||||||
|
if i == 0:
|
||||||
|
features.append(block(input_channel, out_channel, s, t))
|
||||||
|
else:
|
||||||
|
features.append(block(input_channel, out_channel, 1, t))
|
||||||
|
input_channel = out_channel
|
||||||
|
|
||||||
|
features.append(_conv_bn(input_channel, self.last_channel, 1))
|
||||||
|
|
||||||
|
self.features = nn.SequentialCell(features)
|
||||||
|
self.mean = P.ReduceMean(keep_dims=False)
|
||||||
|
self.classifier = combined.Dense(self.last_channel, num_class)
|
||||||
|
|
||||||
|
def construct(self, input_x):
|
||||||
|
out = input_x
|
||||||
|
out = self.features(out)
|
||||||
|
out = self.mean(out, (2, 3))
|
||||||
|
out = self.classifier(out)
|
||||||
|
return out
|
|
@ -0,0 +1,94 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
""" tests for quant """
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.train.quant import quant as qat
|
||||||
|
from mindspore import nn
|
||||||
|
import mindspore.ops.operations as P
|
||||||
|
from mindspore.nn.layer import combined
|
||||||
|
import mindspore.context as context
|
||||||
|
from mobilenetv2_combined import MobileNetV2
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
|
|
||||||
|
class LeNet5(nn.Cell):
|
||||||
|
"""
|
||||||
|
Lenet network
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_class (int): Num classes. Default: 10.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, output tensor
|
||||||
|
Examples:
|
||||||
|
>>> LeNet(num_class=10)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_class=10):
|
||||||
|
super(LeNet5, self).__init__()
|
||||||
|
self.num_class = num_class
|
||||||
|
self.conv1 = combined.Conv2d(
|
||||||
|
1, 6, kernel_size=5, batchnorm=True, activation='relu6')
|
||||||
|
self.conv2 = combined.Conv2d(6, 16, kernel_size=5, activation='relu')
|
||||||
|
self.fc1 = combined.Dense(16 * 5 * 5, 120, activation='relu')
|
||||||
|
self.fc2 = combined.Dense(120, 84, activation='relu')
|
||||||
|
self.fc3 = combined.Dense(84, self.num_class)
|
||||||
|
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||||
|
self.flattern = nn.Flatten()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.max_pool2d(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.max_pool2d(x)
|
||||||
|
x = self.flattern(x)
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
x = self.fc3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test_qat_lenet():
|
||||||
|
net = LeNet5()
|
||||||
|
net = qat.convert_quant_network(
|
||||||
|
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
|
||||||
|
|
||||||
|
|
||||||
|
def test_qat_mobile():
|
||||||
|
net = MobileNetV2()
|
||||||
|
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
|
||||||
|
net = qat.convert_quant_network(
|
||||||
|
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
|
||||||
|
net(img)
|
||||||
|
|
||||||
|
|
||||||
|
def test_qat_mobile_train():
|
||||||
|
net = MobileNetV2(num_class=10)
|
||||||
|
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
|
||||||
|
label = Tensor(np.ones((1, 10)).astype(np.float32))
|
||||||
|
net = qat.convert_quant_network(
|
||||||
|
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
|
||||||
|
|
||||||
|
loss = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
|
||||||
|
optimizer = nn.Momentum(net.trainable_params(),
|
||||||
|
learning_rate=0.1, momentum=0.9)
|
||||||
|
net = nn.WithLossCell(net, loss)
|
||||||
|
net = nn.TrainOneStepCell(net, optimizer)
|
||||||
|
net(img, label)
|
Loading…
Reference in New Issue