!1084 add new interface quant combined

Merge pull request !1084 from SanjayChan/04quant
This commit is contained in:
mindspore-ci-bot 2020-05-14 11:19:00 +08:00 committed by Gitee
commit 4bb5c7b39a
8 changed files with 775 additions and 2 deletions

View File

@ -97,7 +97,7 @@ class Cell:
After invoked, can get all the cell's children's name prefix by '_param_prefix'. After invoked, can get all the cell's children's name prefix by '_param_prefix'.
""" """
cells = self.cells_and_names cells = self.cells_and_names()
for cell_name, cell in cells: for cell_name, cell in cells:
cell._param_prefix = cell_name cell._param_prefix = cell_name

View File

@ -0,0 +1,182 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Use combination of Conv, Dense, Relu, Batchnorm."""
from .normalization import BatchNorm2d
from .activation import get_activation
from ..cell import Cell
from . import conv, basic
from ..._checkparam import ParamValidator as validator
__all__ = ['Conv2d', 'Dense']
class Conv2d(Cell):
r"""
A combination of convolution, Batchnorm, activation layer.
For a more Detailed overview of Conv2d op.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple]): The data type is int or tuple with 2 integers. Specifies the height
and width of the 2D convolution window. Single int means the value if for both height and width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel.
stride (int): Specifies stride for all spatial dimensions with the same value. Value of stride should be
greater or equal to 1 but bounded by the height and width of the input. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding (int): Implicit paddings on both sides of the input. Default: 0.
dilation (int): Specifying the dilation rate to use for dilated convolution. If set to be :math:`k > 1`,
there will be :math:`k - 1` pixels skipped for each sampling location. Its value should be greater
or equal to 1 and bounded by the height and width of the input. Default: 1.
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> net = combined.Conv2d(120, 240, 4, batchnorm=True, activation='ReLU')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape()
(1, 240, 1024, 640)
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros',
batchnorm=None,
activation=None):
super(Conv2d, self).__init__()
self.conv = conv.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
has_bias,
weight_init,
bias_init)
self.has_bn = batchnorm is not None
self.has_act = activation is not None
self.batchnorm = batchnorm
if batchnorm is True:
self.batchnorm = BatchNorm2d(out_channels)
elif batchnorm is not None:
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
self.activation = get_activation(activation)
def construct(self, x):
x = self.conv(x)
if self.has_bn:
x = self.batchnorm(x)
if self.has_act:
x = self.activation(x)
return x
class Dense(Cell):
r"""
A combination of Dense, Batchnorm, activation layer.
For a more Detailed overview of Dense op.
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
Examples:
>>> net = nn.Dense(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input)
"""
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
batchnorm=None,
activation=None):
super(Dense, self).__init__()
self.dense = basic.Dense(
in_channels,
out_channels,
weight_init,
bias_init,
has_bias)
self.has_bn = batchnorm is not None
self.has_act = activation is not None
if batchnorm is True:
self.batchnorm = BatchNorm2d(out_channels)
elif batchnorm is not None:
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
self.activation = get_activation(activation)
def construct(self, x):
x = self.dense(x)
if self.has_bn:
x = self.batchnorm(x)
if self.has_act:
x = self.activation(x)
return x

View File

@ -191,6 +191,8 @@ class Conv2dBatchNormQuant(Cell):
stride, stride,
pad_mode, pad_mode,
padding=0, padding=0,
dilation=1,
group=1,
eps=1e-5, eps=1e-5,
momentum=0.9, momentum=0.9,
weight_init=None, weight_init=None,
@ -198,7 +200,6 @@ class Conv2dBatchNormQuant(Cell):
gamma_init=None, gamma_init=None,
mean_init=None, mean_init=None,
var_init=None, var_init=None,
group=1,
quant_delay=0, quant_delay=0,
freeze_bn=100000, freeze_bn=100000,
fake=True, fake=True,

View File

