forked from mindspore-Ecosystem/mindspore
!3457 add LeakReLUQuant OP for bug fix.
Merge pull request !3457 from chenzhongming/master
This commit is contained in:
commit
521f1e5938
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Quantization aware."""
|
"""Quantization aware training."""
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -43,6 +43,7 @@ __all__ = [
|
||||||
'Conv2dQuant',
|
'Conv2dQuant',
|
||||||
'DenseQuant',
|
'DenseQuant',
|
||||||
'ActQuant',
|
'ActQuant',
|
||||||
|
'LeakyReLUQuant',
|
||||||
'HSwishQuant',
|
'HSwishQuant',
|
||||||
'HSigmoidQuant',
|
'HSigmoidQuant',
|
||||||
'TensorAddQuant',
|
'TensorAddQuant',
|
||||||
|
@ -349,7 +350,7 @@ class FakeQuantWithMinMax(Cell):
|
||||||
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
|
self.maxq = Parameter(Tensor(max_array), name='quant_max', requires_grad=False)
|
||||||
|
|
||||||
# init fake quant relative op
|
# init fake quant relative op
|
||||||
if per_channel:
|
if self.per_channel:
|
||||||
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
|
quant_fun = partial(Q.FakeQuantPerChannel, channel_axis=self.channel_axis)
|
||||||
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
|
ema_fun = partial(Q.MinMaxUpdatePerChannel, channel_axis=self.channel_axis)
|
||||||
else:
|
else:
|
||||||
|
@ -369,7 +370,7 @@ class FakeQuantWithMinMax(Cell):
|
||||||
num_bits=self.num_bits,
|
num_bits=self.num_bits,
|
||||||
symmetric=self.symmetric,
|
symmetric=self.symmetric,
|
||||||
narrow_range=self.narrow_range,
|
narrow_range=self.narrow_range,
|
||||||
quant_delay=quant_delay)
|
quant_delay=self.quant_delay)
|
||||||
self.fake_quant_train = quant_fun(training=True)
|
self.fake_quant_train = quant_fun(training=True)
|
||||||
self.fake_quant_infer = quant_fun(training=False)
|
self.fake_quant_infer = quant_fun(training=False)
|
||||||
|
|
||||||
|
@ -832,7 +833,7 @@ class ActQuant(_QuantActivation):
|
||||||
Tensor, with the same type and shape as the `x`.
|
Tensor, with the same type and shape as the `x`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> act_quant = nn.ActQuant(4, 1)
|
>>> act_quant = nn.ActQuant(nn.ReLU)
|
||||||
>>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
|
>>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
|
||||||
>>> result = act_quant(input_x)
|
>>> result = act_quant(input_x)
|
||||||
"""
|
"""
|
||||||
|
@ -855,7 +856,7 @@ class ActQuant(_QuantActivation):
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range,
|
narrow_range=narrow_range,
|
||||||
quant_delay=quant_delay)
|
quant_delay=quant_delay)
|
||||||
self.act = activation()
|
self.act = activation
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
x = self.act(x)
|
x = self.act(x)
|
||||||
|
@ -865,6 +866,75 @@ class ActQuant(_QuantActivation):
|
||||||
def get_origin(self):
|
def get_origin(self):
|
||||||
return self.act
|
return self.act
|
||||||
|
|
||||||
|
class LeakyReLUQuant(_QuantActivation):
|
||||||
|
r"""
|
||||||
|
LeakyReLUQuant activation function. Add Fake Quant OP after HSwish OP.
|
||||||
|
|
||||||
|
For a more Detailed overview of HSwish op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
activation (Cell): Activation cell class.
|
||||||
|
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
|
||||||
|
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
|
||||||
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
||||||
|
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
|
||||||
|
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
|
||||||
|
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **x** (Tensor) - The input of HSwishQuant.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, with the same type and shape as the `x`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> activation = nn.LeakyReLUQuant(nn.LeakyReLU())
|
||||||
|
>>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
||||||
|
>>> result = activation(input)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
activation,
|
||||||
|
ema_decay=0.999,
|
||||||
|
per_channel=False,
|
||||||
|
num_bits=8,
|
||||||
|
symmetric=False,
|
||||||
|
narrow_range=False,
|
||||||
|
quant_delay=0):
|
||||||
|
super(LeakyReLUQuant, self).__init__()
|
||||||
|
self.fake_quant_act_before = FakeQuantWithMinMax(min_init=-6,
|
||||||
|
max_init=6,
|
||||||
|
ema=True,
|
||||||
|
ema_decay=ema_decay,
|
||||||
|
per_channel=per_channel,
|
||||||
|
num_bits=num_bits,
|
||||||
|
symmetric=symmetric,
|
||||||
|
narrow_range=narrow_range,
|
||||||
|
quant_delay=quant_delay)
|
||||||
|
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
|
||||||
|
max_init=6,
|
||||||
|
ema=True,
|
||||||
|
ema_decay=ema_decay,
|
||||||
|
per_channel=per_channel,
|
||||||
|
num_bits=num_bits,
|
||||||
|
symmetric=symmetric,
|
||||||
|
narrow_range=narrow_range,
|
||||||
|
quant_delay=quant_delay)
|
||||||
|
if issubclass(activation.__class__, nn.LeakyReLU):
|
||||||
|
self.act = activation
|
||||||
|
else:
|
||||||
|
raise ValueError("Activation should be `nn.LeakyReLU`")
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.fake_quant_act_before(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.fake_quant_act_after(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def get_origin(self):
|
||||||
|
return self.act
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HSwishQuant(_QuantActivation):
|
class HSwishQuant(_QuantActivation):
|
||||||
r"""
|
r"""
|
||||||
|
@ -888,9 +958,9 @@ class HSwishQuant(_QuantActivation):
|
||||||
Tensor, with the same type and shape as the `x`.
|
Tensor, with the same type and shape as the `x`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> hswish_quant = nn.HSwishQuant(4, 1)
|
>>> activation = nn.HSwishQuant(nn.HSwish())
|
||||||
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
>>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
||||||
>>> result = hswish_quant(input_x)
|
>>> result = activation(input)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -920,8 +990,8 @@ class HSwishQuant(_QuantActivation):
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range,
|
narrow_range=narrow_range,
|
||||||
quant_delay=quant_delay)
|
quant_delay=quant_delay)
|
||||||
if issubclass(activation, nn.HSwish):
|
if issubclass(activation.__class__, nn.HSwish):
|
||||||
self.act = activation()
|
self.act = activation
|
||||||
else:
|
else:
|
||||||
raise ValueError("Activation should be `nn.HSwish`")
|
raise ValueError("Activation should be `nn.HSwish`")
|
||||||
|
|
||||||
|
@ -957,9 +1027,9 @@ class HSigmoidQuant(_QuantActivation):
|
||||||
Tensor, with the same type and shape as the `x`.
|
Tensor, with the same type and shape as the `x`.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> hsigmoid_quant = nn.HSigmoidQuant(4, 1)
|
>>> activation = nn.HSigmoidQuant(nn.HSigmoid())
|
||||||
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
>>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
|
||||||
>>> result = hsigmoid_quant(input_x)
|
>>> result = activation(input)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -989,8 +1059,8 @@ class HSigmoidQuant(_QuantActivation):
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
narrow_range=narrow_range,
|
narrow_range=narrow_range,
|
||||||
quant_delay=quant_delay)
|
quant_delay=quant_delay)
|
||||||
if issubclass(activation, nn.HSigmoid):
|
if issubclass(activation.__class__, nn.HSigmoid):
|
||||||
self.act = activation()
|
self.act = activation
|
||||||
else:
|
else:
|
||||||
raise ValueError("Activation should be `nn.HSigmoid`")
|
raise ValueError("Activation should be `nn.HSigmoid`")
|
||||||
|
|
||||||
|
|
|
@ -386,6 +386,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
|
||||||
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
|
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
|
||||||
if not self.is_ascend:
|
if not self.is_ascend:
|
||||||
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
|
||||||
|
if len(x_shape) == 1:
|
||||||
|
self.channel_axis = 0
|
||||||
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
|
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
|
||||||
validator.check_integer(
|
validator.check_integer(
|
||||||
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
|
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
|
||||||
|
|
|
@ -35,8 +35,8 @@ from . import quant_utils
|
||||||
|
|
||||||
_ACTIVATION_MAP = {nn.ReLU: quant.ActQuant,
|
_ACTIVATION_MAP = {nn.ReLU: quant.ActQuant,
|
||||||
nn.ReLU6: quant.ActQuant,
|
nn.ReLU6: quant.ActQuant,
|
||||||
nn.LeakyReLU: quant.ActQuant,
|
|
||||||
nn.Sigmoid: quant.ActQuant,
|
nn.Sigmoid: quant.ActQuant,
|
||||||
|
nn.LeakyReLU: quant.LeakyReLUQuant,
|
||||||
nn.HSigmoid: quant.HSigmoidQuant,
|
nn.HSigmoid: quant.HSigmoidQuant,
|
||||||
nn.HSwish: quant.HSwishQuant}
|
nn.HSwish: quant.HSwishQuant}
|
||||||
|
|
||||||
|
@ -167,32 +167,35 @@ class ConvertToQuantNetwork:
|
||||||
convert Conv2d cell to quant cell
|
convert Conv2d cell to quant cell
|
||||||
"""
|
"""
|
||||||
conv_inner = subcell.conv
|
conv_inner = subcell.conv
|
||||||
if subcell.has_bn and self.bn_fold:
|
if subcell.has_bn:
|
||||||
bn_inner = subcell.batchnorm
|
if self.bn_fold:
|
||||||
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
|
bn_inner = subcell.batchnorm
|
||||||
conv_inner.out_channels,
|
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
|
||||||
kernel_size=conv_inner.kernel_size,
|
conv_inner.out_channels,
|
||||||
stride=conv_inner.stride,
|
kernel_size=conv_inner.kernel_size,
|
||||||
pad_mode=conv_inner.pad_mode,
|
stride=conv_inner.stride,
|
||||||
padding=conv_inner.padding,
|
pad_mode=conv_inner.pad_mode,
|
||||||
dilation=conv_inner.dilation,
|
padding=conv_inner.padding,
|
||||||
group=conv_inner.group,
|
dilation=conv_inner.dilation,
|
||||||
eps=bn_inner.eps,
|
group=conv_inner.group,
|
||||||
quant_delay=self.weight_qdelay,
|
eps=bn_inner.eps,
|
||||||
freeze_bn=self.freeze_bn,
|
quant_delay=self.weight_qdelay,
|
||||||
per_channel=self.weight_channel,
|
freeze_bn=self.freeze_bn,
|
||||||
num_bits=self.weight_bits,
|
per_channel=self.weight_channel,
|
||||||
fake=True,
|
num_bits=self.weight_bits,
|
||||||
symmetric=self.weight_symmetric,
|
fake=True,
|
||||||
narrow_range=self.weight_range)
|
symmetric=self.weight_symmetric,
|
||||||
# change original network BatchNormal OP parameters to quant network
|
narrow_range=self.weight_range)
|
||||||
conv_inner.gamma = subcell.batchnorm.gamma
|
# change original network BatchNormal OP parameters to quant network
|
||||||
conv_inner.beta = subcell.batchnorm.beta
|
conv_inner.gamma = subcell.batchnorm.gamma
|
||||||
conv_inner.moving_mean = subcell.batchnorm.moving_mean
|
conv_inner.beta = subcell.batchnorm.beta
|
||||||
conv_inner.moving_variance = subcell.batchnorm.moving_variance
|
conv_inner.moving_mean = subcell.batchnorm.moving_mean
|
||||||
del subcell.batchnorm
|
conv_inner.moving_variance = subcell.batchnorm.moving_variance
|
||||||
subcell.batchnorm = None
|
del subcell.batchnorm
|
||||||
subcell.has_bn = False
|
subcell.batchnorm = None
|
||||||
|
subcell.has_bn = False
|
||||||
|
else:
|
||||||
|
raise ValueError("Only support Batchnorm fold mode.")
|
||||||
else:
|
else:
|
||||||
conv_inner = quant.Conv2dQuant(conv_inner.in_channels,
|
conv_inner = quant.Conv2dQuant(conv_inner.in_channels,
|
||||||
conv_inner.out_channels,
|
conv_inner.out_channels,
|
||||||
|
@ -259,7 +262,7 @@ class ConvertToQuantNetwork:
|
||||||
act_class = activation.__class__
|
act_class = activation.__class__
|
||||||
if act_class not in _ACTIVATION_MAP:
|
if act_class not in _ACTIVATION_MAP:
|
||||||
raise ValueError("Unsupported activation in auto quant: ", act_class)
|
raise ValueError("Unsupported activation in auto quant: ", act_class)
|
||||||
return _ACTIVATION_MAP[act_class](activation=act_class,
|
return _ACTIVATION_MAP[act_class](activation=activation,
|
||||||
num_bits=self.act_bits,
|
num_bits=self.act_bits,
|
||||||
quant_delay=self.act_qdelay,
|
quant_delay=self.act_qdelay,
|
||||||
per_channel=self.act_channel,
|
per_channel=self.act_channel,
|
||||||
|
|
Loading…
Reference in New Issue