fix quantization aware training auto create graph bug

This commit is contained in:
chenzomi 2020-06-29 17:59:08 +08:00
parent 6ef1a731db
commit c831d3eb60
26 changed files with 642 additions and 495 deletions

View File

@ -81,6 +81,7 @@ class Cell:
self.enable_hook = False self.enable_hook = False
self._bprop_debug = False self._bprop_debug = False
self._is_run = False self._is_run = False
self.cell_type = None
@property @property
def is_run(self): def is_run(self):
@ -140,6 +141,14 @@ class Cell:
for cell_name, cell in cells_name: for cell_name, cell in cells_name:
cell._param_prefix = cell_name cell._param_prefix = cell_name
def update_cell_type(self, cell_type):
"""
Update current cell type mainly identify if quantization aware training network.
After invoked, can set the cell type to 'cell_type'.
"""
self.cell_type = cell_type
@cell_init_args.setter @cell_init_args.setter
def cell_init_args(self, value): def cell_init_args(self, value):
if not isinstance(value, str): if not isinstance(value, str):

View File

@ -17,6 +17,7 @@ from mindspore import log as logger
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore._checkparam import ParamValidator as validator, Rel
from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative
from mindspore._extends import cell_attr_register from mindspore._extends import cell_attr_register
from ..cell import Cell from ..cell import Cell
@ -397,3 +398,150 @@ class Conv2dTranspose(_Conv):
self.weight, self.weight,
self.bias) self.bias)
return s return s
class DepthwiseConv2d(Cell):
r"""
2D depthwise convolution layer.
Applies a 2D depthwise convolution over an input tensor which is typically of shape:
math:`(N, C_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number.
For each batch of shape:math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as:
.. math::
out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to :math:`i`-th channel of the :math:`j`-th
filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and
:math:`\text{ks_w}` are height and width of the convolution kernel. The full kernel has shape
:math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number
to split the input in the channel dimension.
If the 'pad_mode' is set to be "valid", the output height and width will be
:math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
(\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
<http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the height
and width of the 2D convolution window. Single int means the value if for both height and width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel.
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the height and width of movement are both strides, or a tuple of two int numbers that
represent height and width of movement respectively. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "same".
- same: Adopts the way of completion. Output height and width will be the same as the input.
Total number of padding will be calculated for horizontal and vertical
direction and evenly distributed to top and bottom, left and right if possible. Otherwise, the
last extra padding will be done from the bottom and the right side. If this mode is set, `padding`
must be 0.
- valid: Adopts the way of discarding. The possibly largest height and width of output will be return
without padding. Extra pixels will be discarded. If this mode is set, `padding`
must be 0.
- pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
Tensor borders. `padding` should be greater than or equal to 0.
padding (int): Implicit paddings on both sides of the input. Default: 0.
dilation (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the dilation rate
to use for dilated convolution. If set to be :math:`k > 1`, there will
be :math:`k - 1` pixels skipped for each sampling location. Its value should
be greater or equal to 1 and bounded by the height and width of the
input. Default: 1.
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> net = nn.DepthwiseConv2d(120, 240, 4, has_bias=False, weight_init='normal')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape
(1, 240, 1024, 640)
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros'):
super(DepthwiseConv2d, self).__init__()
self.kernel_size = twice(kernel_size)
self.stride = twice(stride)
self.dilation = twice(dilation)
self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels)
validator.check_integer('group', group, in_channels, Rel.EQ)
validator.check_integer('group', group, out_channels, Rel.EQ)
validator.check_integer('group', group, 1, Rel.GE)
self.pad_mode = pad_mode
self.padding = padding
self.dilation = dilation
self.group = group
self.has_bias = has_bias
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation)
self.bias_add = P.BiasAdd()
weight_shape = [1, in_channels, *self.kernel_size]
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
if check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
else:
if bias_init != 'zeros':
logger.warning("value of `has_bias` is False, value of `bias_init` will be ignore.")
self.bias = None
def construct(self, x):
out = self.conv(x, self.weight)
if self.has_bias:
out = self.bias_add(out, self.bias)
return out
def extend_repr(self):
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
'pad_mode={}, padding={}, dilation={}, group={},' \
'has_bias={}, weight_init={}, bias_init={}'.format(
self.in_channels, self.out_channels, self.kernel_size, self.stride,
self.pad_mode, self.padding, self.dilation, self.group,
self.has_bias, self.weight_init, self.bias_init)
if self.has_bias:
s += ', bias={}'.format(self.bias)
return s

View File

@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Aware quantization.""" """Quantization aware."""
from functools import partial from functools import partial
import numpy as np import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
@ -23,10 +24,9 @@ from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore._checkparam import check_int_positive, check_bool, twice from mindspore._checkparam import check_int_positive, check_bool, twice
from mindspore._checkparam import Validator as validator, Rel from mindspore._checkparam import Rel
from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation
import mindspore.context as context import mindspore.context as context
from .normalization import BatchNorm2d from .normalization import BatchNorm2d
from .activation import get_activation from .activation import get_activation
from ..cell import Cell from ..cell import Cell
@ -82,7 +82,7 @@ class Conv2dBnAct(Cell):
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'. Initializer for more details. Default: 'zeros'.
batchnorm (bool): Specifies to used batchnorm or not. Default: None. has_bn (bool): Specifies to used batchnorm or not. Default: False.
activation (string): Specifies activation type. The optional values are as following: activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
@ -94,7 +94,7 @@ class Conv2dBnAct(Cell):
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples: Examples:
>>> net = Conv2dBnAct(120, 240, 4, batchnorm=True, activation='ReLU') >>> net = Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32) >>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape >>> net(input).shape
(1, 240, 1024, 640) (1, 240, 1024, 640)
@ -112,28 +112,39 @@ class Conv2dBnAct(Cell):
has_bias=False, has_bias=False,
weight_init='normal', weight_init='normal',
bias_init='zeros', bias_init='zeros',
batchnorm=None, has_bn=False,
activation=None): activation=None):
super(Conv2dBnAct, self).__init__() super(Conv2dBnAct, self).__init__()
self.conv = conv.Conv2d(
in_channels, if context.get_context('device_target') == "Ascend" and group > 1:
self.conv = conv.DepthwiseConv2d(in_channels,
out_channels, out_channels,
kernel_size, kernel_size=kernel_size,
stride, stride=stride,
pad_mode, pad_mode=pad_mode,
padding, padding=padding,
dilation, dilation=dilation,
group, group=group,
has_bias, has_bias=has_bias,
weight_init, weight_init=weight_init,
bias_init) bias_init=bias_init)
self.has_bn = batchnorm is not None else:
self.conv = conv.Conv2d(in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
pad_mode=pad_mode,
padding=padding,
dilation=dilation,
group=group,
has_bias=has_bias,
weight_init=weight_init,
bias_init=bias_init)
self.has_bn = has_bn
self.has_act = activation is not None self.has_act = activation is not None
self.batchnorm = batchnorm if has_bn:
if batchnorm is True:
self.batchnorm = BatchNorm2d(out_channels) self.batchnorm = BatchNorm2d(out_channels)
elif batchnorm is not None:
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
self.activation = get_activation(activation) self.activation = get_activation(activation)
def construct(self, x): def construct(self, x):
@ -160,7 +171,7 @@ class DenseBnAct(Cell):
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
batchnorm (bool): Specifies to used batchnorm or not. Default: None. has_bn (bool): Specifies to used batchnorm or not. Default: False.
activation (string): Specifies activation type. The optional values are as following: activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
@ -172,7 +183,7 @@ class DenseBnAct(Cell):
Tensor of shape :math:`(N, out\_channels)`. Tensor of shape :math:`(N, out\_channels)`.
Examples: Examples:
>>> net = nn.Dense(3, 4) >>> net = nn.DenseBnAct(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input) >>> net(input)
""" """
@ -183,7 +194,7 @@ class DenseBnAct(Cell):
weight_init='normal', weight_init='normal',
bias_init='zeros', bias_init='zeros',
has_bias=True, has_bias=True,
batchnorm=None, has_bn=False,
activation=None): activation=None):
super(DenseBnAct, self).__init__() super(DenseBnAct, self).__init__()
self.dense = basic.Dense( self.dense = basic.Dense(
@ -192,12 +203,10 @@ class DenseBnAct(Cell):
weight_init, weight_init,
bias_init, bias_init,
has_bias) has_bias)
self.has_bn = batchnorm is not None self.has_bn = has_bn
self.has_act = activation is not None self.has_act = activation is not None
if batchnorm is True: if has_bn:
self.batchnorm = BatchNorm2d(out_channels) self.batchnorm = BatchNorm2d(out_channels)
elif batchnorm is not None:
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
self.activation = get_activation(activation) self.activation = get_activation(activation)
def construct(self, x): def construct(self, x):
@ -271,20 +280,20 @@ class BatchNormFoldCell(Cell):
class FakeQuantWithMinMax(Cell): class FakeQuantWithMinMax(Cell):
r""" r"""
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. Quantization aware op. This OP provide Fake quantization observer function on data with min and max.
Args: Args:
min_init (int, float): The dimension of channel or 1(layer). Default: -6. min_init (int, float): The dimension of channel or 1(layer). Default: -6.
max_init (int, float): The dimension of channel or 1(layer). Default: 6. max_init (int, float): The dimension of channel or 1(layer). Default: 6.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
ema (bool): Exponential Moving Average algorithm update min and max. Default: False. ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
channel_axis (int): Quantization by channel axis. Default: 1. channel_axis (int): Quantization by channel axis. Default: 1.
out_channels (int): declarate the min and max channel size, Default: 1. num_channels (int): declarate the min and max channel size, Default: 1.
quant_delay (int): Quantization delay parameters according by global step. Default: 0. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs: Inputs:
- **x** (Tensor) - The input of FakeQuantWithMinMax. - **x** (Tensor) - The input of FakeQuantWithMinMax.
@ -301,24 +310,27 @@ class FakeQuantWithMinMax(Cell):
def __init__(self, def __init__(self,
min_init=-6, min_init=-6,
max_init=6, max_init=6,
num_bits=8,
ema=False, ema=False,
ema_decay=0.999, ema_decay=0.999,
per_channel=False, per_channel=False,
channel_axis=1, channel_axis=1,
out_channels=1, num_channels=1,
quant_delay=0, num_bits=8,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False,
quant_delay=0):
"""init FakeQuantWithMinMax layer""" """init FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__() super(FakeQuantWithMinMax, self).__init__()
validator.check_type("min_init", min_init, [int, float])
validator.check_type("max_init", max_init, [int, float])
validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
self.min_init = min_init self.min_init = min_init
self.max_init = max_init self.max_init = max_init
self.num_bits = num_bits self.num_bits = num_bits
self.ema = ema self.ema = ema
self.ema_decay = ema_decay self.ema_decay = ema_decay
self.per_channel = per_channel self.per_channel = per_channel
self.out_channels = out_channels self.num_channels = num_channels
self.channel_axis = channel_axis self.channel_axis = channel_axis
self.quant_delay = quant_delay self.quant_delay = quant_delay
self.symmetric = symmetric self.symmetric = symmetric
@ -327,54 +339,54 @@ class FakeQuantWithMinMax(Cell):
# init tensor min and max for fake quant op # init tensor min and max for fake quant op
if self.per_channel: if self.per_channel:
min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32) min_array = np.array([self.min_init] * self.num_channels).astype(np.float32)
max_array = np.array([self.max_init for i in range(0, self.out_channels)]).astype(np.float32) max_array = np.array([self.max_init] * self.num_channels).astype(np.float32)
else: else:
min_array = np.array([self.min_init]).reshape(1).astype(np.float32) min_array = np.array([self.min_init]).astype(np.float32)
max_array = np.array([self.max_init]).reshape(1).astype(np.float32) max_array = np.array([self.max_init]).astype(np.float32)
self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False) self.minq = Parameter(Tensor(min_array), name='quant_min', requires_grad=False)
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False) self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
# init fake quant relative op # init fake quant relative op
if per_channel: if per_channel:
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis) quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
ema_fun = partial(Q.FakeQuantMinMaxPerChannelUpdate, channel_axis=self.channel_axis) ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
else: else:
quant_fun = Q.FakeQuantPerLayer quant_fun = Q.FakeQuantPerLayer
ema_fun = Q.FakeQuantMinMaxPerLayerUpdate ema_fun = Q.MinMaxUpdatePerLayer
self.ema_update = ema_fun(ema=self.ema, ema_decay=self.ema_decay)
if self.is_ascend: if self.is_ascend:
self.fake_quant = quant_fun(num_bits=self.num_bits, self.fake_quant_train = quant_fun(num_bits=self.num_bits,
symmetric=self.symmetric, symmetric=self.symmetric,
narrow_range=self.narrow_range) narrow_range=self.narrow_range)
self.fake_quant_infer = self.fake_quant_train
else: else:
self.fake_quant = quant_fun(num_bits=self.num_bits, quant_fun = partial(quant_fun,
ema=self.ema, ema=self.ema,
ema_decay=ema_decay, ema_decay=ema_decay,
quant_delay=quant_delay, num_bits=self.num_bits,
symmetric=self.symmetric, symmetric=self.symmetric,
narrow_range=self.narrow_range) narrow_range=self.narrow_range,
self.ema_update = ema_fun(num_bits=self.num_bits, quant_delay=quant_delay)
ema=self.ema, self.fake_quant_train = quant_fun(training=True)
ema_decay=self.ema_decay, self.fake_quant_infer = quant_fun(training=False)
symmetric=self.symmetric,
narrow_range=self.narrow_range)
def extend_repr(self): def extend_repr(self):
s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \ s = 'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), ' \
'quant_delay={}, min_init={}, max_init={}'.format( 'quant_delay={}, min_init={}, max_init={}'.format(
self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel, self.num_bits, self.symmetric, self.narrow_range, self.ema, self.ema_decay, self.per_channel,
self.channel_axis, self.out_channels, self.quant_delay, self.min_init, self.max_init) self.channel_axis, self.num_channels, self.quant_delay, self.min_init, self.max_init)
return s return s
def construct(self, x): def construct(self, x):
if self.is_ascend and self.training: if self.training:
min_up, max_up = self.ema_update(x, self.minq, self.maxq) min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up) P.Assign()(self.minq, min_up)
P.Assign()(self.maxq, max_up) P.Assign()(self.maxq, max_up)
out = self.fake_quant_train(x, self.minq, self.maxq)
else: else:
out = self.fake_quant(x, self.minq, self.maxq) out = self.fake_quant_infer(x, self.minq, self.maxq)
return out return out
@ -391,8 +403,8 @@ class Conv2dBatchNormQuant(Cell):
stride (int): Specifies stride for all spatial dimensions with the same value. stride (int): Specifies stride for all spatial dimensions with the same value.
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding: (int): Implicit paddings on both sides of the input. Default: 0. padding: (int): Implicit paddings on both sides of the input. Default: 0.
eps (int): Parameters for BatchNormal. Default: 1e-5. eps (float): Parameters for BatchNormal. Default: 1e-5.
momentum (int): Parameters for BatchNormal op. Default: 0.997. momentum (float): Parameters for BatchNormal op. Default: 0.997.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
convolution kernel. Default: 'normal'. convolution kernel. Default: 'normal'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
@ -403,13 +415,13 @@ class Conv2dBatchNormQuant(Cell):
mean vector. Default: 'zeros'. mean vector. Default: 'zeros'.
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
variance vector. Default: 'ones'. variance vector. Default: 'ones'.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
Inputs: Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -440,13 +452,13 @@ class Conv2dBatchNormQuant(Cell):
gamma_init='ones', gamma_init='ones',
mean_init='zeros', mean_init='zeros',
var_init='ones', var_init='ones',
quant_delay=0,
freeze_bn=100000,
fake=True, fake=True,
num_bits=8,
per_channel=False, per_channel=False,
num_bits=8,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False,
quant_delay=0,
freeze_bn=100000):
"""init Conv2dBatchNormQuant layer""" """init Conv2dBatchNormQuant layer"""
super(Conv2dBatchNormQuant, self).__init__() super(Conv2dBatchNormQuant, self).__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -503,12 +515,13 @@ class Conv2dBatchNormQuant(Cell):
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6, max_init=6,
ema=False, ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel, per_channel=per_channel,
out_channels=out_channels, channel_axis=channel_axis,
num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range) narrow_range=narrow_range,
quant_delay=quant_delay)
self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn)
self.correct_mul = Q.CorrectionMul(channel_axis) self.correct_mul = Q.CorrectionMul(channel_axis)
if context.get_context('device_target') == "Ascend": if context.get_context('device_target') == "Ascend":
@ -582,11 +595,11 @@ class Conv2dQuant(Cell):
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
Default: 'normal'. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
quant_delay (int): Quantization delay parameters according by global step. Default: 0. per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs: Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -613,11 +626,11 @@ class Conv2dQuant(Cell):
has_bias=False, has_bias=False,
weight_init='normal', weight_init='normal',
bias_init='zeros', bias_init='zeros',
quant_delay=0,
num_bits=8,
per_channel=False, per_channel=False,
num_bits=8,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False,
quant_delay=0):
super(Conv2dQuant, self).__init__() super(Conv2dQuant, self).__init__()
if isinstance(kernel_size, int): if isinstance(kernel_size, int):
self.kernel_size = (kernel_size, kernel_size) self.kernel_size = (kernel_size, kernel_size)
@ -653,12 +666,13 @@ class Conv2dQuant(Cell):
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6, max_init=6,
ema=False, ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel, per_channel=per_channel,
out_channels=out_channels, channel_axis=0,
num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range) narrow_range=narrow_range,
quant_delay=quant_delay)
def construct(self, x): def construct(self, x):
weight = self.fake_quant_weight(self.weight) weight = self.fake_quant_weight(self.weight)
@ -692,11 +706,11 @@ class DenseQuant(Cell):
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs: Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
@ -718,19 +732,19 @@ class DenseQuant(Cell):
bias_init='zeros', bias_init='zeros',
has_bias=True, has_bias=True,
activation=None, activation=None,
num_bits=8,
quant_delay=0,
per_channel=False, per_channel=False,
num_bits=8,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False,
quant_delay=0):
super(DenseQuant, self).__init__() super(DenseQuant, self).__init__()
self.in_channels = check_int_positive(in_channels) self.in_channels = check_int_positive(in_channels)
self.out_channels = check_int_positive(out_channels) self.out_channels = check_int_positive(out_channels)
self.has_bias = check_bool(has_bias) self.has_bias = check_bool(has_bias)
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
weight_init.shape[1] != in_channels: weight_init.shape()[1] != in_channels:
raise ValueError("weight_init shape error") raise ValueError("weight_init shape error")
self.weight = Parameter(initializer( self.weight = Parameter(initializer(
@ -738,7 +752,7 @@ class DenseQuant(Cell):
if self.has_bias: if self.has_bias:
if isinstance(bias_init, Tensor): if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
raise ValueError("bias_init shape error") raise ValueError("bias_init shape error")
self.bias = Parameter(initializer( self.bias = Parameter(initializer(
@ -752,12 +766,13 @@ class DenseQuant(Cell):
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
max_init=6, max_init=6,
ema=False, ema=False,
num_bits=num_bits,
quant_delay=quant_delay,
per_channel=per_channel, per_channel=per_channel,
out_channels=out_channels, channel_axis=0,
num_channels=out_channels,
num_bits=num_bits,
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range) narrow_range=narrow_range,
quant_delay=quant_delay)
def construct(self, x): def construct(self, x):
"""Use operators to construct to Dense layer.""" """Use operators to construct to Dense layer."""
@ -780,13 +795,16 @@ class DenseQuant(Cell):
return str_info return str_info
class _QuantActivation(Cell): class _QuantActivation(Cell):
r""" r"""
Base class for Quant activation function. Add Fake Quant OP after activation OP. Base class for Quant activation function. Add Fake Quant OP after activation OP.
""" """
def get_origin(self): def get_origin(self):
raise NotImplementedError raise NotImplementedError
class ReLUQuant(_QuantActivation): class ReLUQuant(_QuantActivation):
r""" r"""
ReLUQuant activation function. Add Fake Quant OP after Relu OP. ReLUQuant activation function. Add Fake Quant OP after Relu OP.
@ -794,12 +812,12 @@ class ReLUQuant(_QuantActivation):
For a more Detailed overview of ReLU op. For a more Detailed overview of ReLU op.
Args: Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs: Inputs:
- **x** (Tensor) - The input of ReLUQuant. - **x** (Tensor) - The input of ReLUQuant.
@ -814,22 +832,22 @@ class ReLUQuant(_QuantActivation):
""" """
def __init__(self, def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999, ema_decay=0.999,
per_channel=False, per_channel=False,
num_bits=8,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False,
quant_delay=0):
super(ReLUQuant, self).__init__() super(ReLUQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=0, self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
max_init=6, max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True, ema=True,
per_channel=per_channel,
ema_decay=ema_decay, ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range) narrow_range=narrow_range,
quant_delay=quant_delay)
self.relu = P.ReLU() self.relu = P.ReLU()
def construct(self, x): def construct(self, x):
@ -850,12 +868,12 @@ class ReLU6Quant(_QuantActivation):
For a more Detailed overview of ReLU6 op. For a more Detailed overview of ReLU6 op.
Args: Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs: Inputs:
- **x** (Tensor) - The input of ReLU6Quant. - **x** (Tensor) - The input of ReLU6Quant.
@ -870,22 +888,22 @@ class ReLU6Quant(_QuantActivation):
""" """
def __init__(self, def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999, ema_decay=0.999,
per_channel=False, per_channel=False,
num_bits=8,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False,
quant_delay=0):
super(ReLU6Quant, self).__init__() super(ReLU6Quant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=0, self.fake_quant_act = FakeQuantWithMinMax(min_init=0,
max_init=6, max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True, ema=True,
per_channel=per_channel,
ema_decay=ema_decay, ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range) narrow_range=narrow_range,
quant_delay=quant_delay)
self.relu6 = P.ReLU6() self.relu6 = P.ReLU6()
def construct(self, x): def construct(self, x):
@ -896,6 +914,7 @@ class ReLU6Quant(_QuantActivation):
def get_origin(self): def get_origin(self):
return self.relu6 return self.relu6
class HSwishQuant(_QuantActivation): class HSwishQuant(_QuantActivation):
r""" r"""
HSwishQuant activation function. Add Fake Quant OP after HSwish OP. HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
@ -903,12 +922,12 @@ class HSwishQuant(_QuantActivation):
For a more Detailed overview of HSwish op. For a more Detailed overview of HSwish op.
Args: Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs: Inputs:
- **x** (Tensor) - The input of HSwishQuant. - **x** (Tensor) - The input of HSwishQuant.
@ -923,31 +942,31 @@ class HSwishQuant(_QuantActivation):
""" """
def __init__(self, def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999, ema_decay=0.999,
per_channel=False, per_channel=False,
num_bits=8,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False,
quant_delay=0):
super(HSwishQuant, self).__init__() super(HSwishQuant, self).__init__()
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
max_init=6, max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True, ema=True,
per_channel=per_channel,
ema_decay=ema_decay, ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range) narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
max_init=6, max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True, ema=True,
per_channel=per_channel,
ema_decay=ema_decay, ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range) narrow_range=narrow_range,
quant_delay=quant_delay)
self.act = P.HSwish() self.act = P.HSwish()
def construct(self, x): def construct(self, x):
@ -959,6 +978,7 @@ class HSwishQuant(_QuantActivation):
def get_origin(self): def get_origin(self):
return self.act return self.act
class HSigmoidQuant(_QuantActivation): class HSigmoidQuant(_QuantActivation):
r""" r"""
HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP. HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
@ -966,12 +986,12 @@ class HSigmoidQuant(_QuantActivation):
For a more Detailed overview of HSigmoid op. For a more Detailed overview of HSigmoid op.
Args: Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs: Inputs:
- **x** (Tensor) - The input of HSigmoidQuant. - **x** (Tensor) - The input of HSigmoidQuant.
@ -986,30 +1006,31 @@ class HSigmoidQuant(_QuantActivation):
""" """
def __init__(self, def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999, ema_decay=0.999,
per_channel=False, per_channel=False,
num_bits=8,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False,
quant_delay=0):
super(HSigmoidQuant, self).__init__() super(HSigmoidQuant, self).__init__()
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6, self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
max_init=6, max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True, ema=True,
ema_decay=ema_decay,
per_channel=per_channel, per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range) narrow_range=narrow_range,
quant_delay=quant_delay)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6, self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
max_init=6, max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True, ema=True,
per_channel=per_channel,
ema_decay=ema_decay, ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range) narrow_range=narrow_range,
quant_delay=quant_delay)
self.act = P.HSigmoid() self.act = P.HSigmoid()
def construct(self, x): def construct(self, x):
@ -1021,6 +1042,7 @@ class HSigmoidQuant(_QuantActivation):
def get_origin(self): def get_origin(self):
return self.act return self.act
class TensorAddQuant(Cell): class TensorAddQuant(Cell):
r""" r"""
Add Fake Quant OP after TensorAdd OP. Add Fake Quant OP after TensorAdd OP.
@ -1028,12 +1050,12 @@ class TensorAddQuant(Cell):
For a more Detailed overview of TensorAdd op. For a more Detailed overview of TensorAdd op.
Args: Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs: Inputs:
- **x** (Tensor) - The input of TensorAddQuant. - **x** (Tensor) - The input of TensorAddQuant.
@ -1049,22 +1071,22 @@ class TensorAddQuant(Cell):
""" """
def __init__(self, def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999, ema_decay=0.999,
per_channel=False, per_channel=False,
num_bits=8,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False,
quant_delay=0):
super(TensorAddQuant, self).__init__() super(TensorAddQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
max_init=6, max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True, ema=True,
per_channel=per_channel,
ema_decay=ema_decay, ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range) narrow_range=narrow_range,
quant_delay=quant_delay)
self.add = P.TensorAdd() self.add = P.TensorAdd()
def construct(self, x1, x2): def construct(self, x1, x2):
@ -1080,12 +1102,12 @@ class MulQuant(Cell):
For a more Detailed overview of Mul op. For a more Detailed overview of Mul op.
Args: Args:
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs: Inputs:
- **x** (Tensor) - The input of MulQuant. - **x** (Tensor) - The input of MulQuant.
@ -1096,22 +1118,22 @@ class MulQuant(Cell):
""" """
def __init__(self, def __init__(self,
num_bits=8,
quant_delay=0,
ema_decay=0.999, ema_decay=0.999,
per_channel=False, per_channel=False,
num_bits=8,
symmetric=False, symmetric=False,
narrow_range=False): narrow_range=False,
quant_delay=0):
super(MulQuant, self).__init__() super(MulQuant, self).__init__()
self.fake_quant_act = FakeQuantWithMinMax(min_init=-6, self.fake_quant_act = FakeQuantWithMinMax(min_init=-6,
max_init=6, max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True, ema=True,
per_channel=per_channel,
ema_decay=ema_decay, ema_decay=ema_decay,
per_channel=per_channel,
num_bits=num_bits,
symmetric=symmetric, symmetric=symmetric,
narrow_range=narrow_range) narrow_range=narrow_range,
quant_delay=quant_delay)
self.mul = P.Mul() self.mul = P.Mul()
def construct(self, x1, x2): def construct(self, x1, x2):
@ -1173,12 +1195,13 @@ class QuantBlock(Cell):
self.has_bias = bias is None self.has_bias = bias is None
self.activation = activation self.activation = activation
self.has_act = activation is None self.has_act = activation is None
self.bias_add = P.BiasAdd()
def construct(self, x): def construct(self, x):
x = self.quant(x) x = self.quant(x)
x = self.core_op(x, self.weight) x = self.core_op(x, self.weight)
if self.has_bias: if self.has_bias:
output = self.bias_add(output, self.bias) x = self.bias_add(x, self.bias)
if self.has_act: if self.has_act:
x = self.activation(x) x = self.activation(x)
x = self.dequant(x, self.dequant_scale) x = self.dequant(x, self.dequant_scale)

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Generate bprop for aware quantization ops""" """Generate bprop for quantization aware ops"""
from .. import operations as P from .. import operations as P
from ..operations import _quant_ops as Q from ..operations import _quant_ops as Q
@ -133,9 +133,9 @@ def get_bprop_batchnorm_fold2_(self):
return bprop return bprop
@bprop_getters.register(Q.FakeQuantMinMaxPerLayerUpdate) @bprop_getters.register(Q.MinMaxUpdatePerLayer)
def get_bprop_fakequant_with_minmax_per_layer_update(self): def get_bprop_fakequant_with_minmax_per_layer_update(self):
"""Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend""" """Generate bprop for MinMaxUpdatePerLayer for Ascend"""
def bprop(x, x_min, x_max, out, dout): def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max) return zeros_like(x), zeros_like(x_min), zeros_like(x_max)
@ -143,9 +143,9 @@ def get_bprop_fakequant_with_minmax_per_layer_update(self):
return bprop return bprop
@bprop_getters.register(Q.FakeQuantMinMaxPerChannelUpdate) @bprop_getters.register(Q.MinMaxUpdatePerChannel)
def get_bprop_fakequant_with_minmax_per_channel_update(self): def get_bprop_fakequant_with_minmax_per_channel_update(self):
"""Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend""" """Generate bprop for MinMaxUpdatePerChannel for Ascend"""
def bprop(x, x_min, x_max, out, dout): def bprop(x, x_min, x_max, out, dout):
return zeros_like(x), zeros_like(x_min), zeros_like(x_max) return zeros_like(x), zeros_like(x_min), zeros_like(x_max)

