forked from mindspore-Ecosystem/mindspore
!5097 Adding ShuffleNetV2 model to modelzoo
Merge pull request !5097 from alashkari/shufflenet
This commit is contained in:
commit
7962287f3e
|
@ -0,0 +1,119 @@
|
|||
# Contents
|
||||
|
||||
- [ShuffleNetV2 Description](#shufflenetv2-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Training Process](#training-process)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#evaluation-performance)
|
||||
- [Inference Performance](#evaluation-performance)
|
||||
|
||||
# [ShuffleNetV2 Description](#contents)
|
||||
|
||||
ShuffleNetV2 is a much faster and more accurate netowrk than the previous networks on different platforms such as Ascend or GPU.
|
||||
[Paper](https://arxiv.org/pdf/1807.11164.pdf) Ma, N., Zhang, X., Zheng, H. T., & Sun, J. (2018). Shufflenet v2: Practical guidelines for efficient cnn architecture design. In Proceedings of the European conference on computer vision (ECCV) (pp. 116-131).
|
||||
|
||||
# [Model architecture](#contents)
|
||||
|
||||
The overall network architecture of ShuffleNetV2 is show below:
|
||||
|
||||
[Link](https://arxiv.org/pdf/1807.11164.pdf)
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
Dataset used: [imagenet](http://www.image-net.org/)
|
||||
|
||||
- Dataset size: ~125G, 1.2W colorful images in 1000 classes
|
||||
- Train: 120G, 1.2W images
|
||||
- Test: 5G, 50000 images
|
||||
- Data format: RGB images.
|
||||
- Note: Data will be processed in src/dataset.py
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(GPU)
|
||||
- Prepare hardware environment with GPU processor.
|
||||
- Framework
|
||||
- [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||
|
||||
|
||||
# [Script description](#contents)
|
||||
|
||||
## [Script and sample code](#contents)
|
||||
|
||||
```python
|
||||
+-- ShuffleNetV2
|
||||
+-- Readme.md # descriptions about ShuffleNetV2
|
||||
+-- scripts
|
||||
¦ +--run_distribute_train_for_gpu.sh # shell script for distributed training
|
||||
¦ +--run_eval_for_multi_gpu.sh # shell script for evaluation
|
||||
¦ +--run_standalone_train_for_gpu.sh # shell script for standalone training
|
||||
+-- src
|
||||
¦ +--config.py # parameter configuration
|
||||
¦ +--dataset.py # creating dataset
|
||||
¦ +--loss.py # loss function for network
|
||||
¦ +--lr_generator.py # learning rate config
|
||||
+-- train.py # training script
|
||||
+-- eval.py # evaluation script
|
||||
+-- blocks.py # ShuffleNetV2 blocks
|
||||
+-- network.py # ShuffleNetV2 model network
|
||||
```
|
||||
|
||||
## [Training process](#contents)
|
||||
|
||||
### Usage
|
||||
|
||||
|
||||
You can start training using python or shell scripts. The usage of shell scripts as follows:
|
||||
|
||||
- Ditributed training on GPU: sh run_distribute_train_for_gpu.sh [DATA_DIR]
|
||||
- Standalone training on GPU: sh run_standalone_train_for_gpu.sh [DEVICE_ID] [DATA_DIR]
|
||||
|
||||
### Launch
|
||||
|
||||
```
|
||||
# training example
|
||||
python:
|
||||
GPU: mpirun --allow-run-as-root -n 8 python train.py --is_distributed --platform 'GPU' --dataset_path '~/imagenet/train/' > train.log 2>&1 &
|
||||
|
||||
shell:
|
||||
GPU: sh run_distribute_train_for_gpu.sh ~/imagenet/train/
|
||||
```
|
||||
|
||||
### Result
|
||||
|
||||
Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log`.
|
||||
|
||||
## [Eval process](#contents)
|
||||
|
||||
### Usage
|
||||
|
||||
You can start evaluation using python or shell scripts. The usage of shell scripts as follows:
|
||||
|
||||
- GPU: sh run_eval_for_multi_gpu.sh [DEVICE_ID] [EPOCH]
|
||||
|
||||
### Launch
|
||||
|
||||
```
|
||||
# infer example
|
||||
python:
|
||||
GPU: CUDA_VISIBLE_DEVICES=0 python eval.py --platform 'GPU' --dataset_path '~/imagenet/val/' --epoch 250 > eval.log 2>&1 &
|
||||
|
||||
shell:
|
||||
GPU: sh run_eval_for_multi_gpu.sh 0 250
|
||||
```
|
||||
|
||||
> checkpoint can be produced in training process.
|
||||
|
||||
### Result
|
||||
|
||||
Inference result will be stored in the example path, you can find result in `val.log`.
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
|
||||
|
||||
class ShuffleV2Block(nn.Cell):
|
||||
def __init__(self, inp, oup, mid_channels, *, ksize, stride):
|
||||
super(ShuffleV2Block, self).__init__()
|
||||
self.stride = stride
|
||||
##assert stride in [1, 2]
|
||||
|
||||
self.mid_channels = mid_channels
|
||||
self.ksize = ksize
|
||||
pad = ksize // 2
|
||||
self.pad = pad
|
||||
self.inp = inp
|
||||
|
||||
outputs = oup - inp
|
||||
|
||||
branch_main = [
|
||||
# pw
|
||||
nn.Conv2d(in_channels=inp, out_channels=mid_channels, kernel_size=1, stride=1,
|
||||
pad_mode='pad', padding=0, has_bias=False),
|
||||
nn.BatchNorm2d(num_features=mid_channels, momentum=0.9),
|
||||
nn.ReLU(),
|
||||
# dw
|
||||
nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=ksize, stride=stride,
|
||||
pad_mode='pad', padding=pad, group=mid_channels, has_bias=False),
|
||||
nn.BatchNorm2d(num_features=mid_channels, momentum=0.9),
|
||||
# pw-linear
|
||||
nn.Conv2d(in_channels=mid_channels, out_channels=outputs, kernel_size=1, stride=1,
|
||||
pad_mode='pad', padding=0, has_bias=False),
|
||||
nn.BatchNorm2d(num_features=outputs, momentum=0.9),
|
||||
nn.ReLU(),
|
||||
]
|
||||
self.branch_main = nn.SequentialCell(branch_main)
|
||||
|
||||
if stride == 2:
|
||||
branch_proj = [
|
||||
# dw
|
||||
nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=ksize, stride=stride,
|
||||
pad_mode='pad', padding=pad, group=inp, has_bias=False),
|
||||
nn.BatchNorm2d(num_features=inp, momentum=0.9),
|
||||
# pw-linear
|
||||
nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=1, stride=1,
|
||||
pad_mode='pad', padding=0, has_bias=False),
|
||||
nn.BatchNorm2d(num_features=inp, momentum=0.9),
|
||||
nn.ReLU(),
|
||||
]
|
||||
self.branch_proj = nn.SequentialCell(branch_proj)
|
||||
else:
|
||||
self.branch_proj = None
|
||||
|
||||
def construct(self, old_x):
|
||||
if self.stride == 1:
|
||||
x_proj, x = self.channel_shuffle(old_x)
|
||||
return P.Concat(1)((x_proj, self.branch_main(x)))
|
||||
if self.stride == 2:
|
||||
x_proj = old_x
|
||||
x = old_x
|
||||
return P.Concat(1)((self.branch_proj(x_proj), self.branch_main(x)))
|
||||
return None
|
||||
|
||||
def channel_shuffle(self, x):
|
||||
batchsize, num_channels, height, width = P.Shape()(x)
|
||||
##assert (num_channels % 4 == 0)
|
||||
x = P.Reshape()(x, (batchsize * num_channels // 2, 2, height * width,))
|
||||
x = P.Transpose()(x, (1, 0, 2,))
|
||||
x = P.Reshape()(x, (2, -1, num_channels // 2, height, width,))
|
||||
return x[0], x[1]
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""evaluate_imagenet"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import config_gpu as cfg
|
||||
from src.dataset import create_dataset
|
||||
from network import ShuffleNetV2
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='image classification evaluation')
|
||||
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of ShuffleNetV2 (Default: None)')
|
||||
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||
parser.add_argument('--epoch', type=str, default='')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, device_id=0)
|
||||
net = ShuffleNetV2(n_class=cfg.num_classes)
|
||||
ckpt = load_checkpoint(args_opt.checkpoint)
|
||||
load_param_into_net(net, ckpt)
|
||||
net.set_train(False)
|
||||
dataset = create_dataset(args_opt.dataset_path, cfg, False)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False,
|
||||
smooth_factor=0.1, num_classes=cfg.num_classes)
|
||||
eval_metrics = {'Loss': nn.Loss(),
|
||||
'Top1-Acc': nn.Top1CategoricalAccuracy(),
|
||||
'Top5-Acc': nn.Top5CategoricalAccuracy()}
|
||||
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
|
||||
metrics = model.eval(dataset)
|
||||
print("metric: ", metrics)
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
from blocks import ShuffleV2Block
|
||||
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
|
||||
|
||||
class ShuffleNetV2(nn.Cell):
|
||||
def __init__(self, input_size=224, n_class=1000, model_size='1.0x'):
|
||||
super(ShuffleNetV2, self).__init__()
|
||||
print('model size is ', model_size)
|
||||
|
||||
self.stage_repeats = [4, 8, 4]
|
||||
self.model_size = model_size
|
||||
if model_size == '0.5x':
|
||||
self.stage_out_channels = [-1, 24, 48, 96, 192, 1024]
|
||||
elif model_size == '1.0x':
|
||||
self.stage_out_channels = [-1, 24, 116, 232, 464, 1024]
|
||||
elif model_size == '1.5x':
|
||||
self.stage_out_channels = [-1, 24, 176, 352, 704, 1024]
|
||||
elif model_size == '2.0x':
|
||||
self.stage_out_channels = [-1, 24, 244, 488, 976, 2048]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# building first layer
|
||||
input_channel = self.stage_out_channels[1]
|
||||
self.first_conv = nn.SequentialCell([
|
||||
nn.Conv2d(in_channels=3, out_channels=input_channel, kernel_size=3, stride=2,
|
||||
pad_mode='pad', padding=1, has_bias=False),
|
||||
nn.BatchNorm2d(num_features=input_channel, momentum=0.9),
|
||||
nn.ReLU(),
|
||||
])
|
||||
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
|
||||
|
||||
self.features = []
|
||||
for idxstage in range(len(self.stage_repeats)):
|
||||
numrepeat = self.stage_repeats[idxstage]
|
||||
output_channel = self.stage_out_channels[idxstage+2]
|
||||
|
||||
for i in range(numrepeat):
|
||||
if i == 0:
|
||||
self.features.append(ShuffleV2Block(input_channel, output_channel,
|
||||
mid_channels=output_channel // 2, ksize=3, stride=2))
|
||||
else:
|
||||
self.features.append(ShuffleV2Block(input_channel // 2, output_channel,
|
||||
mid_channels=output_channel // 2, ksize=3, stride=1))
|
||||
|
||||
input_channel = output_channel
|
||||
|
||||
self.features = nn.SequentialCell([*self.features])
|
||||
|
||||
self.conv_last = nn.SequentialCell([
|
||||
nn.Conv2d(in_channels=input_channel, out_channels=self.stage_out_channels[-1], kernel_size=1, stride=1,
|
||||
pad_mode='pad', padding=0, has_bias=False),
|
||||
nn.BatchNorm2d(num_features=self.stage_out_channels[-1], momentum=0.9),
|
||||
nn.ReLU()
|
||||
])
|
||||
self.globalpool = nn.AvgPool2d(kernel_size=7, stride=7, pad_mode='valid')
|
||||
if self.model_size == '2.0x':
|
||||
self.dropout = nn.Dropout(keep_prob=0.8)
|
||||
self.classifier = nn.SequentialCell([nn.Dense(in_channels=self.stage_out_channels[-1],
|
||||
out_channels=n_class, has_bias=False)])
|
||||
##TODO init weights
|
||||
self._initialize_weights()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.first_conv(x)
|
||||
x = self.maxpool(x)
|
||||
x = self.features(x)
|
||||
x = self.conv_last(x)
|
||||
|
||||
x = self.globalpool(x)
|
||||
if self.model_size == '2.0x':
|
||||
x = self.dropout(x)
|
||||
x = P.Reshape()(x, (-1, self.stage_out_channels[-1],))
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
def _initialize_weights(self):
|
||||
for name, m in self.cells_and_names():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if 'first' in name:
|
||||
m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01,
|
||||
m.weight.data.shape).astype("float32")))
|
||||
else:
|
||||
m.weight.set_parameter_data(Tensor(np.random.normal(0, 1.0/m.weight.data.shape[1],
|
||||
m.weight.data.shape).astype("float32")))
|
||||
|
||||
if isinstance(m, nn.Dense):
|
||||
m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32")))
|
|
@ -0,0 +1,17 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DATA_DIR=$1
|
||||
mpirun --allow-run-as-root -n 8 python ./train.py --is_distributed --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 &
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DEVICE_ID=$1
|
||||
EPOCH=$2
|
||||
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./eval.py --platform 'GPU' --dataset_path '/home/data/ImageNet_Original/val/' --epoch $EPOCH > eval.log 2>&1 &
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./train.py --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 &
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in main.py
|
||||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
|
||||
config_gpu = edict({
|
||||
'random_seed': 1,
|
||||
'rank': 0,
|
||||
'group_size': 1,
|
||||
'work_nums': 8,
|
||||
'epoch_size': 250,
|
||||
'keep_checkpoint_max': 100,
|
||||
'ckpt_path': './checkpoint/',
|
||||
'is_save_on_master': 0,
|
||||
|
||||
### Dataset Config
|
||||
'batch_size': 128,
|
||||
'num_classes': 1000,
|
||||
|
||||
### Loss Config
|
||||
'label_smooth_factor': 0.1,
|
||||
'aux_factor': 0.4,
|
||||
|
||||
### Learning Rate Config
|
||||
'lr_init': 0.5,
|
||||
|
||||
### Optimization Config
|
||||
'weight_decay': 0.00004,
|
||||
'momentum': 0.9,
|
||||
'opt_eps': 1.0,
|
||||
'rmsprop_decay': 0.9,
|
||||
"loss_scale": 1,
|
||||
|
||||
})
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import numpy as np
|
||||
from src.config import config_gpu as cfg
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
import mindspore.dataset.transforms.vision.c_transforms as C
|
||||
|
||||
|
||||
class toBGR():
|
||||
def __call__(self, img):
|
||||
img = img[:, :, ::-1]
|
||||
img = np.ascontiguousarray(img)
|
||||
return img
|
||||
|
||||
def create_dataset(dataset_path, do_train, rank, group_size, repeat_num=1):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
rank (int): The shard ID within num_shards (default=None).
|
||||
group_size (int): Number of shards that the dataset should be divided into (default=None).
|
||||
repeat_num(int): the repeat times of dataset. Default: 1.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if group_size == 1:
|
||||
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True)
|
||||
else:
|
||||
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True,
|
||||
num_shards=group_size, shard_id=rank)
|
||||
# define map operations
|
||||
if do_train:
|
||||
trans = [
|
||||
C.RandomCropDecodeResize(224),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize(256),
|
||||
C.CenterCrop(224)
|
||||
]
|
||||
trans += [
|
||||
toBGR(),
|
||||
C.Rescale(1.0 / 255.0, 0.0),
|
||||
# C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
C.HWC2CHW(),
|
||||
C2.TypeCast(mstype.float32)
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=cfg.work_nums)
|
||||
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=cfg.work_nums)
|
||||
# apply batch operations
|
||||
ds = ds.batch(cfg.batch_size, drop_remainder=True)
|
||||
# apply dataset repeat operation
|
||||
ds = ds.repeat(repeat_num)
|
||||
|
||||
return ds
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define loss function for network."""
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class CrossEntropy(_Loss):
|
||||
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
|
||||
def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4):
|
||||
super(CrossEntropy, self).__init__()
|
||||
self.factor = factor
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logits, label):
|
||||
logit, aux = logits
|
||||
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss_logit = self.ce(logit, one_hot_label)
|
||||
loss_logit = self.mean(loss_logit, 0)
|
||||
one_hot_label_aux = self.onehot(label, F.shape(aux)[1], self.on_value, self.off_value)
|
||||
loss_aux = self.ce(aux, one_hot_label_aux)
|
||||
loss_aux = self.mean(loss_aux, 0)
|
||||
return loss_logit + self.factor*loss_aux
|
||||
|
||||
|
||||
class CrossEntropy_Val(_Loss):
|
||||
"""the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process"""
|
||||
def __init__(self, smooth_factor=0, num_classes=1000):
|
||||
super(CrossEntropy_Val, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logits, label):
|
||||
one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value)
|
||||
loss_logit = self.ce(logits, one_hot_label)
|
||||
loss_logit = self.mean(loss_logit, 0)
|
||||
return loss_logit
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate exponential decay generator"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(lr_init, lr_decay_rate, num_epoch_per_decay, total_epochs, steps_per_epoch, is_stair=False):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate
|
||||
lr_decay_rate (float):
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
is_stair(bool): If `True` decay the learning rate at discrete intervals (default=False)
|
||||
|
||||
Returns:
|
||||
learning_rate, learning rate numpy array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
decay_steps = steps_per_epoch * num_epoch_per_decay
|
||||
for i in range(total_steps):
|
||||
p = i/decay_steps
|
||||
if is_stair:
|
||||
p = math.floor(p)
|
||||
lr_each_step.append(lr_init * math.pow(lr_decay_rate, p))
|
||||
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||
return learning_rate
|
||||
|
||||
def get_lr_basic(lr_init, total_epochs, steps_per_epoch, is_stair=False):
|
||||
"""
|
||||
generate basic learning rate array
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate
|
||||
total_epochs(int): total epochs of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
is_stair(bool): If `True` decay the learning rate at discrete intervals (default=False)
|
||||
|
||||
Returns:
|
||||
learning_rate, learning rate numpy array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
for i in range(total_steps):
|
||||
lr = lr_init - lr_init * (i) / (total_steps)
|
||||
lr_each_step.append(lr)
|
||||
learning_rate = np.array(lr_each_step).astype(np.float32)
|
||||
return learning_rate
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_imagenet."""
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from network import ShuffleNetV2
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import dataset as de
|
||||
from mindspore import ParallelMode
|
||||
from mindspore import Tensor
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import config_gpu as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.lr_generator import get_lr_basic
|
||||
|
||||
random.seed(cfg.random_seed)
|
||||
np.random.seed(cfg.random_seed)
|
||||
de.config.set_seed(cfg.random_seed)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='image classification training')
|
||||
parser.add_argument('--dataset_path', type=str, default='/home/data/imagenet_jpeg/train/', help='Dataset path')
|
||||
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
||||
parser.add_argument('--is_distributed', action='store_true', default=False,
|
||||
help='distributed training')
|
||||
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform')
|
||||
parser.add_argument('--model_size', type=str, default='1.0x', help='ShuffleNetV2 model size parameter')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
||||
|
||||
# init distributed
|
||||
if args_opt.is_distributed:
|
||||
if args_opt.platform == "Ascend":
|
||||
init()
|
||||
else:
|
||||
init("nccl")
|
||||
cfg.rank = get_rank()
|
||||
cfg.group_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
|
||||
parameter_broadcast=True, mirror_mean=True)
|
||||
else:
|
||||
cfg.rank = 0
|
||||
cfg.group_size = 1
|
||||
|
||||
# dataloader
|
||||
dataset = create_dataset(args_opt.dataset_path, True, cfg.rank, cfg.group_size)
|
||||
batches_per_epoch = dataset.get_dataset_size()
|
||||
print("Batches Per Epoch: ", batches_per_epoch)
|
||||
# network
|
||||
net = ShuffleNetV2(n_class=cfg.num_classes, model_size=args_opt.model_size)
|
||||
|
||||
# loss
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False,
|
||||
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
|
||||
|
||||
# learning rate schedule
|
||||
lr = get_lr_basic(lr_init=cfg.lr_init, total_epochs=cfg.epoch_size,
|
||||
steps_per_epoch=batches_per_epoch, is_stair=True)
|
||||
lr = Tensor(lr)
|
||||
|
||||
# optimizer
|
||||
decayed_params = []
|
||||
no_decayed_params = []
|
||||
for param in net.trainable_params():
|
||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||
decayed_params.append(param)
|
||||
else:
|
||||
no_decayed_params.append(param)
|
||||
|
||||
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
|
||||
{'params': no_decayed_params},
|
||||
{'order_params': net.trainable_params()}]
|
||||
optimizer = Momentum(params=net.trainable_params(), learning_rate=Tensor(lr), momentum=cfg.momentum,
|
||||
weight_decay=cfg.weight_decay)
|
||||
eval_metrics = {'Loss': nn.Loss(),
|
||||
'Top1-Acc': nn.Top1CategoricalAccuracy(),
|
||||
'Top5-Acc': nn.Top5CategoricalAccuracy()}
|
||||
|
||||
if args_opt.resume:
|
||||
ckpt = load_checkpoint(args_opt.resume)
|
||||
load_param_into_net(net, ckpt)
|
||||
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={'acc'})
|
||||
|
||||
print("============== Starting Training ==============")
|
||||
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
|
||||
time_cb = TimeMonitor(data_size=batches_per_epoch)
|
||||
callbacks = [loss_cb, time_cb]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=f"shufflenet-rank{cfg.rank}", directory=cfg.ckpt_path, config=config_ck)
|
||||
if args_opt.is_distributed & cfg.is_save_on_master:
|
||||
if cfg.rank == 0:
|
||||
callbacks.append(ckpoint_cb)
|
||||
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||
else:
|
||||
callbacks.append(ckpoint_cb)
|
||||
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||
print("train success")
|
Loading…
Reference in New Issue