From e9ee59c7adc22270aed982b4ed471ea3bc847c33 Mon Sep 17 00:00:00 2001 From: chenzupeng Date: Mon, 22 Jun 2020 18:20:53 +0800 Subject: [PATCH] add perchannel quant train --- example/mobilenetv2_quant/Readme.md | 1 - example/mobilenetv2_quant/src/config.py | 4 ++-- .../src/mobilenetV2_quant.py | 24 ++++++++++++------- example/resnet50_quant/README.md | 4 ++-- example/resnet50_quant/models/resnet_quant.py | 18 +++++++++----- example/resnet50_quant/src/config.py | 4 ++-- .../fake_quant_minmax_perchannel_update.py | 12 ++++++---- .../_custom_op/fake_quant_perchannel.py | 14 +++++++---- .../_custom_op/fake_quant_perchannel_grad.py | 14 +++++++---- 9 files changed, 59 insertions(+), 36 deletions(-) diff --git a/example/mobilenetv2_quant/Readme.md b/example/mobilenetv2_quant/Readme.md index f426302454a..ca5254f68b0 100644 --- a/example/mobilenetv2_quant/Readme.md +++ b/example/mobilenetv2_quant/Readme.md @@ -47,7 +47,6 @@ Dataset used: imagenet ├── eval.py ``` -Notation: Current hyperparameters only test on 4 cards while training, if want to use 8 cards for training, should change parameters like learning rate in 'src/config.py'. ## Training process diff --git a/example/mobilenetv2_quant/src/config.py b/example/mobilenetv2_quant/src/config.py index 61d02b24b1b..12411f4400c 100644 --- a/example/mobilenetv2_quant/src/config.py +++ b/example/mobilenetv2_quant/src/config.py @@ -22,10 +22,10 @@ config_ascend = ed({ "image_height": 224, "image_width": 224, "batch_size": 192, - "epoch_size": 40, + "epoch_size": 60, "start_epoch": 200, "warmup_epochs": 1, - "lr": 0.15, + "lr": 0.3, "momentum": 0.9, "weight_decay": 4e-5, "label_smooth": 0.1, diff --git a/example/mobilenetv2_quant/src/mobilenetV2_quant.py b/example/mobilenetv2_quant/src/mobilenetV2_quant.py index 84679c96760..4138b01310d 100644 --- a/example/mobilenetv2_quant/src/mobilenetV2_quant.py +++ b/example/mobilenetv2_quant/src/mobilenetV2_quant.py @@ -20,7 +20,8 @@ from mindspore.ops.operations import TensorAdd __all__ = ['mobilenet_v2_quant'] _ema_decay = 0.999 -_symmetric = False +_symmetric = True +_per_channel = True def _make_divisible(v, divisor, min_value=None): @@ -77,10 +78,10 @@ class ConvBNReLU(nn.Cell): super(ConvBNReLU, self).__init__() padding = (kernel_size - 1) // 2 conv = nn.Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, - group=groups) + group=groups, per_channel=_per_channel, symmetric=_symmetric) layers = [conv, nn.ReLU()] self.features = nn.SequentialCell(layers) - self.fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric, min_init=0) + self.fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, min_init=0) def construct(self, x): output = self.features(x) @@ -119,12 +120,13 @@ class InvertedResidual(nn.Cell): # dw ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), # pw-linear - nn.Conv2dBatchNormQuant(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1), - nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric) + nn.Conv2dBatchNormQuant(hidden_dim, oup, kernel_size=1, stride=1, pad_mode='pad', padding=0, group=1, + per_channel=_per_channel, symmetric=_symmetric), + nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay) ]) self.conv = nn.SequentialCell(layers) self.add = TensorAdd() - self.add_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric) + self.add_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay) def construct(self, x): identity = x @@ -175,7 +177,7 @@ class MobileNetV2Quant(nn.Cell): # building first layer input_channel = _make_divisible(input_channel * width_mult, round_nearest) self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) - self.input_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=_symmetric) + self.input_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay) features = [ConvBNReLU(3, input_channel, stride=2)] # building inverted residual blocks for t, c, n, s in self.cfgs: @@ -189,8 +191,12 @@ class MobileNetV2Quant(nn.Cell): # make it nn.CellList self.features = nn.SequentialCell(features) # mobilenet head - head = ([GlobalAvgPooling(), nn.Dense(self.out_channels, num_classes, has_bias=True)] if not has_dropout else - [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(self.out_channels, num_classes, has_bias=True)]) + head = ([GlobalAvgPooling(), + nn.DenseQuant(self.out_channels, num_classes, has_bias=True, per_channel=_per_channel, + symmetric=_symmetric)] if not has_dropout else + [GlobalAvgPooling(), nn.Dropout(0.2), + nn.DenseQuant(self.out_channels, num_classes, has_bias=True, per_channel=_per_channel, + symmetric=_symmetric)]) self.head = nn.SequentialCell(head) def construct(self, x): diff --git a/example/resnet50_quant/README.md b/example/resnet50_quant/README.md index 948bd93ec6b..9e843b22238 100644 --- a/example/resnet50_quant/README.md +++ b/example/resnet50_quant/README.md @@ -51,7 +51,7 @@ Parameters for both training and inference can be set in config.py. "loss_scale": 1024, # loss scale "momentum": 0.9, # momentum optimizer "weight_decay": 1e-4, # weight decay -"epoch_size": 110, # only valid for taining, which is always 1 for inference +"epoch_size": 120, # only valid for taining, which is always 1 for inference "pretrained_epoch_size": 90, # epoch size that model has been trained before load pretrained checkpoint "buffer_size": 1000, # number of queue size in data preprocessing "image_height": 224, # image height @@ -65,7 +65,7 @@ Parameters for both training and inference can be set in config.py. "label_smooth": True, # label smooth "label_smooth_factor": 0.1, # label smooth factor "lr_init": 0, # initial learning rate -"lr_max": 0.1, # maximum learning rate +"lr_max": 0.005, # maximum learning rate ``` ## Running the example diff --git a/example/resnet50_quant/models/resnet_quant.py b/example/resnet50_quant/models/resnet_quant.py index 54fd58794fb..2d44ba3947b 100755 --- a/example/resnet50_quant/models/resnet_quant.py +++ b/example/resnet50_quant/models/resnet_quant.py @@ -22,6 +22,7 @@ from mindspore.nn import FakeQuantWithMinMax, Conv2dBatchNormQuant _ema_decay = 0.999 _symmetric = False _fake = True +_per_channel = True def _weight_variable(shape, factor=0.01): init_value = np.random.randn(*shape).astype(np.float32) * factor @@ -85,7 +86,7 @@ class ConvBNReLU(nn.Cell): super(ConvBNReLU, self).__init__() padding = (kernel_size - 1) // 2 conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, - group=groups, fake=_fake) + group=groups, fake=_fake, per_channel=_per_channel, symmetric=_symmetric) layers = [conv, nn.ReLUQuant()] if _fake else [conv, nn.ReLU()] self.features = nn.SequentialCell(layers) @@ -119,10 +120,13 @@ class ResidualBlock(nn.Cell): channel = out_channel // self.expansion self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) - self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, + self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, per_channel=_per_channel, + symmetric=_symmetric, kernel_size=1, stride=1, pad_mode='same', padding=0), FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=False) ]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake, + per_channel=_per_channel, + symmetric=_symmetric, kernel_size=1, stride=1, pad_mode='same', padding=0) @@ -134,18 +138,22 @@ class ResidualBlock(nn.Cell): if self.down_sample: self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel, + per_channel=_per_channel, + symmetric=_symmetric, kernel_size=1, stride=stride, pad_mode='same', padding=0), FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=False) ]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel, fake=_fake, + per_channel=_per_channel, + symmetric=_symmetric, kernel_size=1, stride=stride, pad_mode='same', padding=0) self.add = P.TensorAdd() - self.fake = FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=False) + self.relu = nn.ReLUQuant() if _fake else P.ReLU() def construct(self, x): identity = x @@ -157,9 +165,7 @@ class ResidualBlock(nn.Cell): identity = self.down_sample_layer(identity) out = self.add(out, identity) - out = P.ReLU()(out) - if _fake: - out = self.fake(out) + out = self.relu(out) return out diff --git a/example/resnet50_quant/src/config.py b/example/resnet50_quant/src/config.py index dadbe370dd0..d773f531f17 100755 --- a/example/resnet50_quant/src/config.py +++ b/example/resnet50_quant/src/config.py @@ -23,7 +23,7 @@ config = ed({ "loss_scale": 1024, "momentum": 0.9, "weight_decay": 1e-4, - "epoch_size": 110, + "epoch_size": 120, "pretrained_epoch_size": 90, "buffer_size": 1000, "image_height": 224, @@ -37,6 +37,6 @@ config = ed({ "use_label_smooth": True, "label_smooth_factor": 0.1, "lr_init": 0, - "lr_max": 0.1 + "lr_max": 0.005 }) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py index 7694753d8f5..fee7f3ed1b6 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py @@ -91,11 +91,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, min_dtype = min_val.get("dtype") max_shape = max_val.get("ori_shape") max_dtype = max_val.get("dtype") - + # for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1. + if channel_axis == 0 and x_shape[0] != min_shape[0] and x_shape[1] == min_shape[0]: + channel_axis_ = 1 + else: + channel_axis_ = channel_axis util.check_kernel_name(kernel_name) util.check_shape_rule(x_shape) - util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis]) - util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis]) + util.check_shape_rule(min_shape, 1, 1, x_shape[channel_axis_]) + util.check_shape_rule(max_shape, 1, 1, x_shape[channel_axis_]) util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(max_shape) @@ -122,7 +126,7 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up, min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) max_data = tvm.placeholder(shape_c, name="max_val", dtype=x_dtype) res_list = fake_quant_min_max_per_channel_update_compute(input_data, min_data, max_data, - ema, ema_decay, quant_min, quant_max, training, channel_axis, kernel_name) + ema, ema_decay, quant_min, quant_max, training, channel_axis_, kernel_name) with tvm.target.cce(): sch = generic.auto_schedule(res_list) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py index f6c133c8086..dae2d7058dd 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py @@ -99,11 +99,15 @@ def fake_quant_perchannel(x, min_val, max_val, y, min_dtype = min_val.get("dtype") max_shape = max_val.get("ori_shape") max_dtype = max_val.get("dtype") - + # for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1. + if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]: + channel_axis_ = 1 + else: + channel_axis_ = channel_axis util.check_kernel_name(kernel_name) util.check_shape_rule(x_shape) - util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis]) - util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis]) + util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_]) + util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_]) util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(max_shape) @@ -126,8 +130,8 @@ def fake_quant_perchannel(x, min_val, max_val, y, quant_min = quant_min + 1 shape_c = [1] * len(x_shape) - shape_c[channel_axis] = min_val.get("ori_shape")[0] - if x_format == "NC1HWC0" and channel_axis == 1: + shape_c[channel_axis_] = min_val.get("ori_shape")[0] + if x_format == "NC1HWC0" and channel_axis_ == 1: shape_c = min_val.get("shape") input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype) min_data = tvm.placeholder(shape_c, name="min_val", dtype=x_dtype) diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py index 4e9053fcb14..795aab52a3d 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py @@ -124,11 +124,15 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx, min_dtype = min_val.get("dtype") max_shape = max_val.get("ori_shape") max_dtype = max_val.get("dtype") - + # for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1. + if channel_axis == 0 and x_shape_[0] != min_shape[0] and x_shape_[1] == min_shape[0]: + channel_axis_ = 1 + else: + channel_axis_ = channel_axis util.check_kernel_name(kernel_name) util.check_shape_rule(x_shape) - util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis]) - util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis]) + util.check_shape_rule(min_shape, 1, 1, x_shape_[channel_axis_]) + util.check_shape_rule(max_shape, 1, 1, x_shape_[channel_axis_]) util.check_tensor_shape_size(x_shape) util.check_tensor_shape_size(min_shape) util.check_tensor_shape_size(max_shape) @@ -151,8 +155,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx, quant_min = quant_min + 1 shape_c = [1] * len(x_shape) - shape_c[channel_axis] = min_val.get("ori_shape")[0] - if x_format == "NC1HWC0" and channel_axis == 1: + shape_c[channel_axis_] = min_val.get("ori_shape")[0] + if x_format == "NC1HWC0" and channel_axis_ == 1: shape_c = min_val.get("shape") dout_data = tvm.placeholder(x_shape, name="dout", dtype=x_dtype) input_data = tvm.placeholder(x_shape, name="x", dtype=x_dtype)