View File

@ -30,7 +30,6 @@ batchnorm_fold2_op_info = TBERegOp("BatchNormFold2_D") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("batchnorm_fold2") \ .kernel_name("batchnorm_fold2") \
.partial_flag(True) \ .partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "x", None, "required", None) \ .input(0, "x", None, "required", None) \
.input(1, "beta", None, "required", None) \ .input(1, "beta", None, "required", None) \
.input(2, "gamma", None, "required", None) \ .input(2, "gamma", None, "required", None) \

View File

@ -30,7 +30,6 @@ batchnorm_fold2_grad_op_info = TBERegOp("BatchNormFold2GradD") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("batchnorm_fold2_grad") \ .kernel_name("batchnorm_fold2_grad") \
.partial_flag(True) \ .partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "dout", None, "required", None) \ .input(0, "dout", None, "required", None) \
.input(1, "dout_reduce", None, "required", None) \ .input(1, "dout_reduce", None, "required", None) \
.input(2, "dout_x_reduce", None, "required", None) \ .input(2, "dout_x_reduce", None, "required", None) \

View File

@ -31,7 +31,6 @@ batchnorm_fold2_grad_reduce_op_info = TBERegOp("BatchNormFold2GradReduce") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("batchnorm_fold2_grad_reduce") \ .kernel_name("batchnorm_fold2_grad_reduce") \
.partial_flag(True) \ .partial_flag(True) \
.op_pattern("formatAgnostic") \
.input(0, "dout", None, "required", None) \ .input(0, "dout", None, "required", None) \
.input(1, "x", None, "required", None) \ .input(1, "x", None, "required", None) \
.output(0, "dout_reduce", True, "required", "all") \ .output(0, "dout_reduce", True, "required", "all") \