@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
quantization.
User can use aware quantization to train a model. Mindspore supports quantization aware training,
which models quantization errors in both the forward and backward passes using fake-quantization
ops. Note that the entire computation is carried out in floating point. At the end of quantization
aware training, Mindspore provides conversion functions to convert the trained model into lower precision.
"""
from .quant import convert_quant_network
__all__ = ["convert_quant_network"]

View File

@ -0,0 +1,262 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""aware quantization."""
import re
from ... import nn
from ... import ops
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Rel
from ...nn.layer import combined
from ...nn.layer import quant
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
nn.ReLU6: quant.ReLU6Quant,
nn.HSigmoid: quant.HSigmoidQuant,
nn.HSwish: quant.HSwishQuant}
class _AddFakeQuantInputOutput(nn.Cell):
"""
Add FakeQuant at input and output of the Network. Only support one input and one output case.
"""
def __init__(self, network, quant_delay=0):
super(_AddFakeQuantInputOutput, self).__init__(auto_prefix=False)
self.network = network
self.fake_quant_input = quant.FakeQuantWithMinMax(
min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
self.fake_quant_input.update_parameters_name('fake_quant_input')
self.fake_quant_output = quant.FakeQuantWithMinMax(
min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
self.fake_quant_output.update_parameters_name('fake_quant_output')
def construct(self, data):
data = self.fake_quant_input(data)
output = self.network(data)
output = self.fake_quant_output(output)
return output
class _AddFakeQuantAfterSubCell(nn.Cell):
"""
Add FakeQuant after of the sub Cell.
"""
def __init__(self, subcell, quant_delay=0, num_bits=8):
super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False)
self.subcell = subcell
self.fake_quant_act = quant.FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True)
def construct(self, *data):
output = self.subcell(*data)
output = self.fake_quant_act(output)
return output
class ConvertToQuantNetwork:
"""
Convert network to quantization aware network
"""
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
def __init__(self,
network,
quant_delay=0,
bn_fold=False,
freeze_bn=0,
weight_bits=8,
act_bits=8,
per_channel=False,
symmetric=False,
narrow_range=False):
self.network = validator.check_isinstance(
'network', network, (nn.Cell,))
self.quant_delay = validator.check_integer(
"quant delay", quant_delay, 0, Rel.GE)
self.freeze_bn = validator.check_integer(
"freeze bn", freeze_bn, 0, Rel.GE)
self.weight_bits = validator.check_integer(
"weights bit", weight_bits, 0, Rel.GE)
self.act_bits = validator.check_integer(
"activations bit", act_bits, 0, Rel.GE)
self.bn_fold = validator.check_bool("bn fold", bn_fold)
self.per_channel = validator.check_bool("per channel", per_channel)
self.symmetric = validator.check_bool("symmetric", symmetric)
self.narrow_range = validator.check_bool("narrow range", narrow_range)
def _convert_op_name(self, name):
pattern = re.compile(r'([A-Z]{1})')
name_new = re.sub(pattern, r'_\1', name).lower()
if name_new[0] == '_':
name_new = name_new[1:]
return name_new
def run(self):
self.network.update_cell_prefix()
network = self._convert_subcells2quant(self.network)
return network
def _convert_subcells2quant(self, network):
"""
convet sub cell to quant cell
"""
cells = network.name_cells()
change = False
for name in cells:
subcell = cells[name]
if subcell == network:
continue
elif isinstance(subcell, combined.Conv2d):
prefix = subcell.param_prefix
new_subcell = self._convert_conv(subcell)
new_subcell.update_parameters_name(prefix + '.')
network.insert_child_to_cell(name, new_subcell)
change = True
elif isinstance(subcell, combined.Dense):
prefix = subcell.param_prefix
new_subcell = self._convert_dense(subcell)
new_subcell.update_parameters_name(prefix + '.')
network.insert_child_to_cell(name, new_subcell)
change = True
else:
self._convert_subcells2quant(subcell)
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())
# tensoradd to tensoradd quant
add_list = []
for name in network.__dict__:
if name[0] == '_':
continue
attr = network.__dict__[name]
if isinstance(attr, ops.Primitive) and attr.name in ConvertToQuantNetwork.__quant_op_name__:
add_list.append((name, attr))
for name, prim_op in add_list:
prefix = name
add_quant = _AddFakeQuantAfterSubCell(prim_op) # quant.TensorAddQuant()
prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
add_quant.update_parameters_name(prefix + '.')
del network.__dict__[name]
network.insert_child_to_cell(name, add_quant)
return network
def _convert_conv(self, subcell):
"""
convet conv cell to combine cell
"""
conv_inner = subcell.conv
bn_inner = subcell.batchnorm
if subcell.batchnorm is not None and self.bn_fold:
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
conv_inner.out_channels,
kernel_size=conv_inner.kernel_size,
stride=conv_inner.stride,
pad_mode=conv_inner.pad_mode,
padding=conv_inner.padding,
dilation=conv_inner.dilation,
group=conv_inner.group,
eps=bn_inner.eps,
momentum=bn_inner.momentum,
quant_delay=self.quant_delay,
freeze_bn=self.freeze_bn,
per_channel=self.per_channel,
num_bits=self.weight_bits,
fake=True,
symmetric=self.symmetric,
narrow_range=self.narrow_range)
del subcell.batchnorm
subcell.batchnorm = None
subcell.has_bn = False
else:
conv_inner = quant.Conv2dQuant(conv_inner.in_channels,
conv_inner.out_channels,
kernel_size=conv_inner.kernel_size,
stride=conv_inner.stride,
pad_mode=conv_inner.pad_mode,
padding=conv_inner.padding,
dilation=conv_inner.dilation,
group=conv_inner.group,
has_bias=conv_inner.has_bias,
quant_delay=self.quant_delay,
per_channel=self.per_channel,
num_bits=self.weight_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range)
subcell.conv = conv_inner
if subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation)
else:
subcell = _AddFakeQuantAfterSubCell(subcell)
return subcell
def _convert_dense(self, subcell):
"""
convert dense cell to combine dense cell
"""
dense_inner = subcell.dense
dense_inner = quant.DenseQuant(dense_inner.in_channels,
dense_inner.out_channels,
has_bias=dense_inner.has_bias,
quant_delay=self.quant_delay,
per_channel=self.per_channel,
num_bits=self.weight_bits)
subcell.dense = dense_inner
if subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation)
return subcell
def _convert_activation(self, activation):
act_class = activation.__class__
if act_class not in _ACTIVATION_MAP:
raise ValueError(
"Unsupported activation in auto Quant: ", act_class)
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, quant_delay=self.quant_delay)
def convert_quant_network(network,
quant_delay=0,
bn_fold=False,
freeze_bn=0,
weight_bits=8,
act_bits=8,
per_channel=False,
symmetric=False,
narrow_range=False
):
r"""
Create aware quantizaiton training network.
Args:
network (Cell): Obtain a pipeline through network for saving graph summary.
quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0.
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False.
freeze_bn (bool): Number of steps after which BN parameters used total mean and variance. Default: 0.
weight_bits (int): Number of bits to use for quantizing weights. Default: 8.
act_bits (int): Number of bits to use for quantizing activations. Default: 8.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
returns:
Cell, Network which has change to aware quantization training network.
"""
net = ConvertToQuantNetwork(
network, quant_delay, bn_fold, freeze_bn, weight_bits, act_bits, per_channel, symmetric, narrow_range)
return net.run()

