forked from mindspore-Ecosystem/mindspore
!6034 [VM][Quant]Add fake quant to ir when run on ascend
Merge pull request !6034 from chenfei_mindspore/add-quant-delay-in-ascend
This commit is contained in:
commit
a26fdb83ee
|
@ -372,7 +372,8 @@ class FakeQuantWithMinMax(Cell):
|
||||||
if self.is_ascend:
|
if self.is_ascend:
|
||||||
self.fake_quant_train = quant_fun(num_bits=self.num_bits,
|
self.fake_quant_train = quant_fun(num_bits=self.num_bits,
|
||||||
symmetric=self.symmetric,
|
symmetric=self.symmetric,
|
||||||
narrow_range=self.narrow_range)
|
narrow_range=self.narrow_range,
|
||||||
|
quant_delay=self.quant_delay)
|
||||||
self.fake_quant_infer = self.fake_quant_train
|
self.fake_quant_infer = self.fake_quant_train
|
||||||
else:
|
else:
|
||||||
quant_fun = partial(quant_fun,
|
quant_fun = partial(quant_fun,
|
||||||
|
@ -679,15 +680,24 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
||||||
self.group = group
|
self.group = group
|
||||||
self.quant_delay = quant_delay
|
self.quant_delay = quant_delay
|
||||||
|
|
||||||
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
|
|
||||||
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
|
|
||||||
|
|
||||||
self.bias_add = P.BiasAdd()
|
self.bias_add = P.BiasAdd()
|
||||||
if check_bool(has_bias):
|
if check_bool(has_bias):
|
||||||
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
# initialize convolution op and Parameter
|
||||||
|
if context.get_context('device_target') == "Ascend" and group > 1:
|
||||||
|
validator.check_integer('group', group, in_channels, Rel.EQ)
|
||||||
|
validator.check_integer('group', group, out_channels, Rel.EQ)
|
||||||
|
self.conv = P.DepthwiseConv2dNative(channel_multiplier=1,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
pad=padding,
|
||||||
|
stride=self.stride,
|
||||||
|
dilation=self.dilation)
|
||||||
|
weight_shape = [1, in_channels, *self.kernel_size]
|
||||||
|
channel_axis = 1
|
||||||
|
else:
|
||||||
self.conv = P.Conv2D(out_channel=self.out_channels,
|
self.conv = P.Conv2D(out_channel=self.out_channels,
|
||||||
kernel_size=self.kernel_size,
|
kernel_size=self.kernel_size,
|
||||||
mode=1,
|
mode=1,
|
||||||
|
@ -696,11 +706,14 @@ class Conv2dBnWithoutFoldQuant(Cell):
|
||||||
stride=self.stride,
|
stride=self.stride,
|
||||||
dilation=self.dilation,
|
dilation=self.dilation,
|
||||||
group=self.group)
|
group=self.group)
|
||||||
|
weight_shape = [out_channels, in_channels // group, *self.kernel_size]
|
||||||
|
channel_axis = 0
|
||||||
|
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
|
||||||
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
|
||||||
max_init=6,
|
max_init=6,
|
||||||
ema=False,
|
ema=False,
|
||||||
per_channel=per_channel,
|
per_channel=per_channel,
|
||||||
channel_axis=0,
|
channel_axis=channel_axis,
|
||||||
num_channels=out_channels,
|
num_channels=out_channels,
|
||||||
num_bits=num_bits,
|
num_bits=num_bits,
|
||||||
symmetric=symmetric,
|
symmetric=symmetric,
|
||||||
|
@ -1009,6 +1022,7 @@ class ActQuant(_QuantActivation):
|
||||||
def get_origin(self):
|
def get_origin(self):
|
||||||
return self.act
|
return self.act
|
||||||
|
|
||||||
|
|
||||||
class LeakyReLUQuant(_QuantActivation):
|
class LeakyReLUQuant(_QuantActivation):
|
||||||
r"""
|
r"""
|
||||||
LeakyReLUQuant activation function. Add Fake Quant OP after HSwish OP.
|
LeakyReLUQuant activation function. Add Fake Quant OP after HSwish OP.
|
||||||
|
@ -1078,7 +1092,6 @@ class LeakyReLUQuant(_QuantActivation):
|
||||||
return self.act
|
return self.act
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HSwishQuant(_QuantActivation):
|
class HSwishQuant(_QuantActivation):
|
||||||
r"""
|
r"""
|
||||||
HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
|
HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
|
||||||
|
|
Loading…
Reference in New Issue