View File

@ -30,7 +30,6 @@ correction_mul_op_info = TBERegOp("CorrectionMul") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("correction_mul") \ .kernel_name("correction_mul") \
.partial_flag(True) \ .partial_flag(True) \
.op_pattern("formatAgnostic") \
.attr("channel_axis", "optional", "int", "all") \ .attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \ .input(0, "x", None, "required", None) \
.input(1, "batch_std", None, "required", None) \ .input(1, "batch_std", None, "required", None) \

View File

@ -30,7 +30,6 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("correction_mul_grad") \ .kernel_name("correction_mul_grad") \
.partial_flag(True) \ .partial_flag(True) \
.op_pattern("formatAgnostic") \
.attr("channel_axis", "optional", "int", "all") \ .attr("channel_axis", "optional", "int", "all") \
.input(0, "dout", None, "required", None) \ .input(0, "dout", None, "required", None) \
.input(1, "x", None, "required", None) \ .input(1, "x", None, "required", None) \
@ -128,7 +127,6 @@ correction_mul_grad_reduce_op_info = TBERegOp("CorrectionMulGradReduce") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("correction_mul_grad_reduce") \ .kernel_name("correction_mul_grad_reduce") \
.partial_flag(True) \ .partial_flag(True) \
.op_pattern("formatAgnostic") \
.attr("channel_axis", "optional", "int", "all") \ .attr("channel_axis", "optional", "int", "all") \
.input(0, "dout", None, "required", None) \ .input(0, "dout", None, "required", None) \
.output(0, "d_batch_std", True, "required", "all") \ .output(0, "d_batch_std", True, "required", "all") \