View File

@ -0,0 +1,100 @@
"""MobileNetV2"""
from mindspore import nn
from mindspore.ops import operations as P
def make_divisible(input_x, div_by=8):
return int((input_x + div_by) // div_by)
def _conv_bn(in_channel,
out_channel,
ksize,
stride=1):
"""Get a conv2d batchnorm and relu layer."""
return nn.SequentialCell(
[nn.Conv2d(in_channel,
out_channel,
kernel_size=ksize,
stride=stride),
nn.BatchNorm2d(out_channel)])
class InvertedResidual(nn.Cell):
def __init__(self, inp, oup, stride, expend_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expend_ratio)
self.use_res_connect = self.stride == 1 and inp == oup
if expend_ratio == 1:
self.conv = nn.SequentialCell([
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(),
nn.Conv2d(hidden_dim, oup, 1, 1),
nn.BatchNorm2d(oup)
])
else:
self.conv = nn.SequentialCell([
nn.Conv2d(inp, hidden_dim, 1, 1),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(),
nn.Conv2d(hidden_dim, oup, 1, 1),
nn.BatchNorm2d(oup)
])
def construct(self, input_x):
out = self.conv(input_x)
if self.use_res_connect:
out = input_x + out
return out
class MobileNetV2(nn.Cell):
def __init__(self, num_class=1000, input_size=224, width_mul=1.):
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
inverted_residual_setting = [
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 230, 1, 1],
]
if width_mul > 1.0:
last_channel = make_divisible(last_channel * width_mul)
self.last_channel = last_channel
features = [_conv_bn(3, input_channel, 3, 2)]
for t, c, n, s in inverted_residual_setting:
out_channel = make_divisible(c * width_mul) if t > 1 else c
for i in range(n):
if i == 0:
features.append(block(input_channel, out_channel, s, t))
else:
features.append(block(input_channel, out_channel, 1, t))
input_channel = out_channel
features.append(_conv_bn(input_channel, self.last_channel, 1))
self.features = nn.SequentialCell(features)
self.mean = P.ReduceMean(keep_dims=False)
self.classifier = nn.Dense(self.last_channel, num_class)
def construct(self, input_x):
out = input_x
out = self.features(out)
out = self.mean(out, (2, 3))
out = self.classifier(out)
return out

