From 298393b66b4e5ebf21425c5b8d0603821b4073b8 Mon Sep 17 00:00:00 2001 From: chenfei Date: Sat, 29 Aug 2020 15:46:27 +0800 Subject: [PATCH] add manual quantative network of resnet50 --- mindspore/train/quant/quant_utils.py | 9 +- .../official/cv/mobilenetv2_quant/Readme.md | 2 +- model_zoo/official/cv/resnet50_quant/eval.py | 3 +- .../cv/resnet50_quant/models/resnet_quant.py | 2 +- .../models/resnet_quant_manual.py | 325 ++++++++++++++++++ model_zoo/official/cv/resnet50_quant/train.py | 5 +- 6 files changed, 338 insertions(+), 8 deletions(-) create mode 100644 model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py diff --git a/mindspore/train/quant/quant_utils.py b/mindspore/train/quant/quant_utils.py index 12827981d19..e7120d35be2 100644 --- a/mindspore/train/quant/quant_utils.py +++ b/mindspore/train/quant/quant_utils.py @@ -252,13 +252,14 @@ def without_fold_batchnorm(weight, cell_quant): return weight, bias -def load_nonquant_param_into_quant_net(quant_model, params_dict): +def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None): """ load fp32 model parameters to quantization model. Args: - quant_model: quantization model - params_dict: f32 param + quant_model: quantization model. + params_dict: f32 param. + quant_new_params:parameters that exist in quantative network but not in unquantative network. Returns: None @@ -277,6 +278,8 @@ def load_nonquant_param_into_quant_net(quant_model, params_dict): for name, param in quant_model.parameters_and_names(): key_name = name.split(".")[-1] if key_name not in iterable_dict.keys(): + if quant_new_params is not None and key_name in quant_new_params: + continue raise ValueError(f"Can't find match parameter in ckpt,param name = {name}") value_param = next(iterable_dict[key_name], None) if value_param is not None: diff --git a/model_zoo/official/cv/mobilenetv2_quant/Readme.md b/model_zoo/official/cv/mobilenetv2_quant/Readme.md index c58d0b24c31..354fd5703ff 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/Readme.md +++ b/model_zoo/official/cv/mobilenetv2_quant/Readme.md @@ -91,7 +91,7 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil You can start training using python or shell scripts. The usage of shell scripts as follows: -- Ascend: sh run_train_quant.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH] +- Ascend: sh run_train_quant.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] ### Launch diff --git a/model_zoo/official/cv/resnet50_quant/eval.py b/model_zoo/official/cv/resnet50_quant/eval.py index 9eb3ce3520a..76cfa345239 100755 --- a/model_zoo/official/cv/resnet50_quant/eval.py +++ b/model_zoo/official/cv/resnet50_quant/eval.py @@ -20,7 +20,8 @@ import argparse from src.config import config_quant from src.dataset import create_dataset from src.crossentropy import CrossEntropy -from models.resnet_quant import resnet50_quant +#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50 +from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50 from mindspore import context from mindspore.train.model import Model diff --git a/model_zoo/official/cv/resnet50_quant/models/resnet_quant.py b/model_zoo/official/cv/resnet50_quant/models/resnet_quant.py index 63fa32222dd..82bbac02c32 100755 --- a/model_zoo/official/cv/resnet50_quant/models/resnet_quant.py +++ b/model_zoo/official/cv/resnet50_quant/models/resnet_quant.py @@ -209,7 +209,7 @@ class ResNet(nn.Cell): return out -def resnet50_quant(class_num=10001): +def resnet50_quant(class_num=10): """ Get ResNet50 neural network. diff --git a/model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py b/model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py new file mode 100644 index 00000000000..8957ca9322b --- /dev/null +++ b/model_zoo/official/cv/resnet50_quant/models/resnet_quant_manual.py @@ -0,0 +1,325 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ResNet.""" +import numpy as np +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore import Tensor +from mindspore.nn import FakeQuantWithMinMax, Conv2dBnFoldQuant as Conv2dBatchNormQuant + +_ema_decay = 0.999 +_symmetric = True +_fake = True +_per_channel = True + + +def _weight_variable(shape, factor=0.01): + init_value = np.random.randn(*shape).astype(np.float32) * factor + return Tensor(init_value) + + +def _conv3x3(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 3, 3) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _conv1x1(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 1, 1) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _conv7x7(in_channel, out_channel, stride=1): + weight_shape = (out_channel, in_channel, 7, 7) + weight = _weight_variable(weight_shape) + return nn.Conv2d(in_channel, out_channel, + kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight) + + +def _bn(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _bn_last(channel): + return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, + gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1) + + +def _fc(in_channel, out_channel): + weight_shape = (out_channel, in_channel) + weight = _weight_variable(weight_shape) + return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0) + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, + group=groups, fake=_fake, per_channel=_per_channel, symmetric=_symmetric) + layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()] + self.features = nn.SequentialCell(layers) + + def construct(self, x): + output = self.features(x) + return output + + +class ResidualBlock(nn.Cell): + """ + ResNet V1 residual block definition. + + Args: + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, stride=2) + """ + expansion = 4 + + def __init__(self, + in_channel, + out_channel, + stride=1): + super(ResidualBlock, self).__init__() + + channel = out_channel // self.expansion + self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) + self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) + self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, per_channel=_per_channel, + symmetric=_symmetric, + kernel_size=1, stride=1, pad_mode='same', padding=0), + FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, symmetric=False) + ]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake, + per_channel=_per_channel, + symmetric=_symmetric, + kernel_size=1, stride=1, + pad_mode='same', padding=0) + + self.down_sample = False + + if stride != 1 or in_channel != out_channel: + self.down_sample = True + self.down_sample_layer = None + + if self.down_sample: + self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel, + per_channel=_per_channel, + symmetric=_symmetric, + kernel_size=1, stride=stride, + pad_mode='same', padding=0), + FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay, + symmetric=False) + ]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel, + fake=_fake, + per_channel=_per_channel, + symmetric=_symmetric, + kernel_size=1, + stride=stride, + pad_mode='same', + padding=0) + self.add = nn.TensorAddQuant() + self.relu = P.ReLU() + + def construct(self, x): + identity = x + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + + if self.down_sample: + identity = self.down_sample_layer(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class ResNet(nn.Cell): + """ + ResNet architecture. + + Args: + block (Cell): Block for network. + layer_nums (list): Numbers of block in different layers. + in_channels (list): Input channel in each layer. + out_channels (list): Output channel in each layer. + strides (list): Stride size in each layer. + num_classes (int): The number of classes that the training images are belonging to. + Returns: + Tensor, output tensor. + + Examples: + >>> ResNet(ResidualBlock, + >>> [3, 4, 6, 3], + >>> [64, 256, 512, 1024], + >>> [256, 512, 1024, 2048], + >>> [1, 2, 2, 2], + >>> 10) + """ + + def __init__(self, + block, + layer_nums, + in_channels, + out_channels, + strides, + num_classes): + super(ResNet, self).__init__() + + if not len(layer_nums) == len(in_channels) == len(out_channels) == 4: + raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!") + + self.conv1 = ConvBNReLU(3, 64, kernel_size=7, stride=2) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + + self.layer1 = self._make_layer(block, + layer_nums[0], + in_channel=in_channels[0], + out_channel=out_channels[0], + stride=strides[0]) + self.layer2 = self._make_layer(block, + layer_nums[1], + in_channel=in_channels[1], + out_channel=out_channels[1], + stride=strides[1]) + self.layer3 = self._make_layer(block, + layer_nums[2], + in_channel=in_channels[2], + out_channel=out_channels[2], + stride=strides[2]) + self.layer4 = self._make_layer(block, + layer_nums[3], + in_channel=in_channels[3], + out_channel=out_channels[3], + stride=strides[3]) + + self.mean = P.ReduceMean(keep_dims=True) + self.flatten = nn.Flatten() + self.end_point = nn.DenseQuant(out_channels[3], num_classes, has_bias=True, per_channel=_per_channel, + symmetric=_symmetric) + self.output_fake = nn.FakeQuantWithMinMax(ema=True, ema_decay=_ema_decay) + + def _make_layer(self, block, layer_num, in_channel, out_channel, stride): + """ + Make stage network of ResNet. + + Args: + block (Cell): Resnet block. + layer_num (int): Layer number. + in_channel (int): Input channel. + out_channel (int): Output channel. + stride (int): Stride size for the first convolutional layer. + + Returns: + SequentialCell, the output layer. + + Examples: + >>> _make_layer(ResidualBlock, 3, 128, 256, 2) + """ + layers = [] + + resnet_block = block(in_channel, out_channel, stride=stride) + layers.append(resnet_block) + + for _ in range(1, layer_num): + resnet_block = block(out_channel, out_channel, stride=1) + layers.append(resnet_block) + + return nn.SequentialCell(layers) + + def construct(self, x): + x = self.conv1(x) + c1 = self.maxpool(x) + + c2 = self.layer1(c1) + c3 = self.layer2(c2) + c4 = self.layer3(c3) + c5 = self.layer4(c4) + + out = self.mean(c5, (2, 3)) + out = self.flatten(out) + out = self.end_point(out) + out = self.output_fake(out) + return out + + +def resnet50_quant(class_num=10): + """ + Get ResNet50 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet50 neural network. + + Examples: + >>> net = resnet50_quant(10) + """ + return ResNet(ResidualBlock, + [3, 4, 6, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num) + + +def resnet101_quant(class_num=1001): + """ + Get ResNet101 neural network. + + Args: + class_num (int): Class number. + + Returns: + Cell, cell instance of ResNet101 neural network. + + Examples: + >>> net = resnet101(1001) + """ + return ResNet(ResidualBlock, + [3, 4, 23, 3], + [64, 256, 512, 1024], + [256, 512, 1024, 2048], + [1, 2, 2, 2], + class_num) diff --git a/model_zoo/official/cv/resnet50_quant/train.py b/model_zoo/official/cv/resnet50_quant/train.py index 2e13ec37ff6..870da29ee01 100755 --- a/model_zoo/official/cv/resnet50_quant/train.py +++ b/model_zoo/official/cv/resnet50_quant/train.py @@ -31,7 +31,8 @@ from mindspore.communication.management import init import mindspore.nn as nn import mindspore.common.initializer as weight_init -from models.resnet_quant import resnet50_quant +#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50 +from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50 from src.dataset import create_dataset from src.lr_generator import get_lr from src.config import config_quant @@ -85,7 +86,7 @@ if __name__ == '__main__': # weight init and load checkpoint file if args_opt.pre_trained: param_dict = load_checkpoint(args_opt.pre_trained) - load_nonquant_param_into_quant_net(net, param_dict) + load_nonquant_param_into_quant_net(net, param_dict, ['step']) epoch_size = config.epoch_size - config.pretrained_epoch_size else: for _, cell in net.cells_and_names():