View File

@ -99,11 +99,15 @@ def fake_quant_perchannel(x, min_val, max_val, y,
min_dtype = min_val.get("dtype") min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape") max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype") max_dtype = max_val.get("dtype")
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
channel_axis_ = 1
else:
channel_axis_ = channel_axis
util.check_kernel_name(kernel_name) util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape) util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis]) util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis]) util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape) util.check_tensor_shape_size(max_shape)
@ -126,8 +130,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,
quant_min = quant_min + 1 quant_min = quant_min + 1
shape_c = [1] * len(x_shape) shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0] shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1: if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape") shape_c = min_val.get("shape")
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype) input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)

View File

@ -124,11 +124,15 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
min_dtype = min_val.get("dtype") min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape") max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype") max_dtype = max_val.get("dtype")
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
channel_axis_ = 1
else:
channel_axis_ = channel_axis
util.check_kernel_name(kernel_name) util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape) util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis]) util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis]) util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape) util.check_tensor_shape_size(max_shape)
@ -151,8 +155,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
quant_min = quant_min + 1 quant_min = quant_min + 1
shape_c = [1] * len(x_shape) shape_c = [1] * len(x_shape)
shape_c[channel_axis] = min_val.get("ori_shape")[0] shape_c[channel_axis_] = min_val.get("ori_shape")[0]
if x_format == "NC1HWC0" and channel_axis == 1: if x_format == "NC1HWC0" and channel_axis_ == 1:
shape_c = min_val.get("shape") shape_c = min_val.get("shape")
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype) dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype) input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)