View File

@ -0,0 +1,108 @@
"""mobile net v2"""
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.nn.layer import combined
def make_divisible(input_x, div_by=8):
return int((input_x + div_by) // div_by)
def _conv_bn(in_channel,
out_channel,
ksize,
stride=1):
"""Get a conv2d batchnorm and relu layer."""
return nn.SequentialCell(
[combined.Conv2d(in_channel,
out_channel,
kernel_size=ksize,
stride=stride,
batchnorm=True)])
class InvertedResidual(nn.Cell):
def __init__(self, inp, oup, stride, expend_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expend_ratio)
self.use_res_connect = self.stride == 1 and inp == oup
if expend_ratio == 1:
self.conv = nn.SequentialCell([
combined.Conv2d(hidden_dim,
hidden_dim,
3,
stride,
group=hidden_dim,
batchnorm=True,
activation='relu6'),
combined.Conv2d(hidden_dim, oup, 1, 1,
batchnorm=True)
])
else:
self.conv = nn.SequentialCell([
combined.Conv2d(inp, hidden_dim, 1, 1,
batchnorm=True,
activation='relu6'),
combined.Conv2d(hidden_dim,
hidden_dim,
3,
stride,
group=hidden_dim,
batchnorm=True,
activation='relu6'),
combined.Conv2d(hidden_dim, oup, 1, 1,
batchnorm=True)
])
self.add = P.TensorAdd()
def construct(self, input_x):
out = self.conv(input_x)
if self.use_res_connect:
out = self.add(input_x, out)
return out
class MobileNetV2(nn.Cell):
def __init__(self, num_class=1000, input_size=224, width_mul=1.):
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
inverted_residual_setting = [
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 230, 1, 1],
]
if width_mul > 1.0:
last_channel = make_divisible(last_channel * width_mul)
self.last_channel = last_channel
features = [_conv_bn(3, input_channel, 3, 2)]
for t, c, n, s in inverted_residual_setting:
out_channel = make_divisible(c * width_mul) if t > 1 else c
for i in range(n):
if i == 0:
features.append(block(input_channel, out_channel, s, t))
else:
features.append(block(input_channel, out_channel, 1, t))
input_channel = out_channel
features.append(_conv_bn(input_channel, self.last_channel, 1))
self.features = nn.SequentialCell(features)
self.mean = P.ReduceMean(keep_dims=False)
self.classifier = combined.Dense(self.last_channel, num_class)
def construct(self, input_x):
out = input_x
out = self.features(out)
out = self.mean(out, (2, 3))
out = self.classifier(out)
return out

View File

@ -0,0 +1,94 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" tests for quant """
import numpy as np
from mindspore import Tensor
from mindspore.train.quant import quant as qat
from mindspore import nn
import mindspore.ops.operations as P
from mindspore.nn.layer import combined
import mindspore.context as context
from mobilenetv2_combined import MobileNetV2
context.set_context(mode=context.GRAPH_MODE)
class LeNet5(nn.Cell):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def __init__(self, num_class=10):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = combined.Conv2d(
1, 6, kernel_size=5, batchnorm=True, activation='relu6')
self.conv2 = combined.Conv2d(6, 16, kernel_size=5, activation='relu')
self.fc1 = combined.Dense(16 * 5 * 5, 120, activation='relu')
self.fc2 = combined.Dense(120, 84, activation='relu')
self.fc3 = combined.Dense(84, self.num_class)
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flattern = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.bn(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.max_pool2d(x)
x = self.flattern(x)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
def test_qat_lenet():
net = LeNet5()
net = qat.convert_quant_network(
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
def test_qat_mobile():
net = MobileNetV2()
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
net = qat.convert_quant_network(
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
net(img)
def test_qat_mobile_train():
net = MobileNetV2(num_class=10)
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
label = Tensor(np.ones((1, 10)).astype(np.float32))
net = qat.convert_quant_network(
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
loss = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
optimizer = nn.Momentum(net.trainable_params(),
learning_rate=0.1, momentum=0.9)
net = nn.WithLossCell(net, loss)
net = nn.TrainOneStepCell(net, optimizer)
net(img, label)