forked from mindspore-Ecosystem/mindspore
!3457 add LeakReLUQuant OP for bug fix.
Merge pull request !3457 from chenzhongming/master
This commit is contained in:
commit
521f1e5938
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Quantization aware."""
|
||||
"""Quantization aware training."""
|
||||
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
@ -43,6 +43,7 @@ __all__ = [
|
|||
'Conv2dQuant',
|
||||
'DenseQuant',
|
||||
'ActQuant',
|
||||
'LeakyReLUQuant',
|
||||
'HSwishQuant',
|
||||
'HSigmoidQuant',
|
||||
'TensorAddQuant',
|
||||
|
@ -349,7 +350,7 @@ class FakeQuantWithMinMax(Cell):
|
|||
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
|
||||
|
||||
# init fake quant relative op
|
||||
if per_channel:
|
||||
if self.per_channel:
|
||||
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
|
||||
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
|
||||
else:
|
||||
|
@ -369,7 +370,7 @@ class FakeQuantWithMinMax(Cell):
|
|||
num_bits=self.num_bits,
|
||||
symmetric=self.symmetric,
|
||||
narrow_range=self.narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
quant_delay=self.quant_delay)
|
||||
self.fake_quant_train = quant_fun(training=True)
|
||||
self.fake_quant_infer = quant_fun(training=False)
|
||||
|
||||
|
@ -832,7 +833,7 @@ class ActQuant(_QuantActivation):
|
|||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Examples:
|
||||
>>> act_quant = nn.ActQuant(4, 1)
|
||||
>>> act_quant = nn.ActQuant(nn.ReLU)
|
||||
>>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
|
||||
>>> result = act_quant(input_x)
|
||||
"""
|
||||
|
@ -855,7 +856,7 @@ class ActQuant(_QuantActivation):
|
|||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.act = activation()
|
||||
self.act = activation
|
||||
|
||||
def construct(self, x):
|
||||
x = self.act(x)
|
||||
|
@ -865,6 +866,75 @@ class ActQuant(_QuantActivation):
|
|||
def get_origin(self):
|
||||
return self.act
|
||||
|
||||
class LeakyReLUQuant(_QuantActivation):
|
||||
r"""
|
||||
LeakyReLUQuant activation function. Add Fake Quant OP after HSwish OP.
|
||||
|
||||
For a more Detailed overview of HSwish op.
|
||||
|
||||
Args:
|
||||
activation (Cell): Activation cell class.
|
||||
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - The input of HSwishQuant.
|
||||
|
||||
Outputs:
|
||||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Examples:
|
||||
>>> activation = nn.LeakyReLUQuant(nn.LeakyReLU())
|
||||
>>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
||||
>>> result = activation(input)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
activation,
|
||||
ema_decay=0.999,
|
||||
per_channel=False,
|
||||
num_bits=8,
|
||||
symmetric=False,
|
||||
narrow_range=False,
|
||||
quant_delay=0):
|
||||
super(LeakyReLUQuant, self).__init__()
|
||||
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
|
||||
max_init=6,
|
||||
ema=True,
|
||||
ema_decay=ema_decay,
|
||||
per_channel=per_channel,
|
||||
num_bits=num_bits,
|
||||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
if issubclass(activation.__class__, nn.LeakyReLU):
|
||||
self.act = activation
|
||||
else:
|
||||
raise ValueError("Activation should be `nn.LeakyReLU`")
|
||||
|
||||
def construct(self, x):
|
||||
x = self.fake_quant_act_before(x)
|
||||
x = self.act(x)
|
||||
x = self.fake_quant_act_after(x)
|
||||
return x
|
||||
|
||||
def get_origin(self):
|
||||
return self.act
|
||||
|
||||
|
||||
|
||||
class HSwishQuant(_QuantActivation):
|
||||
r"""
|
||||
|
@ -888,9 +958,9 @@ class HSwishQuant(_QuantActivation):
|
|||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Examples:
|
||||
>>> hswish_quant = nn.HSwishQuant(4, 1)
|
||||
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
||||
>>> result = hswish_quant(input_x)
|
||||
>>> activation = nn.HSwishQuant(nn.HSwish())
|
||||
>>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
||||
>>> result = activation(input)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -920,8 +990,8 @@ class HSwishQuant(_QuantActivation):
|
|||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
if issubclass(activation, nn.HSwish):
|
||||
self.act = activation()
|
||||
if issubclass(activation.__class__, nn.HSwish):
|
||||
self.act = activation
|
||||
else:
|
||||
raise ValueError("Activation should be `nn.HSwish`")
|
||||
|
||||
|
@ -957,9 +1027,9 @@ class HSigmoidQuant(_QuantActivation):
|
|||
Tensor, with the same type and shape as the `x`.
|
||||
|
||||
Examples:
|
||||
>>> hsigmoid_quant = nn.HSigmoidQuant(4, 1)
|
||||
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
||||
>>> result = hsigmoid_quant(input_x)
|
||||
>>> activation = nn.HSigmoidQuant(nn.HSigmoid())
|
||||
>>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
||||
>>> result = activation(input)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -989,8 +1059,8 @@ class HSigmoidQuant(_QuantActivation):
|
|||
symmetric=symmetric,
|
||||
narrow_range=narrow_range,
|
||||
quant_delay=quant_delay)
|
||||
if issubclass(activation, nn.HSigmoid):
|
||||
self.act = activation()
|
||||
if issubclass(activation.__class__, nn.HSigmoid):
|
||||
self.act = activation
|
||||
else:
|
||||
raise ValueError("Activation should be `nn.HSigmoid`")
|
||||
|
||||
|
|
|
@ -386,6 +386,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
|||
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
|
||||
if not self.is_ascend:
|
||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||
if len(x_shape) == 1:
|
||||
self.channel_axis = 0
|
||||
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
|
||||
validator.check_integer(
|
||||
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
|
||||
|
|
|
@ -35,8 +35,8 @@ from . import quant_utils
|
|||
|
||||
_ACTIVATION_MAP = {nn.ReLU: quant.ActQuant,
|
||||
nn.ReLU6: quant.ActQuant,
|
||||
nn.LeakyReLU: quant.ActQuant,
|
||||
nn.Sigmoid: quant.ActQuant,
|
||||
nn.LeakyReLU: quant.LeakyReLUQuant,
|
||||
nn.HSigmoid: quant.HSigmoidQuant,
|
||||
nn.HSwish: quant.HSwishQuant}
|
||||
|
||||
|
@ -167,32 +167,35 @@ class ConvertToQuantNetwork:
|
|||
convert Conv2d cell to quant cell
|
||||
"""
|
||||
conv_inner = subcell.conv
|
||||
if subcell.has_bn and self.bn_fold:
|
||||
bn_inner = subcell.batchnorm
|
||||
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
|
||||
conv_inner.out_channels,
|
||||
kernel_size=conv_inner.kernel_size,
|
||||
stride=conv_inner.stride,
|
||||
pad_mode=conv_inner.pad_mode,
|
||||
padding=conv_inner.padding,
|
||||
dilation=conv_inner.dilation,
|
||||
group=conv_inner.group,
|
||||
eps=bn_inner.eps,
|
||||
quant_delay=self.weight_qdelay,
|
||||
freeze_bn=self.freeze_bn,
|
||||
per_channel=self.weight_channel,
|
||||
num_bits=self.weight_bits,
|
||||
fake=True,
|
||||
symmetric=self.weight_symmetric,
|
||||
narrow_range=self.weight_range)
|
||||
# change original network BatchNormal OP parameters to quant network
|
||||
conv_inner.gamma = subcell.batchnorm.gamma
|
||||
conv_inner.beta = subcell.batchnorm.beta
|
||||
conv_inner.moving_mean = subcell.batchnorm.moving_mean
|
||||
conv_inner.moving_variance = subcell.batchnorm.moving_variance
|
||||
del subcell.batchnorm
|
||||
subcell.batchnorm = None
|
||||
subcell.has_bn = False
|
||||
if subcell.has_bn:
|
||||
if self.bn_fold:
|
||||
bn_inner = subcell.batchnorm
|
||||
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
|
||||
conv_inner.out_channels,
|
||||
kernel_size=conv_inner.kernel_size,
|
||||
stride=conv_inner.stride,
|
||||
pad_mode=conv_inner.pad_mode,
|
||||
padding=conv_inner.padding,
|
||||
dilation=conv_inner.dilation,
|
||||
group=conv_inner.group,
|
||||
eps=bn_inner.eps,
|
||||
quant_delay=self.weight_qdelay,
|
||||
freeze_bn=self.freeze_bn,
|
||||
per_channel=self.weight_channel,
|
||||
num_bits=self.weight_bits,
|
||||
fake=True,
|
||||
symmetric=self.weight_symmetric,
|
||||
narrow_range=self.weight_range)
|
||||
# change original network BatchNormal OP parameters to quant network
|
||||
conv_inner.gamma = subcell.batchnorm.gamma
|
||||
conv_inner.beta = subcell.batchnorm.beta
|
||||
conv_inner.moving_mean = subcell.batchnorm.moving_mean
|
||||
conv_inner.moving_variance = subcell.batchnorm.moving_variance
|
||||
del subcell.batchnorm
|
||||
subcell.batchnorm = None
|
||||
subcell.has_bn = False
|
||||
else:
|
||||
raise ValueError("Only support Batchnorm fold mode.")
|
||||
else:
|
||||
conv_inner = quant.Conv2dQuant(conv_inner.in_channels,
|
||||
conv_inner.out_channels,
|
||||
|
@ -259,7 +262,7 @@ class ConvertToQuantNetwork:
|
|||
act_class = activation.__class__
|
||||
if act_class not in _ACTIVATION_MAP:
|
||||
raise ValueError("Unsupported activation in auto quant: ", act_class)
|
||||
return _ACTIVATION_MAP[act_class](activation=act_class,
|
||||
return _ACTIVATION_MAP[act_class](activation=activation,
|
||||
num_bits=self.act_bits,
|
||||
quant_delay=self.act_qdelay,
|
||||
per_channel=self.act_channel,
|
||||
|
|
Loading…
Reference in New Issue