View File

@ -1,4 +1,3 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -14,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""FakeQuantMinMaxPerChannelUpdate op""" """MinMaxUpdatePerChannel op"""
import te.lang.cce import te.lang.cce
from te import tvm from te import tvm
from te.platform.fusion_manager import fusion_manager from te.platform.fusion_manager import fusion_manager
@ -22,20 +21,15 @@ from topi import generic
from topi.cce import util from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
minmax_update_perchannel_op_info = TBERegOp("MinMaxUpdatePerChannel") \
fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChannelUpdate") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.async_flag(False) \ .async_flag(False) \
.binfile_name("fake_quant_min_max_per_channel_update.so") \ .binfile_name("minmax_update_perchannel.so") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("fake_quant_min_max_per_channel_update") \ .kernel_name("minmax_update_perchannel") \
.partial_flag(True) \ .partial_flag(True) \
.attr("ema", "optional", "bool", "all") \ .attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \ .attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.attr("channel_axis", "optional", "int", "all") \ .attr("channel_axis", "optional", "int", "all") \
.input(0, "x", None, "required", None) \ .input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \ .input(1, "min", None, "required", None) \
@ -47,24 +41,27 @@ fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChan
.get_op_info() .get_op_info()
@op_info_register(fake_quant_min_max_per_channel_update_op_info) @op_info_register(minmax_update_perchannel_op_info)
def _fake_quant_min_max_per_channel_update_tbe(): def _minmax_update_perchannel_tbe():
"""FakeQuantPerChannelUpdate TBE register""" """MinMaxUpdatePerChannel TBE register"""
return return
@fusion_manager.register("fake_quant_min_max_per_channel_update") @fusion_manager.register("minmax_update_perchannel")
def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val, def minmax_update_perchannel_compute(x, min_val, max_val,
ema, ema_decay, quant_min, quant_max, training, channel_axis, ema, ema_decay, channel_axis):
kernel_name="fake_quant_min_max_per_channel_update"): """MinMaxUpdatePerChannel compute"""
"""FakeQuantPerChannelUpdate compute"""
shape_min = te.lang.cce.util.shape_to_list(min_val.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
if not ema: if not ema:
ema_decay = 0.0 ema_decay = 0.0
if training:
# CalMinMax # CalMinMax
if channel_axis == 0:
axis = [1, 2, 3, 4]
else:
axis = [0, 2, 3] axis = [0, 2, 3]
x_min = te.lang.cce.reduce_min(x, axis=axis) x_min = te.lang.cce.reduce_min(x, axis=axis)
x_max = te.lang.cce.reduce_max(x, axis=axis) x_max = te.lang.cce.reduce_max(x, axis=axis)
x_min = te.lang.cce.broadcast(x_min, shape_min) x_min = te.lang.cce.broadcast(x_min, shape_min)
@ -79,11 +76,11 @@ def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val,
return [min_val, max_val] return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, int, str) @util.check_input_type(dict, dict, dict, dict, dict, bool, float, int, str)
def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, channel_axis, ema, ema_decay, channel_axis,
kernel_name="fake_quant_min_max_per_channel_update"): kernel_name="minmax_update_perchannel"):
"""FakeQuantPerLayer op""" """MinMaxUpdatePerChannel op"""
x_shape = x.get("ori_shape") x_shape = x.get("ori_shape")
x_format = x.get("format") x_format = x.get("format")
x_dtype = x.get("dtype") x_dtype = x.get("dtype")
@ -91,11 +88,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
min_dtype = min_val.get("dtype") min_dtype = min_val.get("dtype")
max_shape = max_val.get("ori_shape") max_shape = max_val.get("ori_shape")
max_dtype = max_val.get("dtype") max_dtype = max_val.get("dtype")
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if channel_axis == 0 and x_shape[0] != min_shape[0] and x_shape[1] == min_shape[0]:
channel_axis_ = 1
else:
channel_axis_ = channel_axis
util.check_kernel_name(kernel_name) util.check_kernel_name(kernel_name)
util.check_shape_rule(x_shape) util.check_shape_rule(x_shape)
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis]) util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis_])
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis]) util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis_])
util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(x_shape)
util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(min_shape)
util.check_tensor_shape_size(max_shape) util.check_tensor_shape_size(max_shape)
@ -108,21 +109,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
util.check_dtype_rule(min_dtype, check_list) util.check_dtype_rule(min_dtype, check_list)
util.check_dtype_rule(max_dtype, check_list) util.check_dtype_rule(max_dtype, check_list)
if symmetric: if channel_axis_ == 0:
quant_min = 0 - 2 ** (num_bits - 1) shape_c = min_val.get("ori_shape")
quant_max = 2 ** (num_bits - 1) - 1
else: else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]] shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype) input_data = tvm.placeholder(x.get("shape"), name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data, res_list = minmax_update_perchannel_compute(input_data, min_data, max_data,
ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name) ema, ema_decay, channel_axis_)
with tvm.target.cce(): with tvm.target.cce():
sch = generic.auto_schedule(res_list) sch = generic.auto_schedule(res_list)

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""FakeQuantMinMaxPerLayerUpdate op""" """MinMaxUpdatePerLayer op"""
from functools import reduce as functools_reduce from functools import reduce as functools_reduce
import te.lang.cce import te.lang.cce
from te import tvm from te import tvm
@ -22,20 +22,15 @@ from topi import generic
from topi.cce import util from topi.cce import util
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
minmax_update_perlayer_op_info = TBERegOp("MinMaxUpdatePerLayer") \
fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
.fusion_type("OPAQUE") \ .fusion_type("OPAQUE") \
.async_flag(False) \ .async_flag(False) \
.binfile_name("fake_quant_minmax_update.so") \ .binfile_name("minmax_update_perlayer.so") \
.compute_cost(10) \ .compute_cost(10) \
.kernel_name("fake_quant_minmax_update") \ .kernel_name("minmax_update_perlayer") \
.partial_flag(True) \ .partial_flag(True) \
.attr("ema", "optional", "bool", "all") \ .attr("ema", "optional", "bool", "all") \
.attr("ema_decay", "optional", "float", "all") \ .attr("ema_decay", "optional", "float", "all") \
.attr("symmetric", "optional", "bool", "all") \
.attr("narrow_range", "optional", "bool", "all") \
.attr("training", "optional", "bool", "all") \
.attr("num_bits", "optional", "int", "all") \
.input(0, "x", None, "required", None) \ .input(0, "x", None, "required", None) \
.input(1, "min", None, "required", None) \ .input(1, "min", None, "required", None) \
.input(2, "max", None, "required", None) \ .input(2, "max", None, "required", None) \
@ -46,23 +41,22 @@ fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
.get_op_info() .get_op_info()
@op_info_register(fake_quant_minmax_update_op_info) @op_info_register(minmax_update_perlayer_op_info)
def _fake_quant_minmax_update_tbe(): def _minmax_update_perlayer_tbe():
"""FakeQuantMinMaxPerLayerUpdate TBE register""" """MinMaxUpdatePerLayer TBE register"""
return return
@fusion_manager.register("fake_quant_minmax_update") @fusion_manager.register("minmax_update_perlayer")
def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_min, quant_max, training, def minmax_update_perlayer_compute(x, min_val, max_val, ema, ema_decay):
kernel_name="fake_quant_minmax_update"): """MinMaxUpdatePerLayer compute"""
"""FakeQuantMinMaxPerLayerUpdate compute"""
shape = te.lang.cce.util.shape_to_list(x.shape) shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype)
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype)
if not ema: if not ema:
ema_decay = 0.0 ema_decay = 0.0
if training:
# CalMinMax # CalMinMax
axis = tuple(range(len(shape))) axis = tuple(range(len(shape)))
x_min = te.lang.cce.reduce_min(x, axis=axis) x_min = te.lang.cce.reduce_min(x, axis=axis)
@ -79,11 +73,10 @@ def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_
return [min_val, max_val] return [min_val, max_val]
@util.check_input_type(dict, dict, dict, dict, dict, bool, float, bool, bool, bool, int, str) @util.check_input_type(dict, dict, dict, dict, dict, bool, float, str)
def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up, def minmax_update_perlayer(x, min_val, max_val, min_up, max_up,
ema, ema_decay, symmetric, narrow_range, training, num_bits, ema, ema_decay, kernel_name="minmax_update_perlayer"):
kernel_name="fake_quant_minmax_update"): """MinMaxUpdatePerLayer op"""
"""FakeQuantPerLayer op"""
input_shape = x.get("shape") input_shape = x.get("shape")
input_dtype = x.get("dtype") input_dtype = x.get("dtype")
min_shape = min_val.get("ori_shape") min_shape = min_val.get("ori_shape")
@ -112,20 +105,10 @@ def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
shape_min, _, _ = util.produce_shapes(min_shape, input_shape) shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0
quant_max = 2 ** num_bits - 1
if narrow_range:
quant_min = quant_min + 1
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res_list = fake_quant_minmax_update_compute(input_data, min_data, max_data, res_list = minmax_update_perlayer_compute(input_data, min_data, max_data, ema, ema_decay)
ema, ema_decay, quant_min, quant_max, training, kernel_name)
with tvm.target.cce(): with tvm.target.cce():
sch = generic.auto_schedule(res_list) sch = generic.auto_schedule(res_list)

