forked from mindspore-Ecosystem/mindspore
!2455 add perchannel quant train
Merge pull request !2455 from chenzupeng/r0.3
This commit is contained in:
commit
e368d0524b
|
@ -47,7 +47,6 @@ Dataset used: imagenet
|
|||
├── eval.py
|
||||
```
|
||||
|
||||
Notation: Current hyperparameters only test on 4 cards while training, if want to use 8 cards for training, should change parameters like learning rate in 'src/config.py'.
|
||||
|
||||
## Training process
|
||||
|
||||
|
|
|
@ -22,10 +22,10 @@ config_ascend = ed({
|
|||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"batch_size": 192,
|
||||
"epoch_size": 40,
|
||||
"epoch_size": 60,
|
||||
"start_epoch": 200,
|
||||
"warmup_epochs": 1,
|
||||
"lr": 0.15,
|
||||
"lr": 0.3,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 4e-5,
|
||||
"label_smooth": 0.1,
|
||||
|
|
|
@ -20,7 +20,8 @@ from mindspore.ops.operations import TensorAdd
|
|||
__all__ = ['mobilenet_v2_quant']
|
||||
|
||||
_ema_decay = 0.999
|
||||
_symmetric = False
|
||||
_symmetric = True
|
||||
_per_channel = True
|
||||
|
||||
|
||||
def _make_divisible(v, divisor, min_value=None):
|
||||
|
@ -77,10 +78,10 @@ class ConvBNReLU(nn.Cell):
|
|||
super(ConvBNReLU, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
conv = nn.Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
|
||||
group=groups)
|
||||
group=groups, per_channel=_per_channel, symmetric=_symmetric)
|
||||
layers = [conv, nn.ReLU()]
|
||||
self.features = nn.SequentialCell(layers)
|
||||
self.fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric, min_init=0)
|
||||
self.fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, min_init=0)
|
||||
|
||||
def construct(self, x):
|
||||
output = self.features(x)
|
||||
|
@ -119,12 +120,13 @@ class InvertedResidual(nn.Cell):
|
|||
# dw
|
||||
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
||||
# pw-linear
|
||||
nn.Conv2dBatchNormQuant(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1),
|
||||
nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric)
|
||||
nn.Conv2dBatchNormQuant(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1,
|
||||
per_channel=_per_channel, symmetric=_symmetric),
|
||||
nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay)
|
||||
])
|
||||
self.conv = nn.SequentialCell(layers)
|
||||
self.add = TensorAdd()
|
||||
self.add_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric)
|
||||
self.add_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay)
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
|
@ -175,7 +177,7 @@ class MobileNetV2Quant(nn.Cell):
|
|||
# building first layer
|
||||
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
||||
self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
||||
self.input_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric)
|
||||
self.input_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay)
|
||||
features = [ConvBNReLU(3, input_channel, stride=2)]
|
||||
# building inverted residual blocks
|
||||
for t, c, n, s in self.cfgs:
|
||||
|
@ -189,8 +191,12 @@ class MobileNetV2Quant(nn.Cell):
|
|||
# make it nn.CellList
|
||||
self.features = nn.SequentialCell(features)
|
||||
# mobilenet head
|
||||
head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else
|
||||
[GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)])
|
||||
head = ([GlobalAvgPooling(),
|
||||
nn.DenseQuant(self.out_channels, num_classes, has_bias=True, per_channel=_per_channel,
|
||||
symmetric=_symmetric)] if not has_dropout else
|
||||
[GlobalAvgPooling(), nn.Dropout(0.2),
|
||||
nn.DenseQuant(self.out_channels, num_classes, has_bias=True, per_channel=_per_channel,
|
||||
symmetric=_symmetric)])
|
||||
self.head = nn.SequentialCell(head)
|
||||
|
||||
def construct(self, x):
|
||||
|
|
|
@ -51,7 +51,7 @@ Parameters for both training and inference can be set in config.py.
|
|||
"loss_scale": 1024, # loss scale
|
||||
"momentum": 0.9, # momentum optimizer
|
||||
"weight_decay": 1e-4, # weight decay
|
||||
"epoch_size": 110, # only valid for taining, which is always 1 for inference
|
||||
"epoch_size": 120, # only valid for taining, which is always 1 for inference
|
||||
"pretrained_epoch_size": 90, # epoch size that model has been trained before load pretrained checkpoint
|
||||
"buffer_size": 1000, # number of queue size in data preprocessing
|
||||
"image_height": 224, # image height
|
||||
|
@ -65,7 +65,7 @@ Parameters for both training and inference can be set in config.py.
|
|||
"label_smooth": True, # label smooth
|
||||
"label_smooth_factor": 0.1, # label smooth factor
|
||||
"lr_init": 0, # initial learning rate
|
||||
"lr_max": 0.1, # maximum learning rate
|
||||
"lr_max": 0.005, # maximum learning rate
|
||||
```
|
||||
|
||||
## Running the example
|
||||
|
|
|
@ -22,6 +22,7 @@ from mindspore.nn import FakeQuantWithMinMax, Conv2dBatchNormQuant
|
|||
_ema_decay = 0.999
|
||||
_symmetric = False
|
||||
_fake = True
|
||||
_per_channel = True
|
||||
|
||||
def _weight_variable(shape, factor=0.01):
|
||||
init_value = np.random.randn(*shape).astype(np.float32) * factor
|
||||
|
@ -85,7 +86,7 @@ class ConvBNReLU(nn.Cell):
|
|||
super(ConvBNReLU, self).__init__()
|
||||
padding = (kernel_size - 1) // 2
|
||||
conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
|
||||
group=groups, fake=_fake)
|
||||
group=groups, fake=_fake, per_channel=_per_channel, symmetric=_symmetric)
|
||||
layers = [conv, nn.ReLUQuant()] if _fake else [conv, nn.ReLU()]
|
||||
self.features = nn.SequentialCell(layers)
|
||||
|
||||
|
@ -119,10 +120,13 @@ class ResidualBlock(nn.Cell):
|
|||
channel = out_channel // self.expansion
|
||||
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
|
||||
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
|
||||
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
|
||||
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
kernel_size=1, stride=1, pad_mode='same', padding=0),
|
||||
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=False)
|
||||
]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake,
|
||||
per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
kernel_size=1, stride=1,
|
||||
pad_mode='same', padding=0)
|
||||
|
||||
|
@ -134,18 +138,22 @@ class ResidualBlock(nn.Cell):
|
|||
|
||||
if self.down_sample:
|
||||
self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel,
|
||||
per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
kernel_size=1, stride=stride,
|
||||
pad_mode='same', padding=0),
|
||||
FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay,
|
||||
symmetric=False)
|
||||
]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel,
|
||||
fake=_fake,
|
||||
per_channel=_per_channel,
|
||||
symmetric=_symmetric,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
pad_mode='same',
|
||||
padding=0)
|
||||
self.add = P.TensorAdd()
|
||||
self.fake = FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=False)
|
||||
self.relu = nn.ReLUQuant() if _fake else P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
identity = x
|
||||
|
@ -157,9 +165,7 @@ class ResidualBlock(nn.Cell):
|
|||
identity = self.down_sample_layer(identity)
|
||||
|
||||
out = self.add(out, identity)
|
||||
out = P.ReLU()(out)
|
||||
if _fake:
|
||||
out = self.fake(out)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ config = ed({
|
|||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 110,
|
||||
"epoch_size": 120,
|
||||
"pretrained_epoch_size": 90,
|
||||
"buffer_size": 1000,
|
||||
"image_height": 224,
|
||||
|
@ -37,6 +37,6 @@ config = ed({
|
|||
"use_label_smooth": True,
|
||||
"label_smooth_factor": 0.1,
|
||||
"lr_init": 0,
|
||||
"lr_max": 0.1
|
||||
"lr_max": 0.005
|
||||
|
||||
})
|
||||
|
|
|
@ -91,11 +91,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
|
|||
min_dtype = min_val.get("dtype")
|
||||
max_shape = max_val.get("ori_shape")
|
||||
max_dtype = max_val.get("dtype")
|
||||
|
||||
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
|
||||
if channel_axis == 0 and x_shape[0] != min_shape[0] and x_shape[1] == min_shape[0]:
|
||||
channel_axis_ = 1
|
||||
else:
|
||||
channel_axis_ = channel_axis
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(x_shape)
|
||||
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis])
|
||||
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis])
|
||||
util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis_])
|
||||
util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis_])
|
||||
util.check_tensor_shape_size(x_shape)
|
||||
util.check_tensor_shape_size(min_shape)
|
||||
util.check_tensor_shape_size(max_shape)
|
||||
|
@ -122,7 +126,7 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
|
|||
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
||||
max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype)
|
||||
res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data,
|
||||
ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name)
|
||||
ema, ema_decay, quant_min, quant_max, training, channel_axis_, kernel_name)
|
||||
|
||||
with tvm.target.cce():
|
||||
sch = generic.auto_schedule(res_list)
|
||||
|
|
|
@ -99,11 +99,15 @@ def fake_quant_perchannel(x, min_val, max_val, y,
|
|||
min_dtype = min_val.get("dtype")
|
||||
max_shape = max_val.get("ori_shape")
|
||||
max_dtype = max_val.get("dtype")
|
||||
|
||||
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
|
||||
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
|
||||
channel_axis_ = 1
|
||||
else:
|
||||
channel_axis_ = channel_axis
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(x_shape)
|
||||
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
|
||||
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
|
||||
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
|
||||
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
|
||||
util.check_tensor_shape_size(x_shape)
|
||||
util.check_tensor_shape_size(min_shape)
|
||||
util.check_tensor_shape_size(max_shape)
|
||||
|
@ -126,8 +130,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,
|
|||
quant_min = quant_min + 1
|
||||
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis == 1:
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
||||
min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype)
|
||||
|
|
|
@ -124,11 +124,15 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
|||
min_dtype = min_val.get("dtype")
|
||||
max_shape = max_val.get("ori_shape")
|
||||
max_dtype = max_val.get("dtype")
|
||||
|
||||
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
|
||||
if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]:
|
||||
channel_axis_ = 1
|
||||
else:
|
||||
channel_axis_ = channel_axis
|
||||
util.check_kernel_name(kernel_name)
|
||||
util.check_shape_rule(x_shape)
|
||||
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis])
|
||||
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis])
|
||||
util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_])
|
||||
util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_])
|
||||
util.check_tensor_shape_size(x_shape)
|
||||
util.check_tensor_shape_size(min_shape)
|
||||
util.check_tensor_shape_size(max_shape)
|
||||
|
@ -151,8 +155,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
|
|||
quant_min = quant_min + 1
|
||||
|
||||
shape_c = [1] * len(x_shape)
|
||||
shape_c[channel_axis] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis == 1:
|
||||
shape_c[channel_axis_] = min_val.get("ori_shape")[0]
|
||||
if x_format == "NC1HWC0" and channel_axis_ == 1:
|
||||
shape_c = min_val.get("shape")
|
||||
dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype)
|
||||
input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)
|
||||
|
|
Loading…
Reference in New Issue