From b5ad736b75bbe500d70ea26843b26e30c09821d8 Mon Sep 17 00:00:00 2001 From: Amir Lashkari Date: Mon, 24 Aug 2020 18:07:31 -0400 Subject: [PATCH] Adding ShuffleNetV2 to ModelZoo Added Readme.md Fixed PyLint Errors Fixed PyLint Errors-2 Fixed PyLint Errors-3 Fixed PyLint Errors-4 Fixed PyLint Errors-5 Fixed PyLint Errors-6 Fixed PyLint Errors-7 Update eval.py Updated ShuffleNetV2 model Fixed PyLint Error Fixed PyLint Error #2 Fixed PyLint Error #3 Applied Comments Fixed PyLint Fixed PyLint #2 --- model_zoo/official/cv/shufflenetv2/Readme.md | 119 +++++++++++++++++ model_zoo/official/cv/shufflenetv2/blocks.py | 83 ++++++++++++ model_zoo/official/cv/shufflenetv2/eval.py | 54 ++++++++ model_zoo/official/cv/shufflenetv2/network.py | 108 +++++++++++++++ .../scripts/run_distribute_train_for_gpu.sh | 17 +++ .../scripts/run_eval_for_multi_gpu.sh | 18 +++ .../scripts/run_standalone_train_for_gpu.sh | 18 +++ .../official/cv/shufflenetv2/src/config.py | 49 +++++++ .../official/cv/shufflenetv2/src/dataset.py | 81 ++++++++++++ .../official/cv/shufflenetv2/src/loss.py | 60 +++++++++ .../cv/shufflenetv2/src/lr_generator.py | 64 +++++++++ model_zoo/official/cv/shufflenetv2/train.py | 124 ++++++++++++++++++ 12 files changed, 795 insertions(+) create mode 100644 model_zoo/official/cv/shufflenetv2/Readme.md create mode 100644 model_zoo/official/cv/shufflenetv2/blocks.py create mode 100644 model_zoo/official/cv/shufflenetv2/eval.py create mode 100644 model_zoo/official/cv/shufflenetv2/network.py create mode 100644 model_zoo/official/cv/shufflenetv2/scripts/run_distribute_train_for_gpu.sh create mode 100644 model_zoo/official/cv/shufflenetv2/scripts/run_eval_for_multi_gpu.sh create mode 100644 model_zoo/official/cv/shufflenetv2/scripts/run_standalone_train_for_gpu.sh create mode 100644 model_zoo/official/cv/shufflenetv2/src/config.py create mode 100644 model_zoo/official/cv/shufflenetv2/src/dataset.py create mode 100644 model_zoo/official/cv/shufflenetv2/src/loss.py create mode 100644 model_zoo/official/cv/shufflenetv2/src/lr_generator.py create mode 100644 model_zoo/official/cv/shufflenetv2/train.py diff --git a/model_zoo/official/cv/shufflenetv2/Readme.md b/model_zoo/official/cv/shufflenetv2/Readme.md new file mode 100644 index 00000000000..23291073d94 --- /dev/null +++ b/model_zoo/official/cv/shufflenetv2/Readme.md @@ -0,0 +1,119 @@ +# Contents + +- [ShuffleNetV2 Description](#shufflenetv2-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Environment Requirements](#environment-requirements) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Training Process](#training-process) + - [Evaluation Process](#evaluation-process) + - [Evaluation](#evaluation) +- [Model Description](#model-description) + - [Performance](#performance) + - [Training Performance](#evaluation-performance) + - [Inference Performance](#evaluation-performance) + +# [ShuffleNetV2 Description](#contents) + +ShuffleNetV2 is a much faster and more accurate netowrk than the previous networks on different platforms such as Ascend or GPU. +[Paper](https://arxiv.org/pdf/1807.11164.pdf) Ma, N., Zhang, X., Zheng, H. T., & Sun, J. (2018). Shufflenet v2: Practical guidelines for efficient cnn architecture design. In Proceedings of the European conference on computer vision (ECCV) (pp. 116-131). + +# [Model architecture](#contents) + +The overall network architecture of ShuffleNetV2 is show below: + +[Link](https://arxiv.org/pdf/1807.11164.pdf) + +# [Dataset](#contents) + +Dataset used: [imagenet](http://www.image-net.org/) + +- Dataset size: ~125G, 1.2W colorful images in 1000 classes + - Train: 120G, 1.2W images + - Test: 5G, 50000 images +- Data format: RGB images. + - Note: Data will be processed in src/dataset.py + +# [Environment Requirements](#contents) + +- Hardware(GPU) + - Prepare hardware environment with GPU processor. +- Framework + - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) +- For more information, please check the resources below: + - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) + - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) + + +# [Script description](#contents) + +## [Script and sample code](#contents) + +```python ++-- ShuffleNetV2 + +-- Readme.md # descriptions about ShuffleNetV2 + +-- scripts + ¦ +--run_distribute_train_for_gpu.sh # shell script for distributed training + ¦ +--run_eval_for_multi_gpu.sh # shell script for evaluation + ¦ +--run_standalone_train_for_gpu.sh # shell script for standalone training + +-- src + ¦ +--config.py # parameter configuration + ¦ +--dataset.py # creating dataset + ¦ +--loss.py # loss function for network + ¦ +--lr_generator.py # learning rate config + +-- train.py # training script + +-- eval.py # evaluation script + +-- blocks.py # ShuffleNetV2 blocks + +-- network.py # ShuffleNetV2 model network +``` + +## [Training process](#contents) + +### Usage + + +You can start training using python or shell scripts. The usage of shell scripts as follows: + +- Ditributed training on GPU: sh run_distribute_train_for_gpu.sh [DATA_DIR] +- Standalone training on GPU: sh run_standalone_train_for_gpu.sh [DEVICE_ID] [DATA_DIR] + +### Launch + +``` +# training example + python: + GPU: mpirun --allow-run-as-root -n 8 python train.py --is_distributed --platform 'GPU' --dataset_path '~/imagenet/train/' > train.log 2>&1 & + + shell: + GPU: sh run_distribute_train_for_gpu.sh ~/imagenet/train/ +``` + +### Result + +Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log`. + +## [Eval process](#contents) + +### Usage + +You can start evaluation using python or shell scripts. The usage of shell scripts as follows: + +- GPU: sh run_eval_for_multi_gpu.sh [DEVICE_ID] [EPOCH] + +### Launch + +``` +# infer example + python: + GPU: CUDA_VISIBLE_DEVICES=0 python eval.py --platform 'GPU' --dataset_path '~/imagenet/val/' --epoch 250 > eval.log 2>&1 & + + shell: + GPU: sh run_eval_for_multi_gpu.sh 0 250 +``` + +> checkpoint can be produced in training process. + +### Result + +Inference result will be stored in the example path, you can find result in `val.log`. diff --git a/model_zoo/official/cv/shufflenetv2/blocks.py b/model_zoo/official/cv/shufflenetv2/blocks.py new file mode 100644 index 00000000000..253abce0fc2 --- /dev/null +++ b/model_zoo/official/cv/shufflenetv2/blocks.py @@ -0,0 +1,83 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore.nn as nn +import mindspore.ops.operations as P + + +class ShuffleV2Block(nn.Cell): + def __init__(self, inp, oup, mid_channels, *, ksize, stride): + super(ShuffleV2Block, self).__init__() + self.stride = stride + ##assert stride in [1, 2] + + self.mid_channels = mid_channels + self.ksize = ksize + pad = ksize // 2 + self.pad = pad + self.inp = inp + + outputs = oup - inp + + branch_main = [ + # pw + nn.Conv2d(in_channels=inp, out_channels=mid_channels, kernel_size=1, stride=1, + pad_mode='pad', padding=0, has_bias=False), + nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), + nn.ReLU(), + # dw + nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=ksize, stride=stride, + pad_mode='pad', padding=pad, group=mid_channels, has_bias=False), + nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), + # pw-linear + nn.Conv2d(in_channels=mid_channels, out_channels=outputs, kernel_size=1, stride=1, + pad_mode='pad', padding=0, has_bias=False), + nn.BatchNorm2d(num_features=outputs, momentum=0.9), + nn.ReLU(), + ] + self.branch_main = nn.SequentialCell(branch_main) + + if stride == 2: + branch_proj = [ + # dw + nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=ksize, stride=stride, + pad_mode='pad', padding=pad, group=inp, has_bias=False), + nn.BatchNorm2d(num_features=inp, momentum=0.9), + # pw-linear + nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=1, stride=1, + pad_mode='pad', padding=0, has_bias=False), + nn.BatchNorm2d(num_features=inp, momentum=0.9), + nn.ReLU(), + ] + self.branch_proj = nn.SequentialCell(branch_proj) + else: + self.branch_proj = None + + def construct(self, old_x): + if self.stride == 1: + x_proj, x = self.channel_shuffle(old_x) + return P.Concat(1)((x_proj, self.branch_main(x))) + if self.stride == 2: + x_proj = old_x + x = old_x + return P.Concat(1)((self.branch_proj(x_proj), self.branch_main(x))) + return None + + def channel_shuffle(self, x): + batchsize, num_channels, height, width = P.Shape()(x) + ##assert (num_channels % 4 == 0) + x = P.Reshape()(x, (batchsize * num_channels // 2, 2, height * width,)) + x = P.Transpose()(x, (1, 0, 2,)) + x = P.Reshape()(x, (2, -1, num_channels // 2, height, width,)) + return x[0], x[1] diff --git a/model_zoo/official/cv/shufflenetv2/eval.py b/model_zoo/official/cv/shufflenetv2/eval.py new file mode 100644 index 00000000000..51a4ceea8ac --- /dev/null +++ b/model_zoo/official/cv/shufflenetv2/eval.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""evaluate_imagenet""" +import argparse +import os + +import mindspore.nn as nn +from mindspore import context +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.config import config_gpu as cfg +from src.dataset import create_dataset +from network import ShuffleNetV2 + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='image classification evaluation') + parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of ShuffleNetV2 (Default: None)') + parser.add_argument('--dataset_path', type=str, default='', help='Dataset path') + parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform') + parser.add_argument('--epoch', type=str, default='') + args_opt = parser.parse_args() + + if args_opt.platform == 'Ascend': + device_id = int(os.getenv('DEVICE_ID')) + context.set_context(device_id=device_id) + + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, device_id=0) + net = ShuffleNetV2(n_class=cfg.num_classes) + ckpt = load_checkpoint(args_opt.checkpoint) + load_param_into_net(net, ckpt) + net.set_train(False) + dataset = create_dataset(args_opt.dataset_path, cfg, False) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False, + smooth_factor=0.1, num_classes=cfg.num_classes) + eval_metrics = {'Loss': nn.Loss(), + 'Top1-Acc': nn.Top1CategoricalAccuracy(), + 'Top5-Acc': nn.Top5CategoricalAccuracy()} + model = Model(net, loss, optimizer=None, metrics=eval_metrics) + metrics = model.eval(dataset) + print("metric: ", metrics) diff --git a/model_zoo/official/cv/shufflenetv2/network.py b/model_zoo/official/cv/shufflenetv2/network.py new file mode 100644 index 00000000000..f2d2105cfbf --- /dev/null +++ b/model_zoo/official/cv/shufflenetv2/network.py @@ -0,0 +1,108 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +from blocks import ShuffleV2Block + +from mindspore import Tensor +import mindspore.nn as nn +import mindspore.ops.operations as P + + +class ShuffleNetV2(nn.Cell): + def __init__(self, input_size=224, n_class=1000, model_size='1.0x'): + super(ShuffleNetV2, self).__init__() + print('model size is ', model_size) + + self.stage_repeats = [4, 8, 4] + self.model_size = model_size + if model_size == '0.5x': + self.stage_out_channels = [-1, 24, 48, 96, 192, 1024] + elif model_size == '1.0x': + self.stage_out_channels = [-1, 24, 116, 232, 464, 1024] + elif model_size == '1.5x': + self.stage_out_channels = [-1, 24, 176, 352, 704, 1024] + elif model_size == '2.0x': + self.stage_out_channels = [-1, 24, 244, 488, 976, 2048] + else: + raise NotImplementedError + + # building first layer + input_channel = self.stage_out_channels[1] + self.first_conv = nn.SequentialCell([ + nn.Conv2d(in_channels=3, out_channels=input_channel, kernel_size=3, stride=2, + pad_mode='pad', padding=1, has_bias=False), + nn.BatchNorm2d(num_features=input_channel, momentum=0.9), + nn.ReLU(), + ]) + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') + + self.features = [] + for idxstage in range(len(self.stage_repeats)): + numrepeat = self.stage_repeats[idxstage] + output_channel = self.stage_out_channels[idxstage+2] + + for i in range(numrepeat): + if i == 0: + self.features.append(ShuffleV2Block(input_channel, output_channel, + mid_channels=output_channel // 2, ksize=3, stride=2)) + else: + self.features.append(ShuffleV2Block(input_channel // 2, output_channel, + mid_channels=output_channel // 2, ksize=3, stride=1)) + + input_channel = output_channel + + self.features = nn.SequentialCell([*self.features]) + + self.conv_last = nn.SequentialCell([ + nn.Conv2d(in_channels=input_channel, out_channels=self.stage_out_channels[-1], kernel_size=1, stride=1, + pad_mode='pad', padding=0, has_bias=False), + nn.BatchNorm2d(num_features=self.stage_out_channels[-1], momentum=0.9), + nn.ReLU() + ]) + self.globalpool = nn.AvgPool2d(kernel_size=7, stride=7, pad_mode='valid') + if self.model_size == '2.0x': + self.dropout = nn.Dropout(keep_prob=0.8) + self.classifier = nn.SequentialCell([nn.Dense(in_channels=self.stage_out_channels[-1], + out_channels=n_class, has_bias=False)]) + ##TODO init weights + self._initialize_weights() + + def construct(self, x): + x = self.first_conv(x) + x = self.maxpool(x) + x = self.features(x) + x = self.conv_last(x) + + x = self.globalpool(x) + if self.model_size == '2.0x': + x = self.dropout(x) + x = P.Reshape()(x, (-1, self.stage_out_channels[-1],)) + x = self.classifier(x) + return x + + def _initialize_weights(self): + for name, m in self.cells_and_names(): + if isinstance(m, nn.Conv2d): + if 'first' in name: + m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, + m.weight.data.shape).astype("float32"))) + else: + m.weight.set_parameter_data(Tensor(np.random.normal(0, 1.0/m.weight.data.shape[1], + m.weight.data.shape).astype("float32"))) + + if isinstance(m, nn.Dense): + m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32"))) diff --git a/model_zoo/official/cv/shufflenetv2/scripts/run_distribute_train_for_gpu.sh b/model_zoo/official/cv/shufflenetv2/scripts/run_distribute_train_for_gpu.sh new file mode 100644 index 00000000000..305f1dcfff5 --- /dev/null +++ b/model_zoo/official/cv/shufflenetv2/scripts/run_distribute_train_for_gpu.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DATA_DIR=$1 +mpirun --allow-run-as-root -n 8 python ./train.py --is_distributed --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 & diff --git a/model_zoo/official/cv/shufflenetv2/scripts/run_eval_for_multi_gpu.sh b/model_zoo/official/cv/shufflenetv2/scripts/run_eval_for_multi_gpu.sh new file mode 100644 index 00000000000..3d5c42a72a0 --- /dev/null +++ b/model_zoo/official/cv/shufflenetv2/scripts/run_eval_for_multi_gpu.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DEVICE_ID=$1 +EPOCH=$2 +CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./eval.py --platform 'GPU' --dataset_path '/home/data/ImageNet_Original/val/' --epoch $EPOCH > eval.log 2>&1 & diff --git a/model_zoo/official/cv/shufflenetv2/scripts/run_standalone_train_for_gpu.sh b/model_zoo/official/cv/shufflenetv2/scripts/run_standalone_train_for_gpu.sh new file mode 100644 index 00000000000..a007a96cb0a --- /dev/null +++ b/model_zoo/official/cv/shufflenetv2/scripts/run_standalone_train_for_gpu.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DEVICE_ID=$1 +DATA_DIR=$2 +CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./train.py --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 & diff --git a/model_zoo/official/cv/shufflenetv2/src/config.py b/model_zoo/official/cv/shufflenetv2/src/config.py new file mode 100644 index 00000000000..aca62968b6c --- /dev/null +++ b/model_zoo/official/cv/shufflenetv2/src/config.py @@ -0,0 +1,49 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in main.py +""" +from easydict import EasyDict as edict + + +config_gpu = edict({ + 'random_seed': 1, + 'rank': 0, + 'group_size': 1, + 'work_nums': 8, + 'epoch_size': 250, + 'keep_checkpoint_max': 100, + 'ckpt_path': './checkpoint/', + 'is_save_on_master': 0, + + ### Dataset Config + 'batch_size': 128, + 'num_classes': 1000, + + ### Loss Config + 'label_smooth_factor': 0.1, + 'aux_factor': 0.4, + + ### Learning Rate Config + 'lr_init': 0.5, + + ### Optimization Config + 'weight_decay': 0.00004, + 'momentum': 0.9, + 'opt_eps': 1.0, + 'rmsprop_decay': 0.9, + "loss_scale": 1, + +}) diff --git a/model_zoo/official/cv/shufflenetv2/src/dataset.py b/model_zoo/official/cv/shufflenetv2/src/dataset.py new file mode 100644 index 00000000000..26b37d78d5d --- /dev/null +++ b/model_zoo/official/cv/shufflenetv2/src/dataset.py @@ -0,0 +1,81 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Data operations, will be used in train.py and eval.py +""" +import numpy as np +from src.config import config_gpu as cfg + +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.c_transforms as C2 +import mindspore.dataset.transforms.vision.c_transforms as C + + +class toBGR(): + def __call__(self, img): + img = img[:, :, ::-1] + img = np.ascontiguousarray(img) + return img + +def create_dataset(dataset_path, do_train, rank, group_size, repeat_num=1): + """ + create a train or eval dataset + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + rank (int): The shard ID within num_shards (default=None). + group_size (int): Number of shards that the dataset should be divided into (default=None). + repeat_num(int): the repeat times of dataset. Default: 1. + + Returns: + dataset + """ + if group_size == 1: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True) + else: + ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True, + num_shards=group_size, shard_id=rank) + # define map operations + if do_train: + trans = [ + C.RandomCropDecodeResize(224), + C.RandomHorizontalFlip(prob=0.5), + C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) + ] + else: + trans = [ + C.Decode(), + C.Resize(256), + C.CenterCrop(224) + ] + trans += [ + toBGR(), + C.Rescale(1.0 / 255.0, 0.0), + # C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + C.HWC2CHW(), + C2.TypeCast(mstype.float32) + ] + + type_cast_op = C2.TypeCast(mstype.int32) + ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=cfg.work_nums) + ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=cfg.work_nums) + # apply batch operations + ds = ds.batch(cfg.batch_size, drop_remainder=True) + # apply dataset repeat operation + ds = ds.repeat(repeat_num) + + return ds diff --git a/model_zoo/official/cv/shufflenetv2/src/loss.py b/model_zoo/official/cv/shufflenetv2/src/loss.py new file mode 100644 index 00000000000..01757501e7e --- /dev/null +++ b/model_zoo/official/cv/shufflenetv2/src/loss.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define loss function for network.""" +from mindspore.common import dtype as mstype +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore import Tensor +import mindspore.nn as nn + + +class CrossEntropy(_Loss): + """the redefined loss function with SoftmaxCrossEntropyWithLogits""" + def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4): + super(CrossEntropy, self).__init__() + self.factor = factor + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + + def construct(self, logits, label): + logit, aux = logits + one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) + loss_logit = self.ce(logit, one_hot_label) + loss_logit = self.mean(loss_logit, 0) + one_hot_label_aux = self.onehot(label, F.shape(aux)[1], self.on_value, self.off_value) + loss_aux = self.ce(aux, one_hot_label_aux) + loss_aux = self.mean(loss_aux, 0) + return loss_logit + self.factor*loss_aux + + +class CrossEntropy_Val(_Loss): + """the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process""" + def __init__(self, smooth_factor=0, num_classes=1000): + super(CrossEntropy_Val, self).__init__() + self.onehot = P.OneHot() + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits() + self.mean = P.ReduceMean(False) + + def construct(self, logits, label): + one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value) + loss_logit = self.ce(logits, one_hot_label) + loss_logit = self.mean(loss_logit, 0) + return loss_logit diff --git a/model_zoo/official/cv/shufflenetv2/src/lr_generator.py b/model_zoo/official/cv/shufflenetv2/src/lr_generator.py new file mode 100644 index 00000000000..9fc121553c8 --- /dev/null +++ b/model_zoo/official/cv/shufflenetv2/src/lr_generator.py @@ -0,0 +1,64 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""learning rate exponential decay generator""" +import math +import numpy as np + + +def get_lr(lr_init, lr_decay_rate, num_epoch_per_decay, total_epochs, steps_per_epoch, is_stair=False): + """ + generate learning rate array + + Args: + lr_init(float): init learning rate + lr_decay_rate (float): + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + is_stair(bool): If `True` decay the learning rate at discrete intervals (default=False) + + Returns: + learning_rate, learning rate numpy array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + decay_steps = steps_per_epoch * num_epoch_per_decay + for i in range(total_steps): + p = i/decay_steps + if is_stair: + p = math.floor(p) + lr_each_step.append(lr_init * math.pow(lr_decay_rate, p)) + learning_rate = np.array(lr_each_step).astype(np.float32) + return learning_rate + +def get_lr_basic(lr_init, total_epochs, steps_per_epoch, is_stair=False): + """ + generate basic learning rate array + + Args: + lr_init(float): init learning rate + total_epochs(int): total epochs of training + steps_per_epoch(int): steps of one epoch + is_stair(bool): If `True` decay the learning rate at discrete intervals (default=False) + + Returns: + learning_rate, learning rate numpy array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + for i in range(total_steps): + lr = lr_init - lr_init * (i) / (total_steps) + lr_each_step.append(lr) + learning_rate = np.array(lr_each_step).astype(np.float32) + return learning_rate diff --git a/model_zoo/official/cv/shufflenetv2/train.py b/model_zoo/official/cv/shufflenetv2/train.py new file mode 100644 index 00000000000..ac97fe5a3d3 --- /dev/null +++ b/model_zoo/official/cv/shufflenetv2/train.py @@ -0,0 +1,124 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_imagenet.""" +import argparse +import os +import random +import numpy as np + +from network import ShuffleNetV2 + +import mindspore.nn as nn +from mindspore import context +from mindspore import dataset as de +from mindspore import ParallelMode +from mindspore import Tensor +from mindspore.communication.management import init, get_rank, get_group_size +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.config import config_gpu as cfg +from src.dataset import create_dataset +from src.lr_generator import get_lr_basic + +random.seed(cfg.random_seed) +np.random.seed(cfg.random_seed) +de.config.set_seed(cfg.random_seed) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='image classification training') + parser.add_argument('--dataset_path', type=str, default='/home/data/imagenet_jpeg/train/', help='Dataset path') + parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint') + parser.add_argument('--is_distributed', action='store_true', default=False, + help='distributed training') + parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform') + parser.add_argument('--model_size', type=str, default='1.0x', help='ShuffleNetV2 model size parameter') + args_opt = parser.parse_args() + + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False) + if os.getenv('DEVICE_ID', "not_set").isdigit(): + context.set_context(device_id=int(os.getenv('DEVICE_ID'))) + + # init distributed + if args_opt.is_distributed: + if args_opt.platform == "Ascend": + init() + else: + init("nccl") + cfg.rank = get_rank() + cfg.group_size = get_group_size() + parallel_mode = ParallelMode.DATA_PARALLEL + context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, + parameter_broadcast=True, mirror_mean=True) + else: + cfg.rank = 0 + cfg.group_size = 1 + + # dataloader + dataset = create_dataset(args_opt.dataset_path, True, cfg.rank, cfg.group_size) + batches_per_epoch = dataset.get_dataset_size() + print("Batches Per Epoch: ", batches_per_epoch) + # network + net = ShuffleNetV2(n_class=cfg.num_classes, model_size=args_opt.model_size) + + # loss + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False, + smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) + + # learning rate schedule + lr = get_lr_basic(lr_init=cfg.lr_init, total_epochs=cfg.epoch_size, + steps_per_epoch=batches_per_epoch, is_stair=True) + lr = Tensor(lr) + + # optimizer + decayed_params = [] + no_decayed_params = [] + for param in net.trainable_params(): + if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: + decayed_params.append(param) + else: + no_decayed_params.append(param) + + group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay}, + {'params': no_decayed_params}, + {'order_params': net.trainable_params()}] + optimizer = Momentum(params=net.trainable_params(), learning_rate=Tensor(lr), momentum=cfg.momentum, + weight_decay=cfg.weight_decay) + eval_metrics = {'Loss': nn.Loss(), + 'Top1-Acc': nn.Top1CategoricalAccuracy(), + 'Top5-Acc': nn.Top5CategoricalAccuracy()} + + if args_opt.resume: + ckpt = load_checkpoint(args_opt.resume) + load_param_into_net(net, ckpt) + model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={'acc'}) + + print("============== Starting Training ==============") + loss_cb = LossMonitor(per_print_times=batches_per_epoch) + time_cb = TimeMonitor(data_size=batches_per_epoch) + callbacks = [loss_cb, time_cb] + config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max) + ckpoint_cb = ModelCheckpoint(prefix=f"shufflenet-rank{cfg.rank}", directory=cfg.ckpt_path, config=config_ck) + if args_opt.is_distributed & cfg.is_save_on_master: + if cfg.rank == 0: + callbacks.append(ckpoint_cb) + model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True) + else: + callbacks.append(ckpoint_cb) + model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True) + print("train success")