View File

@ -21,12 +21,12 @@ from ..._checkparam import Rel
from ..primitive import PrimitiveWithInfer, prim_attr_register from ..primitive import PrimitiveWithInfer, prim_attr_register
from ...common import dtype as mstype from ...common import dtype as mstype
__all__ = ["FakeQuantPerLayer", __all__ = ["MinMaxUpdatePerLayer",
"MinMaxUpdatePerChannel",
"FakeQuantPerLayer",
"FakeQuantPerLayerGrad", "FakeQuantPerLayerGrad",
"FakeQuantPerChannel", "FakeQuantPerChannel",
"FakeQuantPerChannelGrad", "FakeQuantPerChannelGrad",
"FakeQuantMinMaxPerLayerUpdate",
"FakeQuantMinMaxPerChannelUpdate",
"BatchNormFold", "BatchNormFold",
"BatchNormFoldGrad", "BatchNormFoldGrad",
"CorrectionMul", "CorrectionMul",
@ -38,20 +38,141 @@ __all__ = ["FakeQuantPerLayer",
"BatchNormFoldGradD", "BatchNormFoldGradD",
"BatchNormFold2_D", "BatchNormFold2_D",
"BatchNormFold2GradD", "BatchNormFold2GradD",
"BatchNormFold2GradReduce", "BatchNormFold2GradReduce"
] ]
class MinMaxUpdatePerLayer(PrimitiveWithInfer):
r"""
Update min and max per layer.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, ema=False, ema_decay=0.999):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import minmax_update_perlayer
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class MinMaxUpdatePerChannel(PrimitiveWithInfer):
r"""
Update min and max per channel.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
channel_axis (int): Quantization by channel axis, support 0 and 1. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
support_x_rank = [2, 4]
@prim_attr_register
def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.channel_axis = validator.check_int_range(
'channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
if len(x_shape) not in self.support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.support_x_rank}'")
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class FakeQuantPerLayer(PrimitiveWithInfer): class FakeQuantPerLayer(PrimitiveWithInfer):
r""" r"""
Simulate the quantize and dequantize operations in training time. Simulate the quantize and dequantize operations in training time.
Args: Args:
num_bits (int) : Number bits for aware quantilization. Default: 8. num_bits (int) : Number bits for quantization aware. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False. ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
simulate aware quantize funcion. After delay step in training time begin simulate the aware simulate quantization aware funcion. After delay step in training time begin simulate the aware
quantize funcion. Default: 0. quantize funcion. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
@ -103,8 +224,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer( self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name) 'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type( self.quant_delay = validator.check_integer(
'quant_delay', quant_delay, (int,), self.name) 'quant_delay', quant_delay, 0, Rel.GE, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['out']) outputs=['out'])
@ -196,6 +317,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True. training (bool): Training the network or not. Default: True.
channel_axis (int): Quantization by channel axis, support 0 and 1. Default: 1.
Inputs: Inputs:
- **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor. - **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
@ -213,6 +335,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
>>> result = fake_quant(input_x, _min, _max) >>> result = fake_quant(input_x, _min, _max)
""" """
support_quant_bit = [4, 7, 8] support_quant_bit = [4, 7, 8]
support_x_rank = [2, 4]
@prim_attr_register @prim_attr_register
def __init__(self, def __init__(self,
@ -245,14 +368,15 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer( self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name) 'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_value_type( self.quant_delay = validator.check_integer(
'quant_delay', quant_delay, (int,), self.name) 'quant_delay', quant_delay, 0, Rel.GE, self.name)
self.channel_axis = validator.check_integer( self.channel_axis = validator.check_int_range(
'channel_axis', channel_axis, 0, Rel.GE, self.name) 'channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
def infer_shape(self, x_shape, min_shape, max_shape): def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name) if len(x_shape) not in self.support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.support_x_rank}'")
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer( validator.check_integer(
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name) "min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
@ -832,153 +956,3 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
def infer_dtype(self, dout_type, x_type): def infer_dtype(self, dout_type, x_type):
validator.check("dout type", dout_type, "x type", x_type) validator.check("dout type", dout_type, "x type", x_type)
return dout_type, dout_type return dout_type, dout_type
class FakeQuantMinMaxPerLayerUpdate(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perlayer_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'],
outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type
class FakeQuantMinMaxPerChannelUpdate(PrimitiveWithInfer):
r"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
@prim_attr_register
def __init__(self, num_bits=8, ema=False, ema_decay=0.999, symmetric=False, narrow_range=False,
training=True, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if context.get_context('device_target') == "Ascend":
from mindspore.ops._op_impl._custom_op import fake_quant_minmax_perchannel_update
if num_bits not in self.support_quant_bit:
raise ValueError(
f"For '{self.name}' attr \'num_bits\' is not support.")
if ema and not ema_decay:
raise ValueError(
f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.symmetric = validator.check_value_type(
'symmetric', symmetric, (bool,), self.name)
self.narrow_range = validator.check_value_type(
'narrow_range', narrow_range, (bool,), self.name)
self.training = validator.check_value_type(
'training', training, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.num_bits = validator.check_integer(
'num_bits', num_bits, 0, Rel.GT, self.name)
self.channel_axis = validator.check_integer(
'channel axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
min_shape), 1, Rel.EQ, self.name)
return min_shape, max_shape
def infer_dtype(self, x_type, min_type, max_type):
valid_types = (mstype.float16, mstype.float32)
validator.check_tensor_type_same(
{"x": x_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"min": min_type}, valid_types, self.name)
validator.check_tensor_type_same(
{"max": max_type}, valid_types, self.name)
return min_type, max_type

View File

@ -32,7 +32,6 @@ from ...ops.operations import _inner_ops as inner
from ...train import serialization from ...train import serialization
from . import quant_utils from . import quant_utils
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
nn.ReLU6: quant.ReLU6Quant, nn.ReLU6: quant.ReLU6Quant,
nn.HSigmoid: quant.HSigmoidQuant, nn.HSigmoid: quant.HSigmoidQuant,
@ -41,15 +40,14 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
class _AddFakeQuantInput(nn.Cell): class _AddFakeQuantInput(nn.Cell):
""" """
Add FakeQuant at input and output of the Network. Only support one input and one output case. Add FakeQuant OP at input of the network. Only support one input case.
""" """
def __init__(self, network, quant_delay=0): def __init__(self, network, quant_delay=0):
super(_AddFakeQuantInput, self).__init__(auto_prefix=False) super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
self.network = network self.fake_quant_input = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
self.fake_quant_input = quant.FakeQuantWithMinMax(
min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
self.fake_quant_input.update_parameters_name('fake_quant_input') self.fake_quant_input.update_parameters_name('fake_quant_input')
self.network = network
def construct(self, data): def construct(self, data):
data = self.fake_quant_input(data) data = self.fake_quant_input(data)
@ -59,7 +57,7 @@ class _AddFakeQuantInput(nn.Cell):
class _AddFakeQuantAfterSubCell(nn.Cell): class _AddFakeQuantAfterSubCell(nn.Cell):
""" """
Add FakeQuant after of the sub Cell. Add FakeQuant OP after of the sub Cell.
""" """
def __init__(self, subcell, **kwargs): def __init__(self, subcell, **kwargs):
@ -114,11 +112,12 @@ class ConvertToQuantNetwork:
self.network.update_cell_prefix() self.network.update_cell_prefix()
network = self._convert_subcells2quant(self.network) network = self._convert_subcells2quant(self.network)
network = _AddFakeQuantInput(network) network = _AddFakeQuantInput(network)
self.network.update_cell_type("quant")
return network return network
def _convert_subcells2quant(self, network): def _convert_subcells2quant(self, network):
""" """
convet sub cell to quant cell convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell
""" """
cells = network.name_cells() cells = network.name_cells()
change = False change = False
@ -137,19 +136,19 @@ class ConvertToQuantNetwork:
if isinstance(network, nn.SequentialCell) and change: if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells()) network.cell_list = list(network.cells())
# tensoradd to tensoradd quant # add FakeQuant OP after OP in while list
add_list = [] add_list = []
for name in network.__dict__: for name in network.__dict__:
if name[0] == '_': if name[0] == '_':
continue continue
attr = network.__dict__[name] attr = network.__dict__[name]
if isinstance(attr, ops.Primitive) and attr.name in ConvertToQuantNetwork.__quant_op_name__: if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__:
add_list.append((name, attr)) add_list.append((name, attr))
for name, prim_op in add_list: for name, prim_op in add_list:
prefix = name prefix = name
add_quant = _AddFakeQuantAfterSubCell(prim_op, add_quant = _AddFakeQuantAfterSubCell(prim_op,
num_bits=self.act_bits, num_bits=self.act_bits,
quant_delay=self.act_delay, quant_delay=self.act_qdelay,
per_channel=self.act_channel, per_channel=self.act_channel,
symmetric=self.act_symmetric, symmetric=self.act_symmetric,
narrow_range=self.act_range) narrow_range=self.act_range)
@ -161,11 +160,11 @@ class ConvertToQuantNetwork:
def _convert_conv(self, subcell): def _convert_conv(self, subcell):
""" """
convet conv cell to quant cell convert Conv2d cell to quant cell
""" """
conv_inner = subcell.conv conv_inner = subcell.conv
bn_inner = subcell.batchnorm bn_inner = subcell.batchnorm
if subcell.batchnorm is not None and self.bn_fold: if subcell.has_bn and self.bn_fold:
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels, conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
conv_inner.out_channels, conv_inner.out_channels,
kernel_size=conv_inner.kernel_size, kernel_size=conv_inner.kernel_size,
@ -175,7 +174,7 @@ class ConvertToQuantNetwork:
dilation=conv_inner.dilation, dilation=conv_inner.dilation,
group=conv_inner.group, group=conv_inner.group,
eps=bn_inner.eps, eps=bn_inner.eps,
momentum=bn_inner.momentum, momentum=1 - bn_inner.momentum,
quant_delay=self.weight_qdelay, quant_delay=self.weight_qdelay,
freeze_bn=self.freeze_bn, freeze_bn=self.freeze_bn,
per_channel=self.weight_channel, per_channel=self.weight_channel,
@ -183,6 +182,11 @@ class ConvertToQuantNetwork:
fake=True, fake=True,
symmetric=self.weight_symmetric, symmetric=self.weight_symmetric,
narrow_range=self.weight_range) narrow_range=self.weight_range)
# change original network BatchNormal OP parameters to quant network
conv_inner.gamma = subcell.batchnorm.gamma
conv_inner.beta = subcell.batchnorm.beta
conv_inner.moving_mean = subcell.batchnorm.moving_mean
conv_inner.moving_variance = subcell.batchnorm.moving_variance
del subcell.batchnorm del subcell.batchnorm
subcell.batchnorm = None subcell.batchnorm = None
subcell.has_bn = False subcell.has_bn = False
@ -201,6 +205,10 @@ class ConvertToQuantNetwork:
num_bits=self.weight_bits, num_bits=self.weight_bits,
symmetric=self.weight_symmetric, symmetric=self.weight_symmetric,
narrow_range=self.weight_range) narrow_range=self.weight_range)
# change original network Conv2D OP parameters to quant network
conv_inner.weight = subcell.conv.weight
if subcell.conv.has_bias:
conv_inner.bias = subcell.conv.bias
subcell.conv = conv_inner subcell.conv = conv_inner
if subcell.has_act and subcell.activation is not None: if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation) subcell.activation = self._convert_activation(subcell.activation)
@ -227,6 +235,10 @@ class ConvertToQuantNetwork:
per_channel=self.weight_channel, per_channel=self.weight_channel,
symmetric=self.weight_symmetric, symmetric=self.weight_symmetric,
narrow_range=self.weight_range) narrow_range=self.weight_range)
# change original network Dense OP parameters to quant network
dense_inner.weight = subcell.dense.weight
if subcell.dense.has_bias:
dense_inner.bias = subcell.dense.bias
subcell.dense = dense_inner subcell.dense = dense_inner
if subcell.has_act and subcell.activation is not None: if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation) subcell.activation = self._convert_activation(subcell.activation)
@ -234,7 +246,7 @@ class ConvertToQuantNetwork:
subcell.has_act = True subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity, subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
num_bits=self.act_bits, num_bits=self.act_bits,
quant_delay=self.act_delay, quant_delay=self.act_qdelay,
per_channel=self.act_channel, per_channel=self.act_channel,
symmetric=self.act_symmetric, symmetric=self.act_symmetric,
narrow_range=self.act_range) narrow_range=self.act_range)
@ -244,12 +256,12 @@ class ConvertToQuantNetwork:
act_class = activation.__class__ act_class = activation.__class__
if act_class not in _ACTIVATION_MAP: if act_class not in _ACTIVATION_MAP:
raise ValueError( raise ValueError(
"Unsupported activation in auto Quant: ", act_class) "Unsupported activation in auto quant: ", act_class)
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, return _ACTIVATION_MAP[act_class](num_bits=self.act_bits,
quant_delay=self.act_qdelay, quant_delay=self.act_qdelay,
per_channel=self.act_channel, per_channel=self.act_channel,
symmetric=self.weight_symmetric, symmetric=self.act_symmetric,
narrow_range=self.weight_range) narrow_range=self.act_range)
class ExportQuantNetworkDeploy: class ExportQuantNetworkDeploy:
@ -403,7 +415,7 @@ def convert_quant_network(network,
narrow_range=(False, False) narrow_range=(False, False)
): ):
r""" r"""
Create aware quantizaiton training network. Create quantization aware training network.
Args: Args:
network (Cell): Obtain a pipeline through network for saving graph summary. network (Cell): Obtain a pipeline through network for saving graph summary.
@ -417,7 +429,7 @@ def convert_quant_network(network,
then base on per channel otherwise base on per layer. The first element represent weights then base on per channel otherwise base on per layer. The first element represent weights
and second element represent data flow. Default: [False, False] and second element represent data flow. Default: [False, False]
symmetric (list of bool): Quantization algorithm use symmetric or not. If `True` then base on symmetric (list of bool): Quantization algorithm use symmetric or not. If `True` then base on
symmetric otherwise base on assymmetric. The first element represent weights and second symmetric otherwise base on asymmetric. The first element represent weights and second
element represent data flow. Default: [False, False] element represent data flow. Default: [False, False]
narrow_range (list of bool): Quantization algorithm use narrow range or not. If `True` then base narrow_range (list of bool): Quantization algorithm use narrow range or not. If `True` then base
on narrow range otherwise base on off narrow range. The first element represent weights and on narrow range otherwise base on off narrow range. The first element represent weights and
@ -426,6 +438,7 @@ def convert_quant_network(network,
Returns: Returns:
Cell, Network which has change to aware quantization training network cell. Cell, Network which has change to aware quantization training network cell.
""" """
def convert2list(name, value): def convert2list(name, value):
if not isinstance(value, list) and not isinstance(value, tuple): if not isinstance(value, list) and not isinstance(value, tuple):
value = [value] value = [value]

View File

@ -2,13 +2,13 @@
## Description ## Description
Training LeNet with MNIST dataset in MindSpore with quantization aware trainging. Training LeNet with MNIST dataset in MindSpore with quantization aware training.
This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware. This is the simple and basic tutorial for constructing a network in MindSpore with quantization aware.
In this tutorial, you will: In this tutorial, you will:
1. Train a Mindspore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`. 1. Train a MindSpore fusion model for MNIST from scratch using `nn.Conv2dBnAct` and `nn.DenseBnAct`.
2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file. 2. Fine tune the fusion model by applying the quantization aware training auto network converter API `convert_quant_network`, after the network convergence then export a quantization aware model checkpoint file.
3. Use the quantization aware model to create an actually quantized model for the Ascend inference backend. 3. Use the quantization aware model to create an actually quantized model for the Ascend inference backend.
4. See the persistence of accuracy in inference backend and a 4x smaller model. To see the latency benefits on mobile, try out the Ascend inference backend examples. 4. See the persistence of accuracy in inference backend and a 4x smaller model. To see the latency benefits on mobile, try out the Ascend inference backend examples.
@ -24,10 +24,10 @@ Install MindSpore base on the ascend device and GPU device from [MindSpore](http
```python ```python
pip uninstall -y mindspore-ascend pip uninstall -y mindspore-ascend
pip uninstall -y mindspore-gpu pip uninstall -y mindspore-gpu
pip install mindspore-ascend-0.4.0.whl pip install mindspore-ascend.whl
``` ```
then you will get the following display Then you will get the following display
```bash ```bash
@ -87,7 +87,7 @@ class LeNet5(nn.Cell):
return x return x
``` ```
get the MNIST from scratch dataset. Get the MNIST from scratch dataset.
```Python ```Python
ds_train = create_dataset(os.path.join(args.data_path, "train"), ds_train = create_dataset(os.path.join(args.data_path, "train"),
@ -97,7 +97,7 @@ step_size = ds_train.get_dataset_size()
### Train model ### Train model
Load teh Lenet fusion network, traing network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`. Load the Lenet fusion network, training network using loss `nn.SoftmaxCrossEntropyWithLogits` with optimization `nn.Momentum`.
```Python ```Python
# Define the network # Define the network
@ -133,7 +133,7 @@ After all the following we will get the loss value of each step as following:
>>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] >>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
``` ```
To save your time, just run this command. Also, you can just run this command instead.
```python ```python
python train.py --data_path MNIST_Data --device_target Ascend python train.py --data_path MNIST_Data --device_target Ascend
@ -165,17 +165,17 @@ Note that the resulting model is quantization aware but not quantized (e.g. the
# define funsion network # define funsion network
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# load aware quantizaiton network checkpoint # load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path) param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
# convert funsion netwrok to aware quantizaiton network # convert funsion netwrok to quantization aware network
network = quant.convert_quant_network(network) network = quant.convert_quant_network(network)
``` ```
### load checkpoint ### load checkpoint
after convert to quantization aware network, we can load the checkpoint file. After convert to quantization aware network, we can load the checkpoint file.
```python ```python
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
### train quantization aware model ### train quantization aware model
To save your time, just run this command. Also, you can just run this command instead.
```python ```python
python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
@ -210,7 +210,7 @@ Procedure of quantization aware model evaluation is different from normal. Becau
# define funsion network # define funsion network
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# load aware quantizaiton network checkpoint # load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path) param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
@ -218,10 +218,10 @@ load_param_into_net(network, param_dict)
network = quant.convert_quant_network(network) network = quant.convert_quant_network(network)
``` ```
To save your time, just run this command. Also, you can just run this command insread.
```python ```python
python eval.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt python eval_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
``` ```
The top1 accuracy would display on shell. The top1 accuracy would display on shell.
@ -235,7 +235,7 @@ The top1 accuracy would display on shell.
Here are some optional parameters: Here are some optional parameters:
```bash ```bash
--device_target {Ascend,GPU,CPU} --device_target {Ascend,GPU}
device where the code will be implemented (default: Ascend) device where the code will be implemented (default: Ascend)
--data_path DATA_PATH --data_path DATA_PATH
path where the dataset is saved path where the dataset is saved

View File

@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend", parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'], choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data", parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved') help='path where the dataset is saved')

View File

@ -32,7 +32,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend", parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'], choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data", parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved') help='path where the dataset is saved')
@ -61,7 +61,7 @@ if __name__ == "__main__":
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
# load quantization aware network checkpoint # load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path, model_type="quant") param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
print("============== Starting Testing ==============") print("============== Starting Testing ==============")

View File

@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend", parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'], choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data", parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved') help='path where the dataset is saved')
@ -56,8 +56,7 @@ if __name__ == "__main__":
# call back and monitor # call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max, keep_checkpoint_max=cfg.keep_checkpoint_max)
model_type=network.type)
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
# define model # define model

View File

@ -33,7 +33,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend", parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU', 'CPU'], choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data", parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved') help='path where the dataset is saved')
@ -50,11 +50,13 @@ if __name__ == "__main__":
# define fusion network # define fusion network
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# load quantization aware network checkpoint # load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path, network.type) param_dict = load_checkpoint(args.ckpt_path, network.type)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# define network loss # define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
@ -64,8 +66,7 @@ if __name__ == "__main__":
# call back and monitor # call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max, keep_checkpoint_max=cfg.keep_checkpoint_max)
model_type="quant")
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
# define model # define model

