forked from mindspore-Ecosystem/mindspore
!2718 fix quantization aware training auto create graph bug
Merge pull request !2718 from chenzhongming/master
This commit is contained in:
commit
f1a9a7ceb1
|
@ -22,7 +22,6 @@ message Checkpoint {
|
|||
required TensorProto tensor = 2;
|
||||
}
|
||||
repeated Value value = 1;
|
||||
required string model_type = 2;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -81,6 +81,7 @@ class Cell:
|
|||
self.enable_hook = False
|
||||
self._bprop_debug = False
|
||||
self._is_run = False
|
||||
self.cell_type = None
|
||||
|
||||
@property
|
||||
def is_run(self):
|
||||
|
@ -140,6 +141,14 @@ class Cell:
|
|||
for cell_name, cell in cells_name:
|
||||
cell._param_prefix = cell_name
|
||||
|
||||
def update_cell_type(self, cell_type):
|
||||
"""
|
||||
Update current cell type mainly identify if quantization aware training network.
|
||||
|
||||
After invoked, can set the cell type to 'cell_type'.
|
||||
"""
|
||||
self.cell_type = cell_type
|
||||
|
||||
@cell_init_args.setter
|
||||
def cell_init_args(self, value):
|
||||
if not isinstance(value, str):
|
||||
|
|
|
@ -17,11 +17,12 @@ from mindspore import log as logger
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore._checkparam import ParamValidator as validator, Rel
|
||||
from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative
|
||||
from mindspore._extends import cell_attr_register
|
||||
from ..cell import Cell
|
||||
|
||||
__all__ = ['Conv2d', 'Conv2dTranspose']
|
||||
__all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d']
|
||||
|
||||
class _Conv(Cell):
|
||||
"""
|
||||
|
@ -397,3 +398,150 @@ class Conv2dTranspose(_Conv):
|
|||
self.weight,
|
||||
self.bias)
|
||||
return s
|
||||
|
||||
|
||||
class DepthwiseConv2d(Cell):
|
||||
r"""
|
||||
2D depthwise convolution layer.
|
||||
|
||||
Applies a 2D depthwise convolution over an input tensor which is typically of shape:
|
||||
math:`(N, C_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number.
|
||||
For each batch of shape:math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
|
||||
|
||||
where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
|
||||
from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to :math:`i`-th channel of the :math:`j`-th
|
||||
filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
|
||||
of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and
|
||||
:math:`\text{ks_w}` are height and width of the convolution kernel. The full kernel has shape
|
||||
:math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number
|
||||
to split the input in the channel dimension.
|
||||
|
||||
If the 'pad_mode' is set to be "valid", the output height and width will be
|
||||
:math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
|
||||
(\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
|
||||
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
|
||||
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
|
||||
|
||||
The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
|
||||
<http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channel :math:`C_{in}`.
|
||||
out_channels (int): The number of output channel :math:`C_{out}`.
|
||||
kernel_size (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the height
|
||||
and width of the 2D convolution window. Single int means the value if for both height and width of
|
||||
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
|
||||
width of the kernel.
|
||||
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
|
||||
the height and width of movement are both strides, or a tuple of two int numbers that
|
||||
represent height and width of movement respectively. Default: 1.
|
||||
pad_mode (str): Specifies padding mode. The optional values are
|
||||
"same", "valid", "pad". Default: "same".
|
||||
|
||||
- same: Adopts the way of completion. Output height and width will be the same as the input.
|
||||
Total number of padding will be calculated for horizontal and vertical
|
||||
direction and evenly distributed to top and bottom, left and right if possible. Otherwise, the
|
||||
last extra padding will be done from the bottom and the right side. If this mode is set, `padding`
|
||||
must be 0.
|
||||
|
||||
- valid: Adopts the way of discarding. The possibly largest height and width of output will be return
|
||||
without padding. Extra pixels will be discarded. If this mode is set, `padding`
|
||||
must be 0.
|
||||
|
||||
- pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
|
||||
Tensor borders. `padding` should be greater than or equal to 0.
|
||||
|
||||
padding (int): Implicit paddings on both sides of the input. Default: 0.
|
||||
dilation (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the dilation rate
|
||||
to use for dilated convolution. If set to be :math:`k > 1`, there will
|
||||
be :math:`k - 1` pixels skipped for each sampling location. Its value should
|
||||
be greater or equal to 1 and bounded by the height and width of the
|
||||
input. Default: 1.
|
||||
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
|
||||
divisible by the number of groups. Default: 1.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
|
||||
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
|
||||
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
|
||||
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
|
||||
Initializer for more details. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
|
||||
Initializer and string are the same as 'weight_init'. Refer to the values of
|
||||
Initializer for more details. Default: 'zeros'.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.DepthwiseConv2d(120, 240, 4, has_bias=False, weight_init='normal')
|
||||
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||
>>> net(input).shape
|
||||
(1, 240, 1024, 640)
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
pad_mode='same',
|
||||
padding=0,
|
||||
dilation=1,
|
||||
group=1,
|
||||
has_bias=False,
|
||||
weight_init='normal',
|
||||
bias_init='zeros'):
|
||||
super(DepthwiseConv2d, self).__init__()
|
||||
self.kernel_size = twice(kernel_size)
|
||||
self.stride = twice(stride)
|
||||
self.dilation = twice(dilation)
|
||||
self.in_channels = check_int_positive(in_channels)
|
||||
self.out_channels = check_int_positive(out_channels)
|
||||
validator.check_integer('group', group, in_channels, Rel.EQ)
|
||||
validator.check_integer('group', group, out_channels, Rel.EQ)
|
||||
validator.check_integer('group', group, 1, Rel.GE)
|
||||
self.pad_mode = pad_mode
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.group = group
|
||||
self.has_bias = has_bias
|
||||
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
|
||||
kernel_size=self.kernel_size,
|
||||
pad_mode=self.pad_mode,
|
||||
pad=self.padding,
|
||||
stride=self.stride,
|
||||
dilation=self.dilation)
|
||||
self.bias_add = P.BiasAdd()
|
||||
weight_shape = [1, in_channels, *self.kernel_size]
|
||||
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
|
||||
if check_bool(has_bias):
|
||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
||||
else:
|
||||
if bias_init != 'zeros':
|
||||
logger.warning("value of `has_bias` is False, value of `bias_init` will be ignore.")
|
||||
self.bias = None
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv(x, self.weight)
|
||||
if self.has_bias:
|
||||
out = self.bias_add(out, self.bias)
|
||||
return out
|
||||
|
||||
def extend_repr(self):
|
||||
s = 'input_channels={}, output_channels={}, kernel_size={}, stride={}, ' \
|
||||
'pad_mode={}, padding={}, dilation={}, group={},' \
|
||||
'has_bias={}, weight_init={}, bias_init={}'.format(
|
||||
self.in_channels, self.out_channels, self.kernel_size, self.stride,
|
||||
self.pad_mode, self.padding, self.dilation, self.group,
|
||||
self.has_bias, self.weight_init, self.bias_init)
|
||||
|
||||
if self.has_bias:
|
||||
s += ', bias={}'.format(self.bias)
|
||||
return s
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -23,10 +24,9 @@ from mindspore.common.parameter import Parameter
|
|||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore._checkparam import check_int_positive, check_bool, twice
|
||||
from mindspore._checkparam import Validator as validator, Rel
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer.activation import get_activation
|
||||
from mindspore._checkparam import Rel
|
||||
import mindspore.context as context
|
||||
|
||||
from .normalization import BatchNorm2d
|
||||
from .activation import get_activation
|
||||
from ..cell import Cell
|
||||
|
@ -82,7 +82,7 @@ class Conv2dBnAct(Cell):
|
|||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
|
||||
Initializer and string are the same as 'weight_init'. Refer to the values of
|
||||
Initializer for more details. Default: 'zeros'.
|
||||
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
||||
has_bn (bool): Specifies to used batchnorm or not. Default: False.
|
||||
activation (string): Specifies activation type. The optional values are as following:
|
||||
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||
|
@ -94,7 +94,7 @@ class Conv2dBnAct(Cell):
|
|||
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
|
||||
Examples:
|
||||
>>> net = Conv2dBnAct(120, 240, 4, batchnorm=True, activation='ReLU')
|
||||
>>> net = Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU')
|
||||
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||
>>> net(input).shape
|
||||
(1, 240, 1024, 640)
|
||||
|
@ -112,28 +112,39 @@ class Conv2dBnAct(Cell):
|
|||
has_bias=False,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
batchnorm=None,
|
||||
has_bn=False,
|
||||
activation=None):
|
||||
super(Conv2dBnAct, self).__init__()
|
||||
self.conv = conv.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
pad_mode,
|
||||
padding,
|
||||
dilation,
|
||||
group,
|
||||
has_bias,
|
||||
weight_init,
|
||||
bias_init)
|
||||
self.has_bn = batchnorm is not None
|
||||
|
||||
if context.get_context('device_target') == "Ascend" and group > 1:
|
||||
self.conv = conv.DepthwiseConv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
pad_mode=pad_mode,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
group=group,
|
||||
has_bias=has_bias,
|
||||
weight_init=weight_init,
|
||||
bias_init=bias_init)
|
||||
else:
|
||||
self.conv = conv.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
pad_mode=pad_mode,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
group=group,
|
||||
has_bias=has_bias,
|
||||
weight_init=weight_init,
|
||||
bias_init=bias_init)
|
||||
|
||||
self.has_bn = validator.check_bool("has_bn", has_bn)
|
||||
self.has_act = activation is not None
|
||||
self.batchnorm = batchnorm
|
||||
if batchnorm is True:
|
||||
if has_bn:
|
||||
self.batchnorm = BatchNorm2d(out_channels)
|
||||
elif batchnorm is not None:
|
||||
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
|
||||
self.activation = get_activation(activation)
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -160,7 +171,7 @@ class DenseBnAct(Cell):
|
|||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
||||
has_bn (bool): Specifies to used batchnorm or not. Default: False.
|
||||
activation (string): Specifies activation type. The optional values are as following:
|
||||
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||
|
@ -183,7 +194,7 @@ class DenseBnAct(Cell):
|
|||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
has_bias=True,
|
||||
batchnorm=None,
|
||||
has_bn=False,
|
||||
activation=None):
|
||||
super(DenseBnAct, self).__init__()
|
||||
self.dense = basic.Dense(
|
||||
|
@ -192,12 +203,10 @@ class DenseBnAct(Cell):
|
|||
weight_init,
|
||||
bias_init,
|
||||
has_bias)
|
||||
self.has_bn = batchnorm is not None
|
||||
self.has_bn = validator.check_bool("has_bn", has_bn)
|
||||
self.has_act = activation is not None
|
||||
if batchnorm is True:
|
||||
if has_bn:
|
||||
self.batchnorm = BatchNorm2d(out_channels)
|
||||
elif batchnorm is not None:
|
||||
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
|
||||
self.activation = get_activation(activation)
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -312,6 +321,10 @@ class FakeQuantWithMinMax(Cell):
|
|||
quant_delay=0):
|
||||
"""init FakeQuantWithMinMax layer"""
|
||||
super(FakeQuantWithMinMax, self).__init__()
|
||||
validator.check_type("min_init", min_init, [int, float])
|
||||
validator.check_type("max_init", max_init, [int, float])
|
||||
validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
|
||||
validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
|
||||
self.min_init = min_init
|
||||
self.max_init = max_init
|
||||
self.num_bits = num_bits
|
||||
|
@ -1183,12 +1196,13 @@ class QuantBlock(Cell):
|
|||
self.has_bias = bias is None
|
||||
self.activation = activation
|
||||
self.has_act = activation is None
|
||||
self.bias_add = P.BiasAdd()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.quant(x)
|
||||
x = self.core_op(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
x = self.bias_add(x, self.bias)
|
||||
if self.has_act:
|
||||
x = self.activation(x)
|
||||
x = self.dequant(x, self.dequant_scale)
|
||||
|
|
|
@ -30,7 +30,6 @@ batchnorm_fold2_op_info = TBERegOp("BatchNormFold2_D") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("batchnorm_fold2") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.input(0, "x", None, "required", None) \
|
||||
.input(1, "beta", None, "required", None) \
|
||||
.input(2, "gamma", None, "required", None) \
|
||||
|
|
|
@ -30,7 +30,6 @@ batchnorm_fold2_grad_op_info = TBERegOp("BatchNormFold2GradD") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("batchnorm_fold2_grad") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.input(0, "dout", None, "required", None) \
|
||||
.input(1, "dout_reduce", None, "required", None) \
|
||||
.input(2, "dout_x_reduce", None, "required", None) \
|
||||
|
|
|
@ -31,7 +31,6 @@ batchnorm_fold2_grad_reduce_op_info = TBERegOp("BatchNormFold2GradReduce") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("batchnorm_fold2_grad_reduce") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.input(0, "dout", None, "required", None) \
|
||||
.input(1, "x", None, "required", None) \
|
||||
.output(0, "dout_reduce", True, "required", "all") \
|
||||
|
|
|
@ -30,7 +30,6 @@ correction_mul_op_info = TBERegOp("CorrectionMul") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("correction_mul") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.attr("channel_axis", "optional", "int", "all") \
|
||||
.input(0, "x", None, "required", None) \
|
||||
.input(1, "batch_std", None, "required", None) \
|
||||
|
|
|
@ -30,7 +30,6 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("correction_mul_grad") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.attr("channel_axis", "optional", "int", "all") \
|
||||
.input(0, "dout", None, "required", None) \
|
||||
.input(1, "x", None, "required", None) \
|
||||
|
@ -128,7 +127,6 @@ correction_mul_grad_reduce_op_info = TBERegOp("CorrectionMulGradReduce") \
|
|||
.compute_cost(10) \
|
||||
.kernel_name("correction_mul_grad_reduce") \
|
||||
.partial_flag(True) \
|
||||
.op_pattern("formatAgnostic") \
|
||||
.attr("channel_axis", "optional", "int", "all") \
|
||||
.input(0, "dout", None, "required", None) \
|
||||
.output(0, "d_batch_std", True, "required", "all") \
|
||||
|
|
|
@ -99,11 +99,15 @@ def fake_quant_perchannel(x, min_val, max_val, y,
|
|||
min_dtype = min_val.get("dtype")
|
||||
max_shape = max_val.get("ori_shape")
|
||||
max_dtype = max_val.get("dtype")
|
||||
|
||||
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
|
||||
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
|
||||
channel_axis_ = 1
|
||||
else:
|
||||
channel_axis_ = channel_axis
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(x_shape)
|
||||
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
|
||||
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
|
||||
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
|
||||
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
|
||||
util.check_tensor_shape_size(x_shape)
|
||||
util.check_tensor_shape_size(min_shape)
|
||||
util.check_tensor_shape_size(max_shape)
|
||||
|
@ -126,8 +130,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,
|
|||
quant_min = quant_min + 1
|
||||
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis == 1:
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
||||
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
||||
|
|
|
@ -124,11 +124,15 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
|||
min_dtype = min_val.get("dtype")
|
||||
max_shape = max_val.get("ori_shape")
|
||||
max_dtype = max_val.get("dtype")
|
||||
|
||||
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
|
||||
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
|
||||
channel_axis_ = 1
|
||||
else:
|
||||
channel_axis_ = channel_axis
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(x_shape)
|
||||
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
|
||||
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
|
||||
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
|
||||
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
|
||||
util.check_tensor_shape_size(x_shape)
|
||||
util.check_tensor_shape_size(min_shape)
|
||||
util.check_tensor_shape_size(max_shape)
|
||||
|
@ -151,8 +155,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
|||
quant_min = quant_min + 1
|
||||
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis == 1:
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
|
||||
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
||||
|
|
|
@ -88,11 +88,15 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
|
|||
min_dtype = min_val.get("dtype")
|
||||
max_shape = max_val.get("ori_shape")
|
||||
max_dtype = max_val.get("dtype")
|
||||
|
||||
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
|
||||
if channel_axis == 0 and x_shape[0] != min_shape[0] and x_shape[1] == min_shape[0]:
|
||||
channel_axis_ = 1
|
||||
else:
|
||||
channel_axis_ = channel_axis
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(x_shape)
|
||||
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
|
||||
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
|
||||
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis_])
|
||||
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis_])
|
||||
util.check_tensor_shape_size(x_shape)
|
||||
util.check_tensor_shape_size(min_shape)
|
||||
util.check_tensor_shape_size(max_shape)
|
||||
|
@ -105,7 +109,7 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
|
|||
util.check_dtype_rule(min_dtype, check_list)
|
||||
util.check_dtype_rule(max_dtype, check_list)
|
||||
|
||||
if channel_axis == 0:
|
||||
if channel_axis_ == 0:
|
||||
shape_c = min_val.get("ori_shape")
|
||||
else:
|
||||
shape_c = [min_val.get("shape")[1], min_val.get("shape")[-1]]
|
||||
|
@ -113,7 +117,7 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
|
|||
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
||||
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
|
||||
res_list = minmax_update_perchannel_compute(input_data, min_data, max_data,
|
||||
ema, ema_decay, channel_axis)
|
||||
ema, ema_decay, channel_axis_)
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res_list)
|
||||
|
|
|
@ -106,7 +106,7 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
|
|||
Args:
|
||||
ema (bool): Use EMA algorithm update value min and max. Default: False.
|
||||
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
|
||||
channel_axis (int): Channel asis for per channel compute. Default: 1.
|
||||
channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
|
||||
|
@ -123,11 +123,13 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
|
|||
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
|
||||
"""
|
||||
support_quant_bit = [4, 7, 8]
|
||||
ascend_support_x_rank = [2, 4]
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
|
||||
"""init FakeQuantPerChannelUpdate OP for Ascend"""
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
self.is_ascend = context.get_context('device_target') == "Ascend"
|
||||
if self.is_ascend:
|
||||
from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
|
||||
if ema and not ema_decay:
|
||||
raise ValueError(
|
||||
|
@ -136,13 +138,18 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
|
|||
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
|
||||
self.ema_decay = validator.check_number_range(
|
||||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.channel_axis = validator.check_integer(
|
||||
'channel axis', channel_axis, 0, Rel.GE, self.name)
|
||||
if self.is_ascend:
|
||||
self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
|
||||
else:
|
||||
self.channel_axis = validator.check_integer('channel_axis', channel_axis, 0, Rel.GE, self.name)
|
||||
self.init_prim_io_names(
|
||||
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
|
||||
|
||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
|
||||
if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
|
||||
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
|
||||
if not self.is_ascend:
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||
validator.check("min shape", min_shape, "max shape",
|
||||
max_shape, Rel.EQ, self.name)
|
||||
validator.check_integer("min shape", len(
|
||||
|
@ -221,8 +228,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
|
|||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.num_bits = validator.check_integer(
|
||||
'num_bits', num_bits, 0, Rel.GT, self.name)
|
||||
self.quant_delay = validator.check_value_type(
|
||||
'quant_delay', quant_delay, (int,), self.name)
|
||||
self.quant_delay = validator.check_integer(
|
||||
'quant_delay', quant_delay, 0, Rel.GE, self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'min', 'max'],
|
||||
outputs=['out'])
|
||||
|
||||
|
@ -314,6 +321,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
|||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
training (bool): Training the network or not. Default: True.
|
||||
channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
|
||||
|
@ -331,6 +339,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
|||
>>> result = fake_quant(input_x, _min, _max)
|
||||
"""
|
||||
support_quant_bit = [4, 7, 8]
|
||||
ascend_support_x_rank = [2, 4]
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
|
@ -343,7 +352,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
|||
training=True,
|
||||
channel_axis=1):
|
||||
"""init FakeQuantPerChannel OP"""
|
||||
if context.get_context('device_target') == "Ascend":
|
||||
self.is_ascend = context.get_context('device_target') == "Ascend"
|
||||
if self.is_ascend:
|
||||
from mindspore.ops._op_impl._custom_op import fake_quant_perchannel
|
||||
if num_bits not in self.support_quant_bit:
|
||||
raise ValueError(
|
||||
|
@ -363,14 +373,19 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
|||
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.num_bits = validator.check_integer(
|
||||
'num_bits', num_bits, 0, Rel.GT, self.name)
|
||||
self.quant_delay = validator.check_value_type(
|
||||
'quant_delay', quant_delay, (int,), self.name)
|
||||
self.channel_axis = validator.check_integer(
|
||||
'channel_axis', channel_axis, 0, Rel.GE, self.name)
|
||||
self.quant_delay = validator.check_integer(
|
||||
'quant_delay', quant_delay, 0, Rel.GE, self.name)
|
||||
if self.is_ascend:
|
||||
self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
|
||||
else:
|
||||
self.channel_axis = validator.check_integer('channel_axis', channel_axis, 0, Rel.GE, self.name)
|
||||
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
|
||||
|
||||
def infer_shape(self, x_shape, min_shape, max_shape):
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||
if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
|
||||
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
|
||||
if not self.is_ascend:
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
|
||||
validator.check_integer(
|
||||
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
|
||||
|
|
|
@ -21,7 +21,7 @@ import time
|
|||
|
||||
import mindspore.context as context
|
||||
from mindspore import log as logger
|
||||
from mindspore._checkparam import check_bool, check_string, check_int_non_negative
|
||||
from mindspore._checkparam import check_bool, check_int_non_negative
|
||||
from mindspore.train._utils import _make_directory
|
||||
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
|
||||
from ._callback import Callback, set_cur_net
|
||||
|
@ -86,7 +86,6 @@ class CheckpointConfig:
|
|||
Can't be used with keep_checkpoint_max at the same time.
|
||||
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
|
||||
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
|
||||
model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal".
|
||||
|
||||
Raises:
|
||||
ValueError: If the input_param is None or 0.
|
||||
|
@ -101,8 +100,7 @@ class CheckpointConfig:
|
|||
save_checkpoint_seconds=0,
|
||||
keep_checkpoint_max=5,
|
||||
keep_checkpoint_per_n_minutes=0,
|
||||
integrated_save=True,
|
||||
model_type="normal"):
|
||||
integrated_save=True):
|
||||
|
||||
if not save_checkpoint_steps and not save_checkpoint_seconds and \
|
||||
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
||||
|
@ -116,8 +114,6 @@ class CheckpointConfig:
|
|||
keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
|
||||
if keep_checkpoint_per_n_minutes:
|
||||
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
|
||||
if model_type:
|
||||
model_type = check_string(model_type, ["normal", "fusion", "quant"])
|
||||
|
||||
self._save_checkpoint_steps = save_checkpoint_steps
|
||||
self._save_checkpoint_seconds = save_checkpoint_seconds
|
||||
|
@ -132,7 +128,6 @@ class CheckpointConfig:
|
|||
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
|
||||
self._keep_checkpoint_max = 1
|
||||
|
||||
self._model_type = model_type
|
||||
self._integrated_save = check_bool(integrated_save)
|
||||
|
||||
@property
|
||||
|
@ -160,18 +155,12 @@ class CheckpointConfig:
|
|||
"""Get the value of _integrated_save."""
|
||||
return self._integrated_save
|
||||
|
||||
@property
|
||||
def model_type(self):
|
||||
"""Get the value of model_type."""
|
||||
return self._model_type
|
||||
|
||||
def get_checkpoint_policy(self):
|
||||
"""Get the policy of checkpoint."""
|
||||
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
|
||||
'save_checkpoint_seconds': self._save_checkpoint_seconds,
|
||||
'keep_checkpoint_max': self._keep_checkpoint_max,
|
||||
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes,
|
||||
'model_type': self._model_type}
|
||||
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes}
|
||||
|
||||
return checkpoint_policy
|
||||
|
||||
|
@ -236,7 +225,7 @@ class ModelCheckpoint(Callback):
|
|||
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
|
||||
_save_graph(cb_params.train_network, graph_file_name)
|
||||
self._graph_saved = True
|
||||
self._save_ckpt(cb_params, self._config.model_type)
|
||||
self._save_ckpt(cb_params)
|
||||
|
||||
def end(self, run_context):
|
||||
"""
|
||||
|
@ -247,7 +236,7 @@ class ModelCheckpoint(Callback):
|
|||
"""
|
||||
cb_params = run_context.original_args()
|
||||
_to_save_last_ckpt = True
|
||||
self._save_ckpt(cb_params, self._config.model_type, _to_save_last_ckpt)
|
||||
self._save_ckpt(cb_params, _to_save_last_ckpt)
|
||||
|
||||
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
|
||||
destroy_allgather_cell()
|
||||
|
@ -266,7 +255,7 @@ class ModelCheckpoint(Callback):
|
|||
|
||||
return False
|
||||
|
||||
def _save_ckpt(self, cb_params, model_type, force_to_save=False):
|
||||
def _save_ckpt(self, cb_params, force_to_save=False):
|
||||
"""Save checkpoint files."""
|
||||
if cb_params.cur_step_num == self._last_triggered_step:
|
||||
return
|
||||
|
@ -302,7 +291,7 @@ class ModelCheckpoint(Callback):
|
|||
set_cur_net(cb_params.train_network)
|
||||
cb_params.train_network.exec_checkpoint_graph()
|
||||
|
||||
_exec_save_checkpoint(cb_params.train_network, gen_file, model_type, self._config.integrated_save)
|
||||
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
|
||||
|
||||
if os.path.exists(gen_file):
|
||||
shutil.move(gen_file, cur_file)
|
||||
|
|
|
@ -86,7 +86,7 @@ class LossMonitor(Callback):
|
|||
|
||||
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
|
||||
print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
|
||||
"loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format(
|
||||
"loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}]".format(
|
||||
cb_params.cur_epoch_num, cb_params.epoch_num,
|
||||
cur_step_in_epoch, int(cb_params.batch_num),
|
||||
step_loss, np.mean(self.losses),
|
||||
|
|
|
@ -42,15 +42,14 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
|
|||
|
||||
class _AddFakeQuantInput(nn.Cell):
|
||||
"""
|
||||
Add FakeQuant at input and output of the Network. Only support one input and one output case.
|
||||
Add FakeQuant OP at input of the network. Only support one input case.
|
||||
"""
|
||||
|
||||
def __init__(self, network, quant_delay=0):
|
||||
super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
|
||||
self.fake_quant_input = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
|
||||
self.fake_quant_input.update_parameters_name('fake_quant_input.')
|
||||
self.network = network
|
||||
self.fake_quant_input = quant.FakeQuantWithMinMax(
|
||||
min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
|
||||
self.fake_quant_input.update_parameters_name('fake_quant_input')
|
||||
|
||||
def construct(self, data):
|
||||
data = self.fake_quant_input(data)
|
||||
|
@ -60,7 +59,7 @@ class _AddFakeQuantInput(nn.Cell):
|
|||
|
||||
class _AddFakeQuantAfterSubCell(nn.Cell):
|
||||
"""
|
||||
Add FakeQuant after of the sub Cell.
|
||||
Add FakeQuant OP after of the sub Cell.
|
||||
"""
|
||||
|
||||
def __init__(self, subcell, **kwargs):
|
||||
|
@ -115,11 +114,12 @@ class ConvertToQuantNetwork:
|
|||
self.network.update_cell_prefix()
|
||||
network = self._convert_subcells2quant(self.network)
|
||||
network = _AddFakeQuantInput(network)
|
||||
self.network.update_cell_type("quant")
|
||||
return network
|
||||
|
||||
def _convert_subcells2quant(self, network):
|
||||
"""
|
||||
convet sub cell to quant cell
|
||||
convert sub cell like `Conv2dBnAct` and `DenseBnAct` to quant cell
|
||||
"""
|
||||
cells = network.name_cells()
|
||||
change = False
|
||||
|
@ -138,13 +138,13 @@ class ConvertToQuantNetwork:
|
|||
if isinstance(network, nn.SequentialCell) and change:
|
||||
network.cell_list = list(network.cells())
|
||||
|
||||
# tensoradd to tensoradd quant
|
||||
# add FakeQuant OP after OP in while list
|
||||
add_list = []
|
||||
for name in network.__dict__:
|
||||
if name[0] == '_':
|
||||
continue
|
||||
attr = network.__dict__[name]
|
||||
if isinstance(attr, ops.Primitive) and attr.name in ConvertToQuantNetwork.__quant_op_name__:
|
||||
if isinstance(attr, ops.Primitive) and attr.name in self.__quant_op_name__:
|
||||
add_list.append((name, attr))
|
||||
for name, prim_op in add_list:
|
||||
prefix = name
|
||||
|
@ -164,11 +164,11 @@ class ConvertToQuantNetwork:
|
|||
|
||||
def _convert_conv(self, subcell):
|
||||
"""
|
||||
convet conv cell to quant cell
|
||||
convert Conv2d cell to quant cell
|
||||
"""
|
||||
conv_inner = subcell.conv
|
||||
bn_inner = subcell.batchnorm
|
||||
if subcell.batchnorm is not None and self.bn_fold:
|
||||
if subcell.has_bn and self.bn_fold:
|
||||
bn_inner = subcell.batchnorm
|
||||
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
|
||||
conv_inner.out_channels,
|
||||
kernel_size=conv_inner.kernel_size,
|
||||
|
@ -178,7 +178,7 @@ class ConvertToQuantNetwork:
|
|||
dilation=conv_inner.dilation,
|
||||
group=conv_inner.group,
|
||||
eps=bn_inner.eps,
|
||||
momentum=bn_inner.momentum,
|
||||
momentum=1 - bn_inner.momentum,
|
||||
quant_delay=self.weight_qdelay,
|
||||
freeze_bn=self.freeze_bn,
|
||||
per_channel=self.weight_channel,
|
||||
|
@ -186,6 +186,11 @@ class ConvertToQuantNetwork:
|
|||
fake=True,
|
||||
symmetric=self.weight_symmetric,
|
||||
narrow_range=self.weight_range)
|
||||
# change original network BatchNormal OP parameters to quant network
|
||||
conv_inner.gamma = subcell.batchnorm.gamma
|
||||
conv_inner.beta = subcell.batchnorm.beta
|
||||
conv_inner.moving_mean = subcell.batchnorm.moving_mean
|
||||
conv_inner.moving_variance = subcell.batchnorm.moving_variance
|
||||
del subcell.batchnorm
|
||||
subcell.batchnorm = None
|
||||
subcell.has_bn = False
|
||||
|
@ -204,6 +209,10 @@ class ConvertToQuantNetwork:
|
|||
num_bits=self.weight_bits,
|
||||
symmetric=self.weight_symmetric,
|
||||
narrow_range=self.weight_range)
|
||||
# change original network Conv2D OP parameters to quant network
|
||||
conv_inner.weight = subcell.conv.weight
|
||||
if subcell.conv.has_bias:
|
||||
conv_inner.bias = subcell.conv.bias
|
||||
subcell.conv = conv_inner
|
||||
if subcell.has_act and subcell.activation is not None:
|
||||
subcell.activation = self._convert_activation(subcell.activation)
|
||||
|
@ -230,6 +239,10 @@ class ConvertToQuantNetwork:
|
|||
per_channel=self.weight_channel,
|
||||
symmetric=self.weight_symmetric,
|
||||
narrow_range=self.weight_range)
|
||||
# change original network Dense OP parameters to quant network
|
||||
dense_inner.weight = subcell.dense.weight
|
||||
if subcell.dense.has_bias:
|
||||
dense_inner.bias = subcell.dense.bias
|
||||
subcell.dense = dense_inner
|
||||
if subcell.has_act and subcell.activation is not None:
|
||||
subcell.activation = self._convert_activation(subcell.activation)
|
||||
|
@ -247,12 +260,12 @@ class ConvertToQuantNetwork:
|
|||
act_class = activation.__class__
|
||||
if act_class not in _ACTIVATION_MAP:
|
||||
raise ValueError(
|
||||
"Unsupported activation in auto Quant: ", act_class)
|
||||
"Unsupported activation in auto quant: ", act_class)
|
||||
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits,
|
||||
quant_delay=self.act_qdelay,
|
||||
per_channel=self.act_channel,
|
||||
symmetric=self.weight_symmetric,
|
||||
narrow_range=self.weight_range)
|
||||
symmetric=self.act_symmetric,
|
||||
narrow_range=self.act_range)
|
||||
|
||||
|
||||
class ExportQuantNetworkDeploy:
|
||||
|
|
|
@ -40,8 +40,6 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
|
|||
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
|
||||
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
|
||||
|
||||
ModelType = ["normal", "fusion", "quant"]
|
||||
|
||||
|
||||
def _special_process_par(par, new_par):
|
||||
"""
|
||||
|
@ -103,7 +101,7 @@ def _update_param(param, new_param):
|
|||
param.set_parameter_data(type(param.data)(new_param.data))
|
||||
|
||||
|
||||
def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
|
||||
def save_checkpoint(parameter_list, ckpt_file_name):
|
||||
"""
|
||||
Saves checkpoint info to a specified file.
|
||||
|
||||
|
@ -111,14 +109,12 @@ def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
|
|||
parameter_list (list): Parameters list, each element is a dict
|
||||
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
|
||||
ckpt_file_name (str): Checkpoint file name.
|
||||
model_type (str): The name of model type. Default: "normal".
|
||||
|
||||
Raises:
|
||||
RuntimeError: Failed to save the Checkpoint file.
|
||||
"""
|
||||
logger.info("Execute save checkpoint process.")
|
||||
checkpoint_list = Checkpoint()
|
||||
checkpoint_list.model_type = model_type
|
||||
|
||||
try:
|
||||
for param in parameter_list:
|
||||
|
@ -147,13 +143,12 @@ def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
|
|||
logger.info("Save checkpoint process finish.")
|
||||
|
||||
|
||||
def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
|
||||
def load_checkpoint(ckpt_file_name, net=None):
|
||||
"""
|
||||
Loads checkpoint info from a specified file.
|
||||
|
||||
Args:
|
||||
ckpt_file_name (str): Checkpoint file name.
|
||||
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
|
||||
net (Cell): Cell network. Default: None
|
||||
|
||||
Returns:
|
||||
|
@ -165,9 +160,6 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
|
|||
if not isinstance(ckpt_file_name, str):
|
||||
raise ValueError("The ckpt_file_name must be string.")
|
||||
|
||||
if model_type not in ModelType:
|
||||
raise ValueError(f"The model_type is not in {ModelType}.")
|
||||
|
||||
if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt":
|
||||
raise ValueError("Please input the correct checkpoint file name.")
|
||||
|
||||
|
@ -186,10 +178,6 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
|
|||
raise ValueError(e.__str__())
|
||||
|
||||
parameter_dict = {}
|
||||
if checkpoint_list.model_type:
|
||||
if model_type != checkpoint_list.model_type:
|
||||
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
|
||||
checkpoint_list.model_type, model_type))
|
||||
try:
|
||||
for element in checkpoint_list.value:
|
||||
data = element.tensor.tensor_content
|
||||
|
@ -314,14 +302,13 @@ def _save_graph(network, file_name):
|
|||
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
||||
|
||||
|
||||
def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", integrated_save=True):
|
||||
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
|
||||
"""
|
||||
Saves checkpoint for 'ms' backend.
|
||||
|
||||
Args:
|
||||
train_network (Network): The train network for training.
|
||||
ckpt_file_name (str): The name of checkpoint file.
|
||||
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
|
||||
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
|
||||
"""
|
||||
|
||||
|
@ -346,7 +333,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", in
|
|||
each_param["data"] = param_data
|
||||
param_list.append(each_param)
|
||||
|
||||
save_checkpoint(param_list, ckpt_file_name, model_type)
|
||||
save_checkpoint(param_list, ckpt_file_name)
|
||||
|
||||
|
||||
def _get_merged_param_data(net, param_name, param_data):
|
||||
|
|
|
@ -33,7 +33,7 @@ Then you will get the following display
|
|||
```bash
|
||||
>>> Found existing installation: mindspore-ascend
|
||||
>>> Uninstalling mindspore-ascend:
|
||||
>>> Successfully uninstalled mindspore-ascend.
|
||||
>>> Successfully uninstalled mindspore-ascend.
|
||||
```
|
||||
|
||||
### Prepare Dataset
|
||||
|
@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
|||
|
||||
### train quantization aware model
|
||||
|
||||
Also, you can just run this command instread.
|
||||
Also, you can just run this command instead.
|
||||
|
||||
```python
|
||||
python train_quant.py --data_path MNIST_Data --device_target Ascend --ckpt_path checkpoint_lenet.ckpt
|
||||
|
@ -235,7 +235,7 @@ The top1 accuracy would display on shell.
|
|||
Here are some optional parameters:
|
||||
|
||||
```bash
|
||||
--device_target {Ascend,GPU,CPU}
|
||||
--device_target {Ascend,GPU}
|
||||
device where the code will be implemented (default: Ascend)
|
||||
--data_path DATA_PATH
|
||||
path where the dataset is saved
|
||||
|
|
|
@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
|||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend', 'GPU', 'CPU'],
|
||||
choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||
help='path where the dataset is saved')
|
||||
|
|
|
@ -32,7 +32,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
|||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend', 'GPU', 'CPU'],
|
||||
choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||
help='path where the dataset is saved')
|
||||
|
@ -61,7 +61,7 @@ if __name__ == "__main__":
|
|||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
# load quantization aware network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_path, model_type="quant")
|
||||
param_dict = load_checkpoint(args.ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
print("============== Starting Testing ==============")
|
||||
|
|
|
@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
|||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend', 'GPU', 'CPU'],
|
||||
choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||
help='path where the dataset is saved')
|
||||
|
@ -56,8 +56,7 @@ if __name__ == "__main__":
|
|||
# call back and monitor
|
||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max,
|
||||
model_type=network.type)
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
||||
|
||||
# define model
|
||||
|
|
|
@ -33,7 +33,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
|
|||
|
||||
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend', 'GPU', 'CPU'],
|
||||
choices=['Ascend', 'GPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
|
||||
help='path where the dataset is saved')
|
||||
|
@ -50,11 +50,13 @@ if __name__ == "__main__":
|
|||
|
||||
# define fusion network
|
||||
network = LeNet5Fusion(cfg.num_classes)
|
||||
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
||||
|
||||
# load quantization aware network checkpoint
|
||||
param_dict = load_checkpoint(args.ckpt_path, network.type)
|
||||
load_param_into_net(network, param_dict)
|
||||
# convert fusion network to quantization aware network
|
||||
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
|
||||
|
||||
# define network loss
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
|
||||
|
@ -64,8 +66,7 @@ if __name__ == "__main__":
|
|||
# call back and monitor
|
||||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max,
|
||||
model_type="quant")
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
|
||||
|
||||
# define model
|
||||
|
|
|
@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
|||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
if [ -d "eval" ];
|
||||
if [ -d "../eval" ];
|
||||
then
|
||||
rm -rf ../eval
|
||||
fi
|
||||
|
|
|
@ -62,7 +62,7 @@ run_gpu()
|
|||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
if [ -d "train" ];
|
||||
if [ -d "../train" ];
|
||||
then
|
||||
rm -rf ../train
|
||||
fi
|
||||
|
|
|
@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
|||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
if [ -d "eval" ];
|
||||
if [ -d "../eval" ];
|
||||
then
|
||||
rm -rf ../eval
|
||||
fi
|
||||
|
|
|
@ -60,7 +60,7 @@ run_gpu()
|
|||
|
||||
BASEPATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
|
||||
if [ -d "train" ];
|
||||
if [ -d "../train" ];
|
||||
then
|
||||
rm -rf ../train
|
||||
fi
|
||||
|
|
|
@ -31,7 +31,7 @@ def _conv_bn(in_channel,
|
|||
out_channel,
|
||||
kernel_size=ksize,
|
||||
stride=stride,
|
||||
batchnorm=True)])
|
||||
has_bn=True)])
|
||||
|
||||
|
||||
class InvertedResidual(nn.Cell):
|
||||
|
@ -49,25 +49,25 @@ class InvertedResidual(nn.Cell):
|
|||
3,
|
||||
stride,
|
||||
group=hidden_dim,
|
||||
batchnorm=True,
|
||||
has_bn=True,
|
||||
activation='relu6'),
|
||||
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
|
||||
batchnorm=True)
|
||||
has_bn=True)
|
||||
])
|
||||
else:
|
||||
self.conv = nn.SequentialCell([
|
||||
nn.Conv2dBnAct(inp, hidden_dim, 1, 1,
|
||||
batchnorm=True,
|
||||
has_bn=True,
|
||||
activation='relu6'),
|
||||
nn.Conv2dBnAct(hidden_dim,
|
||||
hidden_dim,
|
||||
3,
|
||||
stride,
|
||||
group=hidden_dim,
|
||||
batchnorm=True,
|
||||
has_bn=True,
|
||||
activation='relu6'),
|
||||
nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
|
||||
batchnorm=True)
|
||||
has_bn=True)
|
||||
])
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ class LeNet5(nn.Cell):
|
|||
def __init__(self, num_class=10):
|
||||
super(LeNet5, self).__init__()
|
||||
self.num_class = num_class
|
||||
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6', pad_mode="valid")
|
||||
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, has_bn=True, activation='relu6', pad_mode="valid")
|
||||
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid")
|
||||
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
|
||||
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
|
||||
|
|
Loading…
Reference in New Issue