forked from mindspore-Ecosystem/mindspore
!2718 fix quantization aware training auto create graph bug
Merge pull request !2718 from chenzhongming/master
This commit is contained in:
commit
f1a9a7ceb1
|
@ -22,7 +22,6 @@ message Checkpoint {
|
||||||
required TensorProto tensor = 2;
|
required TensorProto tensor = 2;
|
||||||
}
|
}
|
||||||
repeated Value value = 1;
|
repeated Value value = 1;
|
||||||
required string model_type = 2;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -81,6 +81,7 @@ class Cell:
|
||||||
self.enable_hook = False
|
self.enable_hook = False
|
||||||
self._bprop_debug = False
|
self._bprop_debug = False
|
||||||
self._is_run = False
|
self._is_run = False
|
||||||
|
self.cell_type = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_run(self):
|
def is_run(self):
|
||||||
|
@ -140,6 +141,14 @@ class Cell:
|
||||||
for cell_name, cell in cells_name:
|
for cell_name, cell in cells_name:
|
||||||
cell._param_prefix = cell_name
|
cell._param_prefix = cell_name
|
||||||
|
|
||||||
|
def update_cell_type(self, cell_type):
|
||||||
|
"""
|
||||||
|
Update current cell type mainly identify if quantization aware training network.
|
||||||
|
|
||||||
|
After invoked, can set the cell type to 'cell_type'.
|
||||||
|
"""
|
||||||
|
self.cell_type = cell_type
|
||||||
|
|
||||||
@cell_init_args.setter
|
@cell_init_args.setter
|
||||||
def cell_init_args(self, value):
|
def cell_init_args(self, value):
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
|
|
|
@ -17,11 +17,12 @@ from mindspore import log as logger
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
|
from mindspore._checkparam import ParamValidator as validator, Rel
|
||||||
from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative
|
from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
|
|
||||||
__all__ = ['Conv2d', 'Conv2dTranspose']
|
__all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d']
|
||||||
|
|
||||||
class _Conv(Cell):
|
class _Conv(Cell):
|
||||||
"""
|
"""
|
||||||
|
@ -397,3 +398,150 @@ class Conv2dTranspose(_Conv):
|
||||||
self.weight,
|
self.weight,
|
||||||
self.bias)
|
self.bias)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class DepthwiseConv2d(Cell):
|
||||||
|
r"""
|
||||||
|
2D depthwise convolution layer.
|
||||||
|
|
||||||
|
Applies a 2D depthwise convolution over an input tensor which is typically of shape:
|
||||||
|
math:`(N, C_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number.
|
||||||
|
For each batch of shape:math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
|
||||||
|
|
||||||
|
where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
|
||||||
|
from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to :math:`i`-th channel of the :math:`j`-th
|
||||||
|
filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
|
||||||
|
of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and
|
||||||
|
:math:`\text{ks_w}` are height and width of the convolution kernel. The full kernel has shape
|
||||||
|
:math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number
|
||||||
|
to split the input in the channel dimension.
|
||||||
|
|
||||||
|
If the 'pad_mode' is set to be "valid", the output height and width will be
|
||||||
|
:math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
|
||||||
|
(\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
|
||||||
|
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
|
||||||
|
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
|
||||||
|
|
||||||
|
The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
|
||||||
|
<http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channel :math:`C_{in}`.
|
||||||
|
out_channels (int): The number of output channel :math:`C_{out}`.
|
||||||
|
kernel_size (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the height
|
||||||
|
and width of the 2D convolution window. Single int means the value if for both height and width of
|
||||||
|
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
|
||||||
|
width of the kernel.
|
||||||
|
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
|
||||||
|
the height and width of movement are both strides, or a tuple of two int numbers that
|
||||||
|
represent height and width of movement respectively. Default: 1.
|
||||||
|
pad_mode (str): Specifies padding mode. The optional values are
|
||||||
|
"same", "valid", "pad". Default: "same".
|
||||||
|
|
||||||
|
- same: Adopts the way of completion. Output height and width will be the same as the input.
|
||||||
|
Total number of padding will be calculated for horizontal and vertical
|
||||||
|
direction and evenly distributed to top and bottom, left and right if possible. Otherwise, the
|
||||||
|
last extra padding will be done from the bottom and the right side. If this mode is set, `padding`
|
||||||
|
must be 0.
|
||||||
|
|
||||||
|
- valid: Adopts the way of discarding. The possibly largest height and width of output will be return
|
||||||
|
without padding. Extra pixels will be discarded. If this mode is set, `padding`
|
||||||
|
must be 0.
|
||||||
|
|
||||||
|
- pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
|
||||||
|
Tensor borders. `padding` should be greater than or equal to 0.
|
||||||
|
|
||||||
|
padding (int): Implicit paddings on both sides of the input. Default: 0.
|
||||||
|
dilation (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the dilation rate
|
||||||
|
to use for dilated convolution. If set to be :math:`k > 1`, there will
|
||||||
|
be :math:`k - 1` pixels skipped for each sampling location. Its value should
|
||||||
|
be greater or equal to 1 and bounded by the height and width of the
|
||||||
|
input. Default: 1.
|
||||||
|
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
|
||||||
|
divisible by the number of groups. Default: 1.
|
||||||
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||||
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||||
|
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
|
||||||
|
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
|
||||||
|
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
|
||||||
|
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
|
||||||
|
Initializer for more details. Default: 'normal'.
|
||||||
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
|
||||||
|
Initializer and string are the same as 'weight_init'. Refer to the values of
|
||||||
|
Initializer for more details. Default: 'zeros'.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = nn.DepthwiseConv2d(120, 240, 4, has_bias=False, weight_init='normal')
|
||||||
|
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||||
|
>>> net(input).shape
|
||||||
|
(1, 240, 1024, 640)
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
pad_mode='same',
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
group=1,
|
||||||
|
has_bias=False,
|
||||||
|
weight_init='normal',
|
||||||
|
bias_init='zeros'):
|
||||||
|
super(DepthwiseConv2d, self).__init__()
|
||||||
|
self.kernel_size = twice(kernel_size)
|
||||||
|
self.stride = twice(stride)
|
||||||
|
self.dilation = twice(dilation)
|
||||||
|
self.in_channels = check_int_positive(in_channels)
|
||||||
|
self.out_channels = check_int_positive(out_channels)
|
||||||
|
validator.check_integer('group', group, in_channels, Rel.EQ)
|
||||||
|
validator.check_integer('group', group, out_channels, Rel.EQ)
|
||||||
|
validator.check_integer('group', group, 1, Rel.GE)
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
self.padding = padding
|
||||||
|
self.dilation = dilation
|
||||||
|
self.group = group
|
||||||
|
self.has_bias = has_bias
|
||||||
|
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
pad_mode=self.pad_mode,
|
||||||
|
pad=self.padding,
|
||||||
|
stride=self.stride,
|
||||||
|
dilation=self.dilation)
|
||||||
|
self.bias_add = P.BiasAdd()
|
||||||
|
weight_shape = [1, in_channels, *self.kernel_size]
|
||||||
|
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
|
||||||
|
if check_bool(has_bias):
|
||||||
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
||||||
|
else:
|
||||||
|
if bias_init != 'zeros':
|
||||||
|
logger.warning("value of `has_bias` is False, value of `bias_init` will be ignore.")
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
out = self.conv(x, self.weight)
|
||||||
|
if self.has_bias:
|
||||||
|
out = self.bias_add(out, self.bias)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def extend_repr(self):
|
||||||
|
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
|
||||||
|
'pad_mode={}, padding={}, dilation={}, group={},' \
|
||||||
|
'has_bias={}, weight_init={}, bias_init={}'.format(
|
||||||
|
self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
||||||
|
self.pad_mode, self.padding, self.dilation, self.group,
|
||||||
|
self.has_bias, self.weight_init, self.bias_init)
|
||||||
|
|
||||||
|
if self.has_bias:
|
||||||
|
s += ', bias={}'.format(self.bias)
|
||||||
|
return s
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import mindspore.common.dtype as mstype
|
import mindspore.common.dtype as mstype
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
from mindspore.ops import functional as F
|
from mindspore.ops import functional as F
|
||||||
|
@ -23,10 +24,9 @@ from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore.common.tensor import Tensor
|
from mindspore.common.tensor import Tensor
|
||||||
from mindspore._checkparam import check_int_positive, check_bool, twice
|
from mindspore._checkparam import check_int_positive, check_bool, twice
|
||||||
from mindspore._checkparam import Validator as validator, Rel
|
from mindspore._checkparam import Rel
|
||||||
from mindspore.nn.cell import Cell
|
|
||||||
from mindspore.nn.layer.activation import get_activation
|
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
|
|
||||||
from .normalization import BatchNorm2d
|
from .normalization import BatchNorm2d
|
||||||
from .activation import get_activation
|
from .activation import get_activation
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
|
@ -82,7 +82,7 @@ class Conv2dBnAct(Cell):
|
||||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
|
||||||
Initializer and string are the same as 'weight_init'. Refer to the values of
|
Initializer and string are the same as 'weight_init'. Refer to the values of
|
||||||
Initializer for more details. Default: 'zeros'.
|
Initializer for more details. Default: 'zeros'.
|
||||||
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
has_bn (bool): Specifies to used batchnorm or not. Default: False.
|
||||||
activation (string): Specifies activation type. The optional values are as following:
|
activation (string): Specifies activation type. The optional values are as following:
|
||||||
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||||
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||||
|
@ -94,7 +94,7 @@ class Conv2dBnAct(Cell):
|
||||||
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> net = Conv2dBnAct(120, 240, 4, batchnorm=True, activation='ReLU')
|
>>> net = Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU')
|
||||||
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||||
>>> net(input).shape
|
>>> net(input).shape
|
||||||
(1, 240, 1024, 640)
|
(1, 240, 1024, 640)
|
||||||
|
@ -112,28 +112,39 @@ class Conv2dBnAct(Cell):
|
||||||
has_bias=False,
|
has_bias=False,
|
||||||
weight_init='normal',
|
weight_init='normal',
|
||||||
bias_init='zeros',
|
bias_init='zeros',
|
||||||
batchnorm=None,
|
has_bn=False,
|
||||||
activation=None):
|
activation=None):
|
||||||
super(Conv2dBnAct, self).__init__()
|
super(Conv2dBnAct, self).__init__()
|
||||||
self.conv = conv.Conv2d(
|
|
||||||
in_channels,
|
if context.get_context('device_target') == "Ascend" and group > 1:
|
||||||
out_channels,
|
self.conv = conv.DepthwiseConv2d(in_channels,
|
||||||
kernel_size,
|
out_channels,
|
||||||
stride,
|
kernel_size=kernel_size,
|
||||||
pad_mode,
|
stride=stride,
|
||||||
padding,
|
pad_mode=pad_mode,
|
||||||
dilation,
|
padding=padding,
|
||||||
group,
|
dilation=dilation,
|
||||||
has_bias,
|
group=group,
|
||||||
weight_init,
|
has_bias=has_bias,
|
||||||
bias_init)
|
weight_init=weight_init,
|
||||||
self.has_bn = batchnorm is not None
|
bias_init=bias_init)
|
||||||
|
else:
|
||||||
|
self.conv = conv.Conv2d(in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
padding=padding,
|
||||||
|
dilation=dilation,
|
||||||
|
group=group,
|
||||||
|
has_bias=has_bias,
|
||||||
|
weight_init=weight_init,
|
||||||
|
bias_init=bias_init)
|
||||||
|
|
||||||
|
self.has_bn = validator.check_bool("has_bn", has_bn)
|
||||||
self.has_act = activation is not None
|
self.has_act = activation is not None
|
||||||
self.batchnorm = batchnorm
|
if has_bn:
|
||||||
if batchnorm is True:
|
|
||||||
self.batchnorm = BatchNorm2d(out_channels)
|
self.batchnorm = BatchNorm2d(out_channels)
|
||||||
elif batchnorm is not None:
|
|
||||||
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
|
|
||||||
self.activation = get_activation(activation)
|
self.activation = get_activation(activation)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
@ -160,7 +171,7 @@ class DenseBnAct(Cell):
|
||||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||||
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||||
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
has_bn (bool): Specifies to used batchnorm or not. Default: False.
|
||||||
activation (string): Specifies activation type. The optional values are as following:
|
activation (string): Specifies activation type. The optional values are as following:
|
||||||
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||||
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||||
|
@ -183,7 +194,7 @@ class DenseBnAct(Cell):
|
||||||
weight_init='normal',
|
weight_init='normal',
|
||||||
bias_init='zeros',
|
bias_init='zeros',
|
||||||
has_bias=True,
|
has_bias=True,
|
||||||
batchnorm=None,
|
has_bn=False,
|
||||||
activation=None):
|
activation=None):
|
||||||
super(DenseBnAct, self).__init__()
|
super(DenseBnAct, self).__init__()
|
||||||
self.dense = basic.Dense(
|
self.dense = basic.Dense(
|
||||||
|
@ -192,12 +203,10 @@ class DenseBnAct(Cell):
|
||||||
weight_init,
|
weight_init,
|
||||||
bias_init,
|
bias_init,
|
||||||
has_bias)
|
has_bias)
|
||||||
self.has_bn = batchnorm is not None
|
self.has_bn = validator.check_bool("has_bn", has_bn)
|
||||||
self.has_act = activation is not None
|
self.has_act = activation is not None
|
||||||
if batchnorm is True:
|
if has_bn:
|
||||||
self.batchnorm = BatchNorm2d(out_channels)
|
self.batchnorm = BatchNorm2d(out_channels)
|
||||||
elif batchnorm is not None:
|
|
||||||
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
|
|
||||||
self.activation = get_activation(activation)
|
self.activation = get_activation(activation)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
|
@ -312,6 +321,10 @@ class FakeQuantWithMinMax(Cell):
|
||||||
quant_delay=0):
|
quant_delay=0):
|
||||||
"""init FakeQuantWithMinMax layer"""
|
"""init FakeQuantWithMinMax layer"""
|
||||||
super(FakeQuantWithMinMax, self).__init__()
|
super(FakeQuantWithMinMax, self).__init__()
|
||||||
|
validator.check_type("min_init", min_init, [int, float])
|
||||||
|
validator.check_type("max_init", max_init, [int, float])
|
||||||
|
validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
|
||||||
|
validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
|
||||||
self.min_init = min_init
|
self.min_init = min_init
|
||||||
self.max_init = max_init
|
self.max_init = max_init
|
||||||
self.num_bits = num_bits
|
self.num_bits = num_bits
|
||||||
|
@ -1183,12 +1196,13 @@ class QuantBlock(Cell):
|
||||||
self.has_bias = bias is None
|
self.has_bias = bias is None
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
self.has_act = activation is None
|
self.has_act = activation is None
|
||||||
|
self.bias_add = P.BiasAdd()
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x = self.quant(x)
|
x = self.quant(x)
|
||||||
x = self.core_op(x, self.weight)
|
x = self.core_op(x, self.weight)
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
output = self.bias_add(output, self.bias)
|
x = self.bias_add(x, self.bias)
|
||||||
if self.has_act:
|
if self.has_act:
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
x = self.dequant(x, self.dequant_scale)
|
x = self.dequant(x, self.dequant_scale)
|
||||||
|
|
|
@ -30,7 +30,6 @@ batchnorm_fold2_op_info = TBERegOp("BatchNormFold2_D") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("batchnorm_fold2") \
|
.kernel_name("batchnorm_fold2") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "x", None, "required", None) \
|
.input(0, "x", None, "required", None) \
|
||||||
.input(1, "beta", None, "required", None) \
|
.input(1, "beta", None, "required", None) \
|
||||||
.input(2, "gamma", None, "required", None) \
|
.input(2, "gamma", None, "required", None) \
|
||||||
|
|
|
@ -30,7 +30,6 @@ batchnorm_fold2_grad_op_info = TBERegOp("BatchNormFold2GradD") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("batchnorm_fold2_grad") \
|
.kernel_name("batchnorm_fold2_grad") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "dout", None, "required", None) \
|
.input(0, "dout", None, "required", None) \
|
||||||
.input(1, "dout_reduce", None, "required", None) \
|
.input(1, "dout_reduce", None, "required", None) \
|
||||||
.input(2, "dout_x_reduce", None, "required", None) \
|
.input(2, "dout_x_reduce", None, "required", None) \
|
||||||
|
|
|
@ -31,7 +31,6 @@ batchnorm_fold2_grad_reduce_op_info = TBERegOp("BatchNormFold2GradReduce") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("batchnorm_fold2_grad_reduce") \
|
.kernel_name("batchnorm_fold2_grad_reduce") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.input(0, "dout", None, "required", None) \
|
.input(0, "dout", None, "required", None) \
|
||||||
.input(1, "x", None, "required", None) \
|
.input(1, "x", None, "required", None) \
|
||||||
.output(0, "dout_reduce", True, "required", "all") \
|
.output(0, "dout_reduce", True, "required", "all") \
|
||||||
|
|
|
@ -30,7 +30,6 @@ correction_mul_op_info = TBERegOp("CorrectionMul") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("correction_mul") \
|
.kernel_name("correction_mul") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.attr("channel_axis", "optional", "int", "all") \
|
.attr("channel_axis", "optional", "int", "all") \
|
||||||
.input(0, "x", None, "required", None) \
|
.input(0, "x", None, "required", None) \
|
||||||
.input(1, "batch_std", None, "required", None) \
|
.input(1, "batch_std", None, "required", None) \
|
||||||
|
|
|
@ -30,7 +30,6 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("correction_mul_grad") \
|
.kernel_name("correction_mul_grad") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.attr("channel_axis", "optional", "int", "all") \
|
.attr("channel_axis", "optional", "int", "all") \
|
||||||
.input(0, "dout", None, "required", None) \
|
.input(0, "dout", None, "required", None) \
|
||||||
.input(1, "x", None, "required", None) \
|
.input(1, "x", None, "required", None) \
|
||||||
|
@ -128,7 +127,6 @@ correction_mul_grad_reduce_op_info = TBERegOp("CorrectionMulGradReduce") \
|
||||||
.compute_cost(10) \
|
.compute_cost(10) \
|
||||||
.kernel_name("correction_mul_grad_reduce") \
|
.kernel_name("correction_mul_grad_reduce") \
|
||||||
.partial_flag(True) \
|
.partial_flag(True) \
|
||||||
.op_pattern("formatAgnostic") \
|
|
||||||
.attr("channel_axis", "optional", "int", "all") \
|
.attr("channel_axis", "optional", "int", "all") \
|
||||||
.input(0, "dout", None, "required", None) \
|
.input(0, "dout", None, "required", None) \
|
||||||
.output(0, "d_batch_std", True, "required", "all") \
|
.output(0, "d_batch_std", True, "required", "all") \
|
||||||
|
|
|
@ -99,11 +99,15 @@ def fake_quant_perchannel(x, min_val, max_val, y,
|
||||||
min_dtype = min_val.get("dtype")
|
min_dtype = min_val.get("dtype")
|
||||||
max_shape = max_val.get("ori_shape")
|
max_shape = max_val.get("ori_shape")
|
||||||
max_dtype = max_val.get("dtype")
|
max_dtype = max_val.get("dtype")
|
||||||
|
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
|
||||||
|
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
|
||||||
|
channel_axis_ = 1
|
||||||
|
else:
|
||||||
|
channel_axis_ = channel_axis
|
||||||
util.check_kernel_name(kernel_name)
|
util.check_kernel_name(kernel_name)
|
||||||
util.check_shape_rule(x_shape)
|
util.check_shape_rule(x_shape)
|
||||||
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
|
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
|
||||||
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
|
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
|
||||||
util.check_tensor_shape_size(x_shape)
|
util.check_tensor_shape_size(x_shape)
|
||||||
util.check_tensor_shape_size(min_shape)
|
util.check_tensor_shape_size(min_shape)
|
||||||
util.check_tensor_shape_size(max_shape)
|
util.check_tensor_shape_size(max_shape)
|
||||||
|
@ -126,8 +130,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,
|
||||||
quant_min = quant_min + 1
|
quant_min = quant_min + 1
|
||||||
|
|
||||||
shape_c = [1] * len(x_shape)
|
shape_c = [1] * len(x_shape)
|
||||||
shape_c[channel_axis] = min_val.get("ori_shape")[0]
|
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||||
if x_format == "NC1HWC0" and channel_axis == 1:
|
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||||
shape_c = min_val.get("shape")
|
shape_c = min_val.get("shape")
|
||||||
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
||||||
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
||||||
|
|
|
@ -124,11 +124,15 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
||||||
min_dtype = min_val.get("dtype")
|
min_dtype = min_val.get("dtype")
|
||||||
max_shape = max_val.get("ori_shape")
|
max_shape = max_val.get("ori_shape")
|
||||||
max_dtype = max_val.get("dtype")
|
max_dtype = max_val.get("dtype")
|
||||||
|
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
|
||||||
|
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
|
||||||
|
channel_axis_ = 1
|
||||||
|
else:
|
||||||
|
channel_axis_ = channel_axis
|
||||||
util.check_kernel_name(kernel_name)
|
util.check_kernel_name(kernel_name)
|
||||||
util.check_shape_rule(x_shape)
|
util.check_shape_rule(x_shape)
|
||||||
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
|
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
|
||||||
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
|
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
|
||||||
util.check_tensor_shape_size(x_shape)
|
util.check_tensor_shape_size(x_shape)
|
||||||
util.check_tensor_shape_size(min_shape)
|
util.check_tensor_shape_size(min_shape)
|
||||||
util.check_tensor_shape_size(max_shape)
|
util.check_tensor_shape_size(max_shape)
|
||||||
|
@ -151,8 +155,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
||||||
quant_min = quant_min + 1
|
quant_min = quant_min + 1
|
||||||
|
|
||||||
shape_c = [1] * len(x_shape)
|
shape_c = [1] * len(x_shape)
|
||||||
shape_c[channel_axis] = min_val.get("ori_shape")[0]
|
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||||
if x_format == "NC1HWC0" and channel_axis == 1:
|
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||||
shape_c = min_val.get("shape")
|
shape_c = min_val.get("shape")
|
||||||
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
|
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
|
||||||
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
||||||
|
|
|
@ -88,11 +88,15 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
|
||||||
min_dtype = min_val.get("dtype")
|
min_dtype = min_val.get("dtype")
|
||||||
max_shape = max_val.get("ori_shape")
|
max_shape = max_val.get("ori_shape")
|
||||||
max_dtype = max_val.get("dtype")
|
max_dtype = max_val.get("dtype")
|
||||||
|
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
|
||||||
|
if channel_axis == 0 and x_shape[0] != min_shape[0] and x_shape[1] == min_shape[0]:
|
||||||
|
channel_axis_ = 1
|
||||||
|
else:
|
||||||
|
channel_axis_ = channel_axis
|
||||||
util.check_kernel_name(kernel_name)
|
util.check_kernel_name(kernel_name)
|
||||||
util.check_shape_rule(x_shape)
|
util.check_shape_rule(x_shape)
|
||||||
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
|
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis_])
|
||||||
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
|
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis_])
|
||||||
util.check_tensor_shape_size(x_shape)
|
util.check_tensor_shape_size(x_shape)
|
||||||
util.check_tensor_shape_size(min_shape)
|
util.check_tensor_shape_size(min_shape)
|
||||||
util.check_tensor_shape_size(max_shape)
|
util.check_tensor_shape_size(max_shape)
|
||||||
|
@ -105,7 +109,7 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
|
||||||
util.check_dtype_rule(min_dtype, check_list)
|
util.check_dtype_rule(min_dtype, check_list)
|
||||||
util.check_dtype_rule(max_dtype, check_list)
|
util.check_dtype_rule(max_dtype, check_list)
|
||||||
|
|
||||||
if channel_axis == 0:
|
if channel_axis_ == 0:
|
||||||
shape_c = min_val.get("ori_shape")
|
shape_c = min_val.get("ori_shape")
|
||||||
else:
|
else:
|
||||||
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
|
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
|
||||||
|
@ -113,7 +117,7 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
|
||||||
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
||||||
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
|
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
|
||||||
res_list = minmax_update_perchannel_compute(input_data, min_data, max_data,
|
res_list = minmax_update_perchannel_compute(input_data, min_data, max_data,
|
||||||
ema, ema_decay, channel_axis)
|
ema, ema_decay, channel_axis_)
|
||||||
|
|
||||||
with tvm.target.cce():
|
with tvm.target.cce():
|
||||||
sch = generic.auto_schedule(res_list)
|
sch = generic.auto_schedule(res_list)
|
||||||
|
|
|
@ -106,7 +106,7 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
|
||||||
Args:
|
Args:
|
||||||
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
||||||
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||||
channel_axis (int): Channel asis for per channel compute. Default: 1.
|
channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
|
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
|
||||||
|
@ -123,11 +123,13 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
|
||||||
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
|
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
|
||||||
"""
|
"""
|
||||||
support_quant_bit = [4, 7, 8]
|
support_quant_bit = [4, 7, 8]
|
||||||
|
ascend_support_x_rank = [2, 4]
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
|
def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
|
||||||
"""init FakeQuantPerChannelUpdate OP for Ascend"""
|
"""init FakeQuantPerChannelUpdate OP for Ascend"""
|
||||||
if context.get_context('device_target') == "Ascend":
|
self.is_ascend = context.get_context('device_target') == "Ascend"
|
||||||
|
if self.is_ascend:
|
||||||
from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
|
from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
|
||||||
if ema and not ema_decay:
|
if ema and not ema_decay:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -136,13 +138,18 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
|
||||||
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
||||||
self.ema_decay = validator.check_number_range(
|
self.ema_decay = validator.check_number_range(
|
||||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||||
self.channel_axis = validator.check_integer(
|
if self.is_ascend:
|
||||||
'channel axis', channel_axis, 0, Rel.GE, self.name)
|
self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
|
||||||
|
else:
|
||||||
|
self.channel_axis = validator.check_integer('channel_axis', channel_axis, 0, Rel.GE, self.name)
|
||||||
self.init_prim_io_names(
|
self.init_prim_io_names(
|
||||||
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
|
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
|
||||||
|
|
||||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
|
if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
|
||||||
|
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
|
||||||
|
if not self.is_ascend:
|
||||||
|
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||||
validator.check("min shape", min_shape, "max shape",
|
validator.check("min shape", min_shape, "max shape",
|
||||||
max_shape, Rel.EQ, self.name)
|
max_shape, Rel.EQ, self.name)
|
||||||
validator.check_integer("min shape", len(
|
validator.check_integer("min shape", len(
|
||||||
|
@ -221,8 +228,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
|
||||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||||
self.num_bits = validator.check_integer(
|
self.num_bits = validator.check_integer(
|
||||||
'num_bits', num_bits, 0, Rel.GT, self.name)
|
'num_bits', num_bits, 0, Rel.GT, self.name)
|
||||||
self.quant_delay = validator.check_value_type(
|
self.quant_delay = validator.check_integer(
|
||||||
'quant_delay', quant_delay, (int,), self.name)
|
'quant_delay', quant_delay, 0, Rel.GE, self.name)
|
||||||
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
||||||
outputs=['out'])
|
outputs=['out'])
|
||||||
|
|
||||||
|
@ -314,6 +321,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
||||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
training (bool): Training the network or not. Default: True.
|
training (bool): Training the network or not. Default: True.
|
||||||
|
channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
- **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
|
- **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
|
||||||
|
@ -331,6 +339,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
||||||
>>> result = fake_quant(input_x, _min, _max)
|
>>> result = fake_quant(input_x, _min, _max)
|
||||||
"""
|
"""
|
||||||
support_quant_bit = [4, 7, 8]
|
support_quant_bit = [4, 7, 8]
|
||||||
|
ascend_support_x_rank = [2, 4]
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -343,7 +352,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
||||||
training=True,
|
training=True,
|
||||||
channel_axis=1):
|
channel_axis=1):
|
||||||
"""init FakeQuantPerChannel OP"""
|
"""init FakeQuantPerChannel OP"""
|
||||||
if context.get_context('device_target') == "Ascend":
|
self.is_ascend = context.get_context('device_target') == "Ascend"
|
||||||
|
if self.is_ascend:
|
||||||
from mindspore.ops._op_impl._custom_op import fake_quant_perchannel
|
from mindspore.ops._op_impl._custom_op import fake_quant_perchannel
|
||||||
if num_bits not in self.support_quant_bit:
|
if num_bits not in self.support_quant_bit:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -363,14 +373,19 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
||||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||||
self.num_bits = validator.check_integer(
|
self.num_bits = validator.check_integer(
|
||||||
'num_bits', num_bits, 0, Rel.GT, self.name)
|
'num_bits', num_bits, 0, Rel.GT, self.name)
|
||||||
self.quant_delay = validator.check_value_type(
|
self.quant_delay = validator.check_integer(
|
||||||
'quant_delay', quant_delay, (int,), self.name)
|
'quant_delay', quant_delay, 0, Rel.GE, self.name)
|
||||||
self.channel_axis = validator.check_integer(
|
if self.is_ascend:
|
||||||
'channel_axis', channel_axis, 0, Rel.GE, self.name)
|
self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
|
||||||
|
else:
|
||||||
|
self.channel_axis = validator.check_integer('channel_axis', channel_axis, 0, Rel.GE, self.name)
|
||||||
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
|
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
|
||||||
|
|
||||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
|
||||||
|
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
|
||||||
|
if not self.is_ascend:
|
||||||
|
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||||
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
|
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
|
||||||
validator.check_integer(
|
validator.check_integer(
|
||||||
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
|
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
|
||||||
|
|
|
@ -21,7 +21,7 @@ import time
|
||||||
|
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
from mindspore import log as logger
|
from mindspore import log as logger
|
||||||
from mindspore._checkparam import check_bool, check_string, check_int_non_negative
|
from mindspore._checkparam import check_bool, check_int_non_negative
|
||||||
from mindspore.train._utils import _make_directory
|
from mindspore.train._utils import _make_directory
|
||||||
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
|
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
|
||||||
from ._callback import Callback, set_cur_net
|
from ._callback import Callback, set_cur_net
|
||||||
|
@ -86,7 +86,6 @@ class CheckpointConfig:
|
||||||
Can't be used with keep_checkpoint_max at the same time.
|
Can't be used with keep_checkpoint_max at the same time.
|
||||||
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
|
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
|
||||||
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
|
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
|
||||||
model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal".
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the input_param is None or 0.
|
ValueError: If the input_param is None or 0.
|
||||||
|
@ -101,8 +100,7 @@ class CheckpointConfig:
|
||||||
save_checkpoint_seconds=0,
|
save_checkpoint_seconds=0,
|
||||||
keep_checkpoint_max=5,
|
keep_checkpoint_max=5,
|
||||||
keep_checkpoint_per_n_minutes=0,
|
keep_checkpoint_per_n_minutes=0,
|
||||||
integrated_save=True,
|
integrated_save=True):
|
||||||
model_type="normal"):
|
|
||||||
|
|
||||||
if not save_checkpoint_steps and not save_checkpoint_seconds and \
|
if not save_checkpoint_steps and not save_checkpoint_seconds and \
|
||||||
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
||||||
|
@ -116,8 +114,6 @@ class CheckpointConfig:
|
||||||
keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
|
keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
|
||||||
if keep_checkpoint_per_n_minutes:
|
if keep_checkpoint_per_n_minutes:
|
||||||
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
|
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
|
||||||
if model_type:
|
|
||||||
model_type = check_string(model_type, ["normal", "fusion", "quant"])
|
|
||||||
|
|
||||||
self._save_checkpoint_steps = save_checkpoint_steps
|
self._save_checkpoint_steps = save_checkpoint_steps
|
||||||
self._save_checkpoint_seconds = save_checkpoint_seconds
|
self._save_checkpoint_seconds = save_checkpoint_seconds
|
||||||
|
@ -132,7 +128,6 @@ class CheckpointConfig:
|
||||||
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
|
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
|
||||||
self._keep_checkpoint_max = 1
|
self._keep_checkpoint_max = 1
|
||||||
|
|
||||||
self._model_type = model_type
|
|
||||||
self._integrated_save = check_bool(integrated_save)
|
self._integrated_save = check_bool(integrated_save)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -160,18 +155,12 @@ class CheckpointConfig:
|
||||||
"""Get the value of _integrated_save."""
|
"""Get the value of _integrated_save."""
|
||||||
return self._integrated_save
|
return self._integrated_save
|
||||||
|
|
||||||
@property
|
|
||||||
def model_type(self):
|
|
||||||
"""Get the value of model_type."""
|
|
||||||
return self._model_type
|
|
||||||
|
|
||||||
def get_checkpoint_policy(self):
|
def get_checkpoint_policy(self):
|
||||||
"""Get the policy of checkpoint."""
|
"""Get the policy of checkpoint."""
|
||||||
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
|
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
|
||||||
'save_checkpoint_seconds': self._save_checkpoint_seconds,
|
'save_checkpoint_seconds': self._save_checkpoint_seconds,
|
||||||
'keep_checkpoint_max': self._keep_checkpoint_max,
|
'keep_checkpoint_max': self._keep_checkpoint_max,
|
||||||
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes,
|
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes}
|
||||||
'model_type': self._model_type}
|
|
||||||
|
|
||||||
return checkpoint_policy
|
return checkpoint_policy
|
||||||
|
|
||||||
|
@ -236,7 +225,7 @@ class ModelCheckpoint(Callback):
|
||||||
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
|
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
|
||||||
_save_graph(cb_params.train_network, graph_file_name)
|
_save_graph(cb_params.train_network, graph_file_name)
|
||||||
self._graph_saved = True
|
self._graph_saved = True
|
||||||
self._save_ckpt(cb_params, self._config.model_type)
|
self._save_ckpt(cb_params)
|
||||||
|
|
||||||
def end(self, run_context):
|
def end(self, run_context):
|
||||||
"""
|
"""
|
||||||
|
@ -247,7 +236,7 @@ class ModelCheckpoint(Callback):
|
||||||
"""
|
"""
|
||||||
cb_params = run_context.original_args()
|
cb_params = run_context.original_args()
|
||||||
_to_save_last_ckpt = True
|
_to_save_last_ckpt = True
|
||||||
self._save_ckpt(cb_params, self._config.model_type, _to_save_last_ckpt)
|
self._save_ckpt(cb_params, _to_save_last_ckpt)
|
||||||
|
|
||||||
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
||||||
destroy_allgather_cell()
|
destroy_allgather_cell()
|
||||||
|
@ -266,7 +255,7 @@ class ModelCheckpoint(Callback):
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _save_ckpt(self, cb_params, model_type, force_to_save=False):
|
def _save_ckpt(self, cb_params, force_to_save=False):
|
||||||
"""Save checkpoint files."""
|
"""Save checkpoint files."""
|
||||||
if cb_params.cur_step_num == self._last_triggered_step:
|
if cb_params.cur_step_num == self._last_triggered_step:
|
||||||
return
|
return
|
||||||
|
@ -302,7 +291,7 @@ class ModelCheckpoint(Callback):
|
||||||
set_cur_net(cb_params.train_network)
|
set_cur_net(cb_params.train_network)
|
||||||
cb_params.train_network.exec_checkpoint_graph()
|
cb_params.train_network.exec_checkpoint_graph()
|
||||||
|
|
||||||
_exec_save_checkpoint(cb_params.train_network, gen_file, model_type, self._config.integrated_save)
|
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
|
||||||
|
|
||||||
if os.path.exists(gen_file):
|
if os.path.exists(gen_file):
|
||||||
shutil.move(gen_file, cur_file)
|
shutil.move(gen_file, cur_file)
|
||||||
|
|
|
@ -86,7 +86,7 @@ class LossMonitor(Callback):
|
||||||
|
|
||||||
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
||||||
print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
|
print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
|
||||||
"loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format(
|
"loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}]".format(
|
||||||
cb_params.cur_epoch_num, cb_params.epoch_num,
|
cb_params.cur_epoch_num, cb_params.epoch_num,
|
||||||
cur_step_in_epoch, int(cb_params.batch_num),
|
cur_step_in_epoch, int(cb_params.batch_num),
|
||||||
step_loss, np.mean(self.losses),
|
step_loss, np.mean(self.losses),
|
||||||
|
|
|
@ -42,15 +42,14 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
|
||||||
|
|
||||||
class _AddFakeQuantInput(nn.Cell):
|
class _AddFakeQuantInput(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Add FakeQuant at input and output of the Network. Only support one input and one output case.
|
Add FakeQuant OP at input of the network. Only support one input case.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, network, quant_delay=0):
|
def __init__(self, network, quant_delay=0):
|
||||||
super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
|
super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
|
||||||
|
self.fake_quant_input = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
|
||||||
|
self.fake_quant_input.update_parameters_name('fake_quant_input.')
|
||||||
self.network = network
|
self.network = network
|
||||||
self.fake_quant_input = quant.FakeQuantWithMinMax(
|
|
||||||
min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
|
|
||||||
self.fake_quant_input.update_parameters_name('fake_quant_input')
|
|
||||||
|
|
||||||
def construct(self, data):
|
def construct(self, data):
|
||||||
data = self.fake_quant_input(data)
|
data = self.fake_quant_input(data)
|
||||||
|
@ -60,7 +59,7 @@ class _AddFakeQuantInput(nn.Cell):
|
||||||
|
|
||||||
class _AddFakeQuantAfterSubCell(nn.Cell):
|
class _AddFakeQuantAfterSubCell(nn.Cell):
|
||||||
"""
|
"""
|
||||||
Add FakeQuant after of the sub Cell.
|
Add FakeQuant OP after of the sub Cell.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, subcell, **kwargs):
|
def __init__(self, subcell, **kwargs):
|
||||||
|
@ -115,11 +114,12 @@ class ConvertToQuantNetwork:
|
||||||
self.network.update_cell_prefix()
|
self.network.update_cell_prefix()
|
||||||
network = self._convert_subcells2quant(self.network)
|
network = self._convert_subcells2quant(self.network)
|
||||||
network = _AddFakeQuantInput(network)
|
network = _AddFakeQuantInput(network)
|
||||||
|
self.network.update_cell_type("quant")
|
||||||
return network
|
return network
|
||||||
|
|
||||||
def _convert_subcells2quant(self, network):
|
def _convert_subcells2quant(self, network):
|
||||||
"""
|
"""
|
||||||
convet sub cell to quant cell
|
convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell
|
||||||
"""
|
"""
|
||||||
cells = network.name_cells()
|
cells = network.name_cells()
|
||||||
change = False
|
change = False
|
||||||
|
@ -138,13 +138,13 @@ class ConvertToQuantNetwork:
|
||||||
if isinstance(network, nn.SequentialCell) and change:
|
if isinstance(network, nn.SequentialCell) and change:
|
||||||
network.cell_list = list(network.cells())
|
network.cell_list = list(network.cells())
|
||||||
|
|
||||||
# tensoradd to tensoradd quant
|
# add FakeQuant OP after OP in while list
|
||||||
add_list = []
|
add_list = []
|
||||||
for name in network.__dict__:
|
for name in network.__dict__:
|
||||||
if name[0] == '_':
|
if name[0] == '_':
|
||||||
continue
|
continue
|
||||||
attr = network.__dict__[name]
|
attr = network.__dict__[name]
|
||||||
if isinstance(attr, ops.Primitive) and attr.name in ConvertToQuantNetwork.__quant_op_name__:
|
if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__:
|
||||||
add_list.append((name, attr))
|
add_list.append((name, attr))
|
||||||
for name, prim_op in add_list:
|
for name, prim_op in add_list:
|
||||||
prefix = name
|
prefix = name
|
||||||
|
@ -164,11 +164,11 @@ class ConvertToQuantNetwork:
|
||||||
|
|
||||||
def _convert_conv(self, subcell):
|
def _convert_conv(self, subcell):
|
||||||
"""
|
"""
|
||||||
convet conv cell to quant cell
|
convert Conv2d cell to quant cell
|
||||||
"""
|
"""
|
||||||
conv_inner = subcell.conv
|
conv_inner = subcell.conv
|
||||||
bn_inner = subcell.batchnorm
|
if subcell.has_bn and self.bn_fold:
|
||||||
if subcell.batchnorm is not None and self.bn_fold:
|
bn_inner = subcell.batchnorm
|
||||||
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
|
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
|
||||||
conv_inner.out_channels,
|
conv_inner.out_channels,
|
||||||
kernel_size=conv_inner.kernel_size,
|
kernel_size=conv_inner.kernel_size,
|
||||||
|
@ -178,7 +178,7 @@ class ConvertToQuantNetwork:
|
||||||
dilation=conv_inner.dilation,
|
dilation=conv_inner.dilation,
|
||||||
group=conv_inner.group,
|
group=conv_inner.group,
|
||||||
eps=bn_inner.eps,
|
eps=bn_inner.eps,
|
||||||
momentum=bn_inner.momentum,
|
momentum=1 - bn_inner.momentum,
|
||||||
quant_delay=self.weight_qdelay,
|
quant_delay=self.weight_qdelay,
|
||||||
freeze_bn=self.freeze_bn,
|
freeze_bn=self.freeze_bn,
|
||||||
per_channel=self.weight_channel,
|
per_channel=self.weight_channel,
|
||||||
|
@ -186,6 +186,11 @@ class ConvertToQuantNetwork:
|
||||||
fake=True,
|
fake=True,
|
||||||
symmetric=self.weight_symmetric,
|
symmetric=self.weight_symmetric,
|
||||||
narrow_range=self.weight_range)
|
narrow_range=self.weight_range)
|
||||||
|
# change original network BatchNormal OP parameters to quant network
|
||||||
|
conv_inner.gamma = subcell.batchnorm.gamma
|
||||||
|
conv_inner.beta = subcell.batchnorm.beta
|
||||||
|
conv_inner.moving_mean = subcell.batchnorm.moving_mean
|
||||||
|
conv_inner.moving_variance = subcell.batchnorm.moving_variance
|
||||||
del subcell.batchnorm
|
del subcell.batchnorm
|
||||||
subcell.batchnorm = None
|
subcell.batchnorm = None
|
||||||
subcell.has_bn = False
|
subcell.has_bn = False
|
||||||
|
@ -204,6 +209,10 @@ class ConvertToQuantNetwork:
|
||||||
num_bits=self.weight_bits,
|
num_bits=self.weight_bits,
|
||||||
symmetric=self.weight_symmetric,
|
symmetric=self.weight_symmetric,
|
||||||
narrow_range=self.weight_range)
|
narrow_range=self.weight_range)
|
||||||
|
# change original network Conv2D OP parameters to quant network
|
||||||
|
conv_inner.weight = subcell.conv.weight
|
||||||
|
if subcell.conv.has_bias:
|
||||||
|
conv_inner.bias = subcell.conv.bias
|
||||||
subcell.conv = conv_inner
|
subcell.conv = conv_inner
|
||||||
if subcell.has_act and subcell.activation is not None:
|
if subcell.has_act and subcell.activation is not None:
|
||||||
subcell.activation = self._convert_activation(subcell.activation)
|
subcell.activation = self._convert_activation(subcell.activation)
|
||||||
|
@ -230,6 +239,10 @@ class ConvertToQuantNetwork:
|
||||||
per_channel=self.weight_channel,
|
per_channel=self.weight_channel,
|
||||||
symmetric=self.weight_symmetric,
|
symmetric=self.weight_symmetric,
|
||||||
narrow_range=self.weight_range)
|
narrow_range=self.weight_range)
|
||||||
|
# change original network Dense OP parameters to quant network
|
||||||
|
dense_inner.weight = subcell.dense.weight
|
||||||
|
if subcell.dense.has_bias:
|
||||||
|
dense_inner.bias = subcell.dense.bias
|
||||||
subcell.dense = dense_inner
|
subcell.dense = dense_inner
|
||||||
if subcell.has_act and subcell.activation is not None:
|
if subcell.has_act and subcell.activation is not None:
|
||||||
subcell.activation = self._convert_activation(subcell.activation)
|
subcell.activation = self._convert_activation(subcell.activation)
|
||||||
|
@ -247,12 +260,12 @@ class ConvertToQuantNetwork:
|
||||||
act_class = activation.__class__
|
act_class = activation.__class__
|
||||||
if act_class not in _ACTIVATION_MAP:
|
if act_class not in _ACTIVATION_MAP:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported activation in auto Quant: ", act_class)
|
"Unsupported activation in auto quant: ", act_class)
|
||||||
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits,
|
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits,
|
||||||
quant_delay=self.act_qdelay,
|
quant_delay=self.act_qdelay,
|
||||||
per_channel=self.act_channel,
|
per_channel=self.act_channel,
|
||||||
symmetric=self.weight_symmetric,
|
symmetric=self.act_symmetric,
|
||||||
narrow_range=self.weight_range)
|
narrow_range=self.act_range)
|
||||||
|
|
||||||
|
|
||||||
class ExportQuantNetworkDeploy:
|
class ExportQuantNetworkDeploy:
|
||||||
|
|
|
@ -40,8 +40,6 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
|
||||||
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
|
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
|
||||||
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
|
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
|
||||||
|
|
||||||
ModelType = ["normal", "fusion", "quant"]
|
|
||||||
|
|
||||||
|
|
||||||
def _special_process_par(par, new_par):
|
def _special_process_par(par, new_par):
|
||||||
"""
|
"""
|
||||||
|
@ -103,7 +101,7 @@ def _update_param(param, new_param):
|
||||||
param.set_parameter_data(type(param.data)(new_param.data))
|
param.set_parameter_data(type(param.data)(new_param.data))
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
|
def save_checkpoint(parameter_list, ckpt_file_name):
|
||||||
"""
|
"""
|
||||||
Saves checkpoint info to a specified file.
|
Saves checkpoint info to a specified file.
|
||||||
|
|
||||||
|
@ -111,14 +109,12 @@ def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
|
||||||
parameter_list (list): Parameters list, each element is a dict
|
parameter_list (list): Parameters list, each element is a dict
|
||||||
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
|
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
|
||||||
ckpt_file_name (str): Checkpoint file name.
|
ckpt_file_name (str): Checkpoint file name.
|
||||||
model_type (str): The name of model type. Default: "normal".
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: Failed to save the Checkpoint file.
|
RuntimeError: Failed to save the Checkpoint file.
|
||||||
"""
|
"""
|
||||||
logger.info("Execute save checkpoint process.")
|
logger.info("Execute save checkpoint process.")
|
||||||
checkpoint_list = Checkpoint()
|
checkpoint_list = Checkpoint()
|
||||||
checkpoint_list.model_type = model_type
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for param in parameter_list:
|
for param in parameter_list:
|
||||||
|
@ -147,13 +143,12 @@ def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
|
||||||
logger.info("Save checkpoint process finish.")
|
logger.info("Save checkpoint process finish.")
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
|
def load_checkpoint(ckpt_file_name, net=None):
|
||||||
"""
|
"""
|
||||||
Loads checkpoint info from a specified file.
|
Loads checkpoint info from a specified file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ckpt_file_name (str): Checkpoint file name.
|
ckpt_file_name (str): Checkpoint file name.
|
||||||
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
|
|
||||||
net (Cell): Cell network. Default: None
|
net (Cell): Cell network. Default: None
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -165,9 +160,6 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
|
||||||
if not isinstance(ckpt_file_name, str):
|
if not isinstance(ckpt_file_name, str):
|
||||||
raise ValueError("The ckpt_file_name must be string.")
|
raise ValueError("The ckpt_file_name must be string.")
|
||||||
|
|
||||||
if model_type not in ModelType:
|
|
||||||
raise ValueError(f"The model_type is not in {ModelType}.")
|
|
||||||
|
|
||||||
if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt":
|
if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt":
|
||||||
raise ValueError("Please input the correct checkpoint file name.")
|
raise ValueError("Please input the correct checkpoint file name.")
|
||||||
|
|
||||||
|
@ -186,10 +178,6 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
|
||||||
raise ValueError(e.__str__())
|
raise ValueError(e.__str__())
|
||||||
|
|
||||||
parameter_dict = {}
|
parameter_dict = {}
|
||||||
if checkpoint_list.model_type:
|
|
||||||
if model_type != checkpoint_list.model_type:
|
|
||||||
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
|
|
||||||
checkpoint_list.model_type, model_type))
|
|
||||||
try:
|
try:
|
||||||
for element in checkpoint_list.value:
|
for element in checkpoint_list.value:
|
||||||
data = element.tensor.tensor_content
|
data = element.tensor.tensor_content
|
||||||
|
@ -314,14 +302,13 @@ def _save_graph(network, file_name):
|
||||||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||||
|
|
||||||
|
|
||||||
def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", integrated_save=True):
|
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
|
||||||
"""
|
"""
|
||||||
Saves checkpoint for 'ms' backend.
|
Saves checkpoint for 'ms' backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_network (Network): The train network for training.
|
train_network (Network): The train network for training.
|
||||||
ckpt_file_name (str): The name of checkpoint file.
|
ckpt_file_name (str): The name of checkpoint file.
|
||||||
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
|
|
||||||
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
|
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -346,7 +333,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", in
|
||||||
each_param["data"] = param_data
|
each_param["data"] = param_data
|
||||||
param_list.append(each_param)
|
param_list.append(each_param)
|
||||||
|
|
||||||
save_checkpoint(param_list, ckpt_file_name, model_type)
|
save_checkpoint(param_list, ckpt_file_name)
|
||||||
|
|
||||||
|
|
||||||
def _get_merged_param_data(net, param_name, param_data):
|
def _get_merged_param_data(net, param_name, param_data):
|
||||||
|
|
|
@ -33,7 +33,7 @@ Then you will get the following display
|
||||||
```bash
|
```bash
|
||||||
>>> Found existing installation: mindspore-ascend
|
>>> Found existing installation: mindspore-ascend
|
||||||
>>> Uninstalling mindspore-ascend:
|
>>> Uninstalling mindspore-ascend:
|
||||||
>>> Successfully uninstalled mindspore-ascend.
|
>>> Successfully uninstalled mindspore-ascend.
|
||||||
```
|
```
|
||||||
|
|
||||||
### Prepare Dataset
|
### Prepare Dataset
|
||||||
|
@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||||
|
|
||||||
### train quantization aware model
|
### train quantization aware model
|
||||||
|
|
||||||
Also, you can just run this command instread.
|
Also, you can just run this command instead.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
|
python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
|
||||||
|
@ -235,7 +235,7 @@ The top1 accuracy would display on shell.
|
||||||
Here are some optional parameters:
|
Here are some optional parameters:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
--device_target {Ascend,GPU,CPU}
|
--device_target {Ascend,GPU}
|
||||||
device where the code will be implemented (default: Ascend)
|
device where the code will be implemented (default: Ascend)
|
||||||
--data_path DATA_PATH
|
--data_path DATA_PATH
|
||||||
path where the dataset is saved
|
path where the dataset is saved
|
||||||
|
|
|
@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||||
choices=['Ascend', 'GPU', 'CPU'],
|
choices=['Ascend', 'GPU'],
|
||||||
help='device where the code will be implemented (default: Ascend)')
|
help='device where the code will be implemented (default: Ascend)')
|
||||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||||
help='path where the dataset is saved')
|
help='path where the dataset is saved')
|
||||||
|
|
|
@ -32,7 +32,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||||
choices=['Ascend', 'GPU', 'CPU'],
|
choices=['Ascend', 'GPU'],
|
||||||
help='device where the code will be implemented (default: Ascend)')
|
help='device where the code will be implemented (default: Ascend)')
|
||||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||||
help='path where the dataset is saved')
|
help='path where the dataset is saved')
|
||||||
|
@ -61,7 +61,7 @@ if __name__ == "__main__":
|
||||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||||
|
|
||||||
# load quantization aware network checkpoint
|
# load quantization aware network checkpoint
|
||||||
param_dict = load_checkpoint(args.ckpt_path, model_type="quant")
|
param_dict = load_checkpoint(args.ckpt_path)
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
|
|
||||||
print("============== Starting Testing ==============")
|
print("============== Starting Testing ==============")
|
||||||
|
|
|
@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||||
choices=['Ascend', 'GPU', 'CPU'],
|
choices=['Ascend', 'GPU'],
|
||||||
help='device where the code will be implemented (default: Ascend)')
|
help='device where the code will be implemented (default: Ascend)')
|
||||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||||
help='path where the dataset is saved')
|
help='path where the dataset is saved')
|
||||||
|
@ -56,8 +56,7 @@ if __name__ == "__main__":
|
||||||
# call back and monitor
|
# call back and monitor
|
||||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||||
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||||
keep_checkpoint_max=cfg.keep_checkpoint_max,
|
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||||
model_type=network.type)
|
|
||||||
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
||||||
|
|
||||||
# define model
|
# define model
|
||||||
|
|
|
@ -33,7 +33,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||||
choices=['Ascend', 'GPU', 'CPU'],
|
choices=['Ascend', 'GPU'],
|
||||||
help='device where the code will be implemented (default: Ascend)')
|
help='device where the code will be implemented (default: Ascend)')
|
||||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||||
help='path where the dataset is saved')
|
help='path where the dataset is saved')
|
||||||
|
@ -50,11 +50,13 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# define fusion network
|
# define fusion network
|
||||||
network = LeNet5Fusion(cfg.num_classes)
|
network = LeNet5Fusion(cfg.num_classes)
|
||||||
|
|
||||||
|
# convert fusion network to quantization aware network
|
||||||
|
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
||||||
|
|
||||||
# load quantization aware network checkpoint
|
# load quantization aware network checkpoint
|
||||||
param_dict = load_checkpoint(args.ckpt_path, network.type)
|
param_dict = load_checkpoint(args.ckpt_path, network.type)
|
||||||
load_param_into_net(network, param_dict)
|
load_param_into_net(network, param_dict)
|
||||||
# convert fusion network to quantization aware network
|
|
||||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
|
||||||
|
|
||||||
# define network loss
|
# define network loss
|
||||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||||
|
@ -64,8 +66,7 @@ if __name__ == "__main__":
|
||||||
# call back and monitor
|
# call back and monitor
|
||||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||||
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||||
keep_checkpoint_max=cfg.keep_checkpoint_max,
|
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||||
model_type="quant")
|
|
||||||
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
||||||
|
|
||||||
# define model
|
# define model
|
||||||
|
|
|
@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||||
export DEVICE_ID=0
|
export DEVICE_ID=0
|
||||||
export RANK_ID=0
|
export RANK_ID=0
|
||||||
export RANK_SIZE=1
|
export RANK_SIZE=1
|
||||||
if [ -d "eval" ];
|
if [ -d "../eval" ];
|
||||||
then
|
then
|
||||||
rm -rf ../eval
|
rm -rf ../eval
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -62,7 +62,7 @@ run_gpu()
|
||||||
|
|
||||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||||
if [ -d "train" ];
|
if [ -d "../train" ];
|
||||||
then
|
then
|
||||||
rm -rf ../train
|
rm -rf ../train
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||||
export DEVICE_ID=0
|
export DEVICE_ID=0
|
||||||
export RANK_ID=0
|
export RANK_ID=0
|
||||||
export RANK_SIZE=1
|
export RANK_SIZE=1
|
||||||
if [ -d "eval" ];
|
if [ -d "../eval" ];
|
||||||
then
|
then
|
||||||
rm -rf ../eval
|
rm -rf ../eval
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -60,7 +60,7 @@ run_gpu()
|
||||||
|
|
||||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||||
if [ -d "train" ];
|
if [ -d "../train" ];
|
||||||
then
|
then
|
||||||
rm -rf ../train
|
rm -rf ../train
|
||||||
fi
|
fi
|
||||||
|
|
|
@ -31,7 +31,7 @@ def _conv_bn(in_channel,
|
||||||
out_channel,
|
out_channel,
|
||||||
kernel_size=ksize,
|
kernel_size=ksize,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
batchnorm=True)])
|
has_bn=True)])
|
||||||
|
|
||||||
|
|
||||||
class InvertedResidual(nn.Cell):
|
class InvertedResidual(nn.Cell):
|
||||||
|
@ -49,25 +49,25 @@ class InvertedResidual(nn.Cell):
|
||||||
3,
|
3,
|
||||||
stride,
|
stride,
|
||||||
group=hidden_dim,
|
group=hidden_dim,
|
||||||
batchnorm=True,
|
has_bn=True,
|
||||||
activation='relu6'),
|
activation='relu6'),
|
||||||
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
|
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
|
||||||
batchnorm=True)
|
has_bn=True)
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
self.conv = nn.SequentialCell([
|
self.conv = nn.SequentialCell([
|
||||||
nn.Conv2dBnAct(inp, hidden_dim, 1, 1,
|
nn.Conv2dBnAct(inp, hidden_dim, 1, 1,
|
||||||
batchnorm=True,
|
has_bn=True,
|
||||||
activation='relu6'),
|
activation='relu6'),
|
||||||
nn.Conv2dBnAct(hidden_dim,
|
nn.Conv2dBnAct(hidden_dim,
|
||||||
hidden_dim,
|
hidden_dim,
|
||||||
3,
|
3,
|
||||||
stride,
|
stride,
|
||||||
group=hidden_dim,
|
group=hidden_dim,
|
||||||
batchnorm=True,
|
has_bn=True,
|
||||||
activation='relu6'),
|
activation='relu6'),
|
||||||
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
|
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
|
||||||
batchnorm=True)
|
has_bn=True)
|
||||||
])
|
])
|
||||||
self.add = P.TensorAdd()
|
self.add = P.TensorAdd()
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ class LeNet5(nn.Cell):
|
||||||
def __init__(self, num_class=10):
|
def __init__(self, num_class=10):
|
||||||
super(LeNet5, self).__init__()
|
super(LeNet5, self).__init__()
|
||||||
self.num_class = num_class
|
self.num_class = num_class
|
||||||
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6', pad_mode="valid")
|
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu6', pad_mode="valid")
|
||||||
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid")
|
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid")
|
||||||
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
|
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
|
||||||
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
|
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
|
||||||
|
|
Loading…
Reference in New Issue