View File

@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export DEVICE_ID=0 export DEVICE_ID=0
export RANK_ID=0 export RANK_ID=0
export RANK_SIZE=1 export RANK_SIZE=1
if [ -d "eval" ]; if [ -d "../eval" ];
then then
rm -rf ../eval rm -rf ../eval
fi fi

View File

@ -62,7 +62,7 @@ run_gpu()
BASEPATH=$(cd "`dirname $0`" || exit; pwd) BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "train" ]; if [ -d "../train" ];
then then
rm -rf ../train rm -rf ../train
fi fi

View File

@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export DEVICE_ID=0 export DEVICE_ID=0
export RANK_ID=0 export RANK_ID=0
export RANK_SIZE=1 export RANK_SIZE=1
if [ -d "eval" ]; if [ -d "../eval" ];
then then
rm -rf ../eval rm -rf ../eval
fi fi

View File

@ -60,7 +60,7 @@ run_gpu()
BASEPATH=$(cd "`dirname $0`" || exit; pwd) BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "train" ]; if [ -d "../train" ];
then then
rm -rf ../train rm -rf ../train
fi fi

View File

@ -17,7 +17,7 @@ def _conv_bn(in_channel,
out_channel, out_channel,
kernel_size=ksize, kernel_size=ksize,
stride=stride, stride=stride,
batchnorm=True)]) has_bn=True)])
class InvertedResidual(nn.Cell): class InvertedResidual(nn.Cell):
@ -35,25 +35,25 @@ class InvertedResidual(nn.Cell):
3, 3,
stride, stride,
group=hidden_dim, group=hidden_dim,
batchnorm=True, has_bn=True,
activation='relu6'), activation='relu6'),
nn.Conv2dBnAct(hidden_dim, oup, 1, 1, nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
batchnorm=True) has_bn=True)
]) ])
else: else:
self.conv = nn.SequentialCell([ self.conv = nn.SequentialCell([
nn.Conv2dBnAct(inp, hidden_dim, 1, 1, nn.Conv2dBnAct(inp, hidden_dim, 1, 1,
batchnorm=True, has_bn=True,
activation='relu6'), activation='relu6'),
nn.Conv2dBnAct(hidden_dim, nn.Conv2dBnAct(hidden_dim,
hidden_dim, hidden_dim,
3, 3,
stride, stride,
group=hidden_dim, group=hidden_dim,
batchnorm=True, has_bn=True,
activation='relu6'), activation='relu6'),
nn.Conv2dBnAct(hidden_dim, oup, 1, 1, nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
batchnorm=True) has_bn=True)
]) ])
self.add = P.TensorAdd() self.add = P.TensorAdd()

View File

@ -42,7 +42,7 @@ class LeNet5(nn.Cell):
def __init__(self, num_class=10): def __init__(self, num_class=10):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.num_class = num_class self.num_class = num_class
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6', pad_mode="valid") self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu6', pad_mode="valid")
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid") self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid")
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
self.fc2 = nn.DenseBnAct(120, 84, activation='relu') self.fc2 = nn.DenseBnAct(120, 84, activation='relu')