!16264 mobilenetV3_small_x1_0 master

Merge pull request !16264 from Gogery/mo3m
This commit is contained in:
i-robot 2021-08-11 02:55:26 +00:00 committed by Gitee
commit 728224dd9e
13 changed files with 1238 additions and 0 deletions

View File

@ -0,0 +1,185 @@
# 目录
- [目录](#目录)
- [MobileNetV3描述](#mobilenetv3描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [脚本说明](#脚本说明)
- [脚本和示例代码](#脚本和示例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [启动](#启动)
- [结果](#结果)
- [评估过程](#评估过程)
- [启动](#启动-1)
- [结果](#结果-1)
- [模型说明](#模型说明)
- [训练性能](#训练性能)
- [随机情况的描述](#随机情况的描述)
- [ModelZoo 主页](#modelzoo-主页)
<!-- /TOC -->
# MobileNetV3描述
MobileNetV3结合硬件感知神经网络架构搜索NAS和NetAdapt算法已经可以移植到手机CPU上运行后续随新架构进一步优化改进。2019年11月20日
[论文](https://arxiv.org/pdf/1905.02244)Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al."Searching for mobilenetv3."In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324.2019.
# 模型架构
MobileNetV3总体网络架构如下
[链接](https://arxiv.org/pdf/1905.02244)
# 数据集
使用的数据集:[imagenet](http://www.image-net.org/)
- 数据集大小: 146G, 1330k 1000类彩色图像
- 训练: 140G, 1280k张图片
- 测试: 6G, 50k张图片
- 数据格式RGB
- 注数据在src/dataset.py中处理。
# 环境要求
- 硬件Ascend
- 使用Ascend来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore 教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
# 脚本说明
## 脚本和样例代码
```python
├── MobileNetV3
├── README_CN.md # MobileNetV3相关描述
├── scripts
│ ├──run_standalone_train.sh # 用于单卡训练的shell脚本
│ ├──run_distribute_train.sh # 用于八卡训练的shell脚本
│ └──run_eval.sh # 用于评估的shell脚本
├── src
│ ├──config.py # 参数配置
│ ├──dataset.py # 创建数据集
│ ├──loss.py # 损失函数
│ ├──lr_generator.py # 配置学习率
│ ├──mobilenetV3.py # MobileNetV3架构
│ └──monitor.py # 监控网络损失和其他数据
├── eval.py # 评估脚本
├── export.py # 模型格式转换脚本
└── train.py # 训练脚本
```
## 脚本参数
模型训练和评估过程中使用的参数可以在config.py中设置:
```python
'num_classes': 1000, # 数据集类别数
'image_height': 224, # 输入图像高度
'image_width': 224, # 输入图像宽度
'batch_size': 256, # 数据批次大小
'epoch_size': 370, # 模型迭代次数
'warmup_epochs': 4, # warmup epoch数量
'lr': 0.05, # 学习率
'momentum': 0.9, # 动量参数
'weight_decay': 4e-5, # 权重衰减率
'label_smooth': 0.1, # 标签平滑因子
'loss_scale': 1024, # loss scale
'save_checkpoint': True, # 是否保存ckpt文件
'save_checkpoint_epochs': 1, # 每迭代相应次数保存一个ckpt文件
'keep_checkpoint_max': 5, # 保存ckpt文件的最大数量
'save_checkpoint_path': "./checkpoint", # 保存ckpt文件的路径
'export_file': "mobilenetv3_small", # export文件
'export_format': "MINDIR", # export格式
```
## 训练过程
### 启动
您可以使用python或shell脚本进行训练。
```shell
# 训练示例
python:
Ascend单卡训练示例python train.py --device_id [DEVICE_ID] --dataset_path [DATA_DIR]
shell:
Ascend单卡训练示例: sh ./scripts/run_standalone_train.sh [DEVICE_ID] [DATA_DIR]
Ascend八卡并行训练:
cd ./scripts/
sh ./run_distribute_train.sh [RANK_TABLE_FILE] [DATA_DIR]
```
### 结果
ckpt文件将存储在 `./checkpoint` 路径下,训练日志将被记录到 `log.txt` 中。训练日志部分示例如下:
```shell
epoch 1: epoch time: 553262.126, per step time: 518.521, avg loss: 5.270
epoch 2: epoch time: 151033.049, per step time: 141.549, avg loss: 4.529
epoch 3: epoch time: 150605.300, per step time: 141.148, avg loss: 4.101
epoch 4: epoch time: 150638.805, per step time: 141.180, avg loss: 4.014
epoch 5: epoch time: 150594.088, per step time: 141.138, avg loss: 3.607
```
## 评估过程
### 启动
您可以使用python或shell脚本进行评估。
```shell
# 评估示例
python:
python eval.py --device_id [DEVICE_ID] --dataset_path [DATA_DIR] --checkpoint_path [PATH_CHECKPOINT]
shell:
sh ./scripts/run_eval.sh [DEVICE_ID] [DATA_DIR] [PATH_CHECKPOINT]
```
> 训练过程中可以生成ckpt文件。
### 结果
可以在 `eval_log.txt` 查看评估结果。
```shell
result: {'Loss': 2.3101649037352554, 'Top_1_Acc': 0.6746546546546547, 'Top_5_Acc': 0.8722122122122122} ckpt= ./checkpoint/model_0/mobilenetV3-370_625.ckpt
```
# 模型说明
## 训练性能
| 参数 | Ascend |
| -------------------------- | ------------------------------------- |
| 模型名称 | mobilenetV3 |
| 模型版本 | 小版本 |
| 运行环境 | HUAWEI CLOUD Modelarts |
| 上传时间 | 2021-3-25 |
| 数据集 | imagenet |
| 训练参数 | src/config.py |
| 优化器 | RMSProp |
| 损失函数 | CrossEntropyWithLabelSmooth |
| 最终损失 | 2.31 |
| 精确度 (8p) | Top1[67.5%], Top5[87.2%] |
| 训练总时间 (8p) | 16.4h |
| 评估总时间 | 1min |
| 参数量 (M) | 36M |
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/mobilenetV3_small_x1_0) |
# 随机情况的描述
我们在 `dataset.py``train.py` 脚本中设置了随机种子。
# ModelZoo
请核对官方 [主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,93 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
eval.
"""
import os
import argparse
import ast
from mindspore import context
from mindspore import nn
from mindspore.train.model import Model
from mindspore.common import set_seed
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset import create_dataset
from src.config import config
from src.loss import CrossEntropyWithLabelSmooth
from src.mobilenetv3 import mobilenet_v3_small
parser = argparse.ArgumentParser(description='Image classification')
# modelarts parameter
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
# Ascend parameter
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--device_id', type=int, default=0, help='Device id')
parser.add_argument('--run_modelarts', type=ast.literal_eval, default=False, help='Run distribute')
args_opt = parser.parse_args()
set_seed(1)
if __name__ == '__main__':
if args_opt.run_modelarts:
import moxing as mox
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
context.set_context(device_id=device_id)
local_data_url = '/cache/data/'
local_train_url = '/cache/ckpt/'
mox.file.copy_parallel(args_opt.data_url, local_data_url)
mox.file.copy_parallel(args_opt.train_url, local_train_url)
else:
context.set_context(device_id=args_opt.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False)
net = mobilenet_v3_small(num_classes=config.num_classes, multiplier=1.)
if args_opt.run_modelarts:
dataset = create_dataset(dataset_path=local_data_url,
do_train=False,
batch_size=config.batch_size)
ckpt_path = local_train_url + 'mobilenetV3-370_1067.ckpt'
param_dict = load_checkpoint(ckpt_path)
else:
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=False,
batch_size=config.batch_size)
param_dict = load_checkpoint(args_opt.checkpoint_path)
step_size = dataset.get_dataset_size()
load_param_into_net(net, param_dict)
net.set_train(False)
# define loss, model
loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth, num_classes=config.num_classes)
# define model
eval_metrics = {'Loss': nn.Loss(),
'Top_1_Acc': nn.Top1CategoricalAccuracy(),
'Top_5_Acc': nn.Top5CategoricalAccuracy()}
model = Model(net, loss_fn=loss, metrics=eval_metrics)
# eval model
res = model.eval(dataset)
if args_opt.run_modelarts:
print("result:", res, "ckpt=", local_data_url)
else:
print("result:", res, "ckpt=", args_opt.checkpoint_path)

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
mobilenetv3_small export.
"""
import argparse
import numpy as np
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.config import config
from src.mobilenetv3 import mobilenet_v3_small
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, required=True, help='Checkpoint file path')
args_opt = parser.parse_args()
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
net = mobilenet_v3_small(num_classes=config.num_classes, multiplier=1.)
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
input_shp = [1, 3, config.image_height, config.image_width]
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
export(net, input_array, file_name=config.export_file, file_format=config.export_format)

View File

@ -0,0 +1,61 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -f $PATH1 ]; then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -d $PATH2 ]; then
echo "error: DATASET_PATH=$PATH2 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env >env.log
python train.py --dataset_path=$PATH2 --run_distribute=True > log.txt 2>&1 &
cd ..
done

View File

@ -0,0 +1,25 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
DATA_DIR=$2
PATH_CHECKPOINT=$3
python ./eval.py \
--device_target=Ascend \
--device_id=$DEVICE_ID \
--checkpoint_path=$PATH_CHECKPOINT \
--dataset_path=$DATA_DIR > eval.log 2>&1 &

View File

@ -0,0 +1,22 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
DATA_DIR=$2
python ./train.py \
--device_id=$DEVICE_ID \
--dataset_path=$DATA_DIR > log.txt 2>&1 &

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
config = ed({
"num_classes": 1000,
"image_height": 224,
"image_width": 224,
"batch_size": 150,
"epoch_size": 370,
"warmup_epochs": 4,
"lr": 0.05,
"momentum": 0.9,
"weight_decay": 4e-5,
"label_smooth": 0.1,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 5,
"save_checkpoint_path": "./checkpoint",
"export_file": "mobilenetv3_small",
"export_format": "MINDIR",
})

View File

@ -0,0 +1,67 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
def create_dataset(dataset_path, do_train, batch_size=16, device_num=1, rank=0):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
batch_size(int): the batch size of dataset. Default: 16.
device_num (int): Number of shards that the dataset should be divided into (default=1).
rank (int): The shard ID within num_shards (default=0).
Returns:
dataset
"""
if device_num == 1:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=64, shuffle=True)
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=64, shuffle=True,
num_shards=device_num, shard_id=rank)
# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(224),
C.RandomHorizontalFlip(prob=0.5),
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
]
else:
trans = [
C.Decode(),
C.Resize(256),
C.CenterCrop(224)
]
trans += [
C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255]),
C.HWC2CHW(),
C2.TypeCast(mstype.float32)
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
return ds

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropyWithLabelSmooth(_Loss):
"""CrossEntropyWithLabelSmooth"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropyWithLabelSmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss

View File

@ -0,0 +1,54 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""
generate learning rate array
Args:
global_step(int): total steps of the training
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_end + \
(lr_max - lr_end) * \
(1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2.
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate

View File

@ -0,0 +1,406 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""MobileNetV3 model define"""
from functools import partial
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor
__all__ = ['mobilenet_v3_large',
'mobilenet_v3_small']
class hswish(nn.Cell):
"""hswish"""
def construct(self, x):
out = x * nn.ReLU6()(x + 3) / 6
return out
class hsigmoid(nn.Cell):
"""hsigmoid"""
def construct(self, x):
out = nn.ReLU6()(x + 3) / 6
return out
def _make_divisible(x, divisor=8):
"""_make_divisible"""
return int(np.ceil(x * 1. / divisor) * divisor)
class Activation(nn.Cell):
"""
Activation definition.
Args:
act_func(string): activation name.
Returns:
Tensor, output tensor.
"""
def __init__(self, act_func):
super(Activation, self).__init__()
if act_func == 'relu':
self.act = nn.ReLU()
elif act_func == 'relu6':
self.act = nn.ReLU6()
elif act_func in ('hsigmoid', 'hard_sigmoid'):
self.act = hsigmoid()
elif act_func in ('hswish', 'hard_swish'):
self.act = hswish()
else:
raise NotImplementedError
def construct(self, x):
return self.act(x)
class GlobalAvgPooling(nn.Cell):
"""
Global avg pooling definition.
Args:
Returns:
Tensor, output tensor.
Examples:
>>> GlobalAvgPooling()
"""
def __init__(self, keep_dims=False):
super(GlobalAvgPooling, self).__init__()
self.mean = P.ReduceMean(keep_dims=keep_dims)
def construct(self, x):
x = self.mean(x, (2, 3))
return x
class SE(nn.Cell):
"""
SE warpper definition.
Args:
num_out (int): Numbers of output channels.
ratio (int): middle output ratio.
Returns:
Tensor, output tensor.
Examples:
>>> SE(4)
"""
def __init__(self, num_out, ratio=4):
super(SE, self).__init__()
num_mid = _make_divisible(num_out // ratio)
self.pool = GlobalAvgPooling(keep_dims=True)
self.conv1 = nn.Conv2d(in_channels=num_out, out_channels=num_mid,
kernel_size=1, has_bias=True, pad_mode='pad')
self.act1 = Activation('relu')
self.conv2 = nn.Conv2d(in_channels=num_mid, out_channels=num_out,
kernel_size=1, has_bias=True, pad_mode='pad')
self.act2 = Activation('hsigmoid')
self.mul = P.Mul()
def construct(self, x):
out = self.pool(x)
out = self.conv1(out)
out = self.act1(out)
out = self.conv2(out)
out = self.act2(out)
out = self.mul(x, out)
return out
class Unit(nn.Cell):
"""
Unit warpper definition.
Args:
num_in (int): Input channel.
num_out (int): Output channel.
kernel_size (int): Input kernel size.
stride (int): Stride size.
padding (int): Padding number.
num_groups (int): Output num group.
use_act (bool): Used activation or not.
act_type (string): Activation type.
Returns:
Tensor, output tensor.
Examples:
>>> Unit(3, 3)
"""
def __init__(self, num_in, num_out, kernel_size=1, stride=1, padding=0, num_groups=1,
use_act=True, act_type='relu'):
super(Unit, self).__init__()
self.conv = nn.Conv2d(in_channels=num_in,
out_channels=num_out,
kernel_size=kernel_size,
stride=stride,
padding=padding,
group=num_groups,
has_bias=False,
pad_mode='pad')
self.bn = nn.BatchNorm2d(num_out)
self.use_act = use_act
self.act = Activation(act_type) if use_act else None
def construct(self, x):
out = self.conv(x)
out = self.bn(out)
if self.use_act:
out = self.act(out)
return out
class ResUnit(nn.Cell):
"""
ResUnit warpper definition.
Args:
num_in (int): Input channel.
num_mid (int): Middle channel.
num_out (int): Output channel.
kernel_size (int): Input kernel size.
stride (int): Stride size.
act_type (str): Activation type.
use_se (bool): Use SE warpper or not.
Returns:
Tensor, output tensor.
Examples:
>>> ResUnit(16, 3, 1, 1)
"""
def __init__(self, num_in, num_mid, num_out, kernel_size, stride=1, act_type='relu', use_se=False):
super(ResUnit, self).__init__()
self.use_se = use_se
self.first_conv = (num_out != num_mid)
self.use_short_cut_conv = True
if self.first_conv:
self.expand = Unit(num_in, num_mid, kernel_size=1,
stride=1, padding=0, act_type=act_type)
else:
self.expand = None
self.conv1 = Unit(num_mid, num_mid, kernel_size=kernel_size, stride=stride,
padding=self._get_pad(kernel_size), act_type=act_type, num_groups=num_mid)
if use_se:
self.se = SE(num_mid)
self.conv2 = Unit(num_mid, num_out, kernel_size=1, stride=1,
padding=0, act_type=act_type, use_act=False)
if num_in != num_out or stride != 1:
self.use_short_cut_conv = False
self.add = P.Add() if self.use_short_cut_conv else None
def construct(self, x):
"""construct"""
if self.first_conv:
out = self.expand(x)
else:
out = x
out = self.conv1(out)
if self.use_se:
out = self.se(out)
out = self.conv2(out)
if self.use_short_cut_conv:
out = self.add(x, out)
return out
def _get_pad(self, kernel_size):
"""set the padding number"""
pad = 0
if kernel_size == 1:
pad = 0
elif kernel_size == 3:
pad = 1
elif kernel_size == 5:
pad = 2
elif kernel_size == 7:
pad = 3
else:
raise NotImplementedError
return pad
class MobileNetV3(nn.Cell):
"""
MobileNetV3 architecture.
Args:
model_cfgs (Cell): number of classes.
num_classes (int): Output number classes.
multiplier (int): Channels multiplier for round to 8/16 and others. Default is 1.
final_drop (float): Dropout number.
round_nearest (list): Channel round to . Default is 8.
Returns:
Tensor, output tensor.
Examples:
>>> MobileNetV3(num_classes=1000)
"""
def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0., round_nearest=8):
super(MobileNetV3, self).__init__()
self.cfgs = model_cfgs['cfg']
self.inplanes = 16
self.features = []
first_conv_in_channel = 3
first_conv_out_channel = _make_divisible(multiplier * self.inplanes)
self.features.append(nn.Conv2d(in_channels=first_conv_in_channel,
out_channels=first_conv_out_channel,
kernel_size=3, padding=1, stride=2,
has_bias=False, pad_mode='pad'))
self.features.append(nn.BatchNorm2d(first_conv_out_channel))
self.features.append(Activation('hswish'))
for layer_cfg in self.cfgs:
self.features.append(self._make_layer(kernel_size=layer_cfg[0],
exp_ch=_make_divisible(multiplier * layer_cfg[1]),
out_channel=_make_divisible(multiplier * layer_cfg[2]),
use_se=layer_cfg[3],
act_func=layer_cfg[4],
stride=layer_cfg[5]))
output_channel = _make_divisible(multiplier * model_cfgs["cls_ch_squeeze"])
self.features.append(nn.Conv2d(in_channels=_make_divisible(multiplier * self.cfgs[-1][2]),
out_channels=output_channel,
kernel_size=1, padding=0, stride=1,
has_bias=False, pad_mode='pad'))
self.features.append(nn.BatchNorm2d(output_channel))
self.features.append(Activation('hswish'))
self.features.append(GlobalAvgPooling(keep_dims=True))
self.features.append(nn.Conv2d(in_channels=output_channel,
out_channels=model_cfgs['cls_ch_expand'],
kernel_size=1, padding=0, stride=1,
has_bias=False, pad_mode='pad'))
self.features.append(Activation('hswish'))
if final_drop > 0:
self.features.append((nn.Dropout(final_drop)))
# make it nn.CellList
self.features = nn.SequentialCell(self.features)
self.output = nn.Conv2d(in_channels=model_cfgs['cls_ch_expand'],
out_channels=num_classes,
kernel_size=1, has_bias=True, pad_mode='pad')
self.squeeze = P.Squeeze(axis=(2, 3))
self._initialize_weights()
def construct(self, x):
x = self.features(x)
x = self.output(x)
x = self.squeeze(x)
return x
def _make_layer(self, kernel_size, exp_ch, out_channel, use_se, act_func, stride=1):
mid_planes = exp_ch
out_planes = out_channel
layer = ResUnit(self.inplanes, mid_planes, out_planes,
kernel_size, stride=stride, act_type=act_func, use_se=use_se)
self.inplanes = out_planes
return layer
def _initialize_weights(self):
"""
Initialize weights.
Args:
Returns:
None.
Examples:
>>> _initialize_weights()
"""
self.init_parameters_data()
for _, m in self.cells_and_names():
if isinstance(m, (nn.Conv2d)):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
m.weight.data.shape).astype("float32")))
if m.bias is not None:
m.bias.set_data(
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
elif isinstance(m, nn.BatchNorm2d):
m.gamma.set_data(
Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
m.beta.set_data(
Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
elif isinstance(m, nn.Dense):
m.weight.set_data(Tensor(np.random.normal(
0, 0.01, m.weight.data.shape).astype("float32")))
if m.bias is not None:
m.bias.set_data(
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
def mobilenet_v3(model_name, **kwargs):
"""
Constructs a MobileNet V2 model
"""
model_cfgs = {
"large": {
"cfg": [
# k, exp, c, se, nl, s,
[3, 16, 16, False, 'relu', 1],
[3, 64, 24, False, 'relu', 2],
[3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', 2],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hswish', 2],
[3, 200, 80, False, 'hswish', 1],
[3, 184, 80, False, 'hswish', 1],
[3, 184, 80, False, 'hswish', 1],
[3, 480, 112, True, 'hswish', 1],
[3, 672, 112, True, 'hswish', 1],
[5, 672, 160, True, 'hswish', 2],
[5, 960, 160, True, 'hswish', 1],
[5, 960, 160, True, 'hswish', 1]],
"cls_ch_squeeze": 960,
"cls_ch_expand": 1280,
},
"small": {
"cfg": [
# k, exp, c, se, nl, s,
[3, 16, 16, True, 'relu', 2],
[3, 72, 24, False, 'relu', 2],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hswish', 2],
[5, 240, 40, True, 'hswish', 1],
[5, 240, 40, True, 'hswish', 1],
[5, 120, 48, True, 'hswish', 1],
[5, 144, 48, True, 'hswish', 1],
[5, 288, 96, True, 'hswish', 2],
[5, 576, 96, True, 'hswish', 1],
[5, 576, 96, True, 'hswish', 1]],
"cls_ch_squeeze": 576,
"cls_ch_expand": 1280,
}
}
return MobileNetV3(model_cfgs[model_name], **kwargs)
mobilenet_v3_large = partial(mobilenet_v3, model_name="large")
mobilenet_v3_small = partial(mobilenet_v3, model_name="small")

View File

@ -0,0 +1,72 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Monitor loss and time"""
import time
import numpy as np
from mindspore import Tensor
from mindspore.train.callback import Callback
class Monitor(Callback):
"""
Monitor loss and time.
Args:
lr_init (numpy array): train lr
Returns:
None
Examples:
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
"""
def __init__(self, lr_init=None):
super(Monitor, self).__init__()
self.lr_init = lr_init
self.lr_init_len = len(lr_init)
def epoch_begin(self, run_context):
self.losses = []
self.epoch_time = time.time()
def epoch_end(self, run_context):
cb_params = run_context.original_args()
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / cb_params.batch_num
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds,
per_step_mseconds,
np.mean(self.losses)))
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
"""step end"""
cb_params = run_context.original_args()
step_mseconds = (time.time() - self.step_time) * 1000
step_loss = cb_params.net_outputs
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
step_loss = step_loss[0]
if isinstance(step_loss, Tensor):
step_loss = np.mean(step_loss.asnumpy())
self.losses.append(step_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format(
cb_params.cur_epoch_num -
1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss,
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))

View File

@ -0,0 +1,139 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_imagenet."""
import os
import ast
import argparse
from mindspore import context
from mindspore import Tensor
from mindspore.nn import RMSProp
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from mindspore.communication.management import init
from src.dataset import create_dataset
from src.lr_generator import get_lr
from src.config import config
from src.loss import CrossEntropyWithLabelSmooth
from src.monitor import Monitor
from src.mobilenetv3 import mobilenet_v3_small
set_seed(1)
parser = argparse.ArgumentParser(description='Image classification')
# modelarts parameter
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
# Ascend parameter
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
parser.add_argument('--device_id', type=int, default=0, help='Device id')
parser.add_argument('--run_modelarts', type=ast.literal_eval, default=False, help='Run mode')
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
if __name__ == '__main__':
# init distributed
if args_opt.run_modelarts:
import moxing as mox
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
context.set_context(device_id=device_id)
local_data_url = '/cache/data'
local_train_url = '/cache/ckpt'
if device_num > 1:
init()
context.set_auto_parallel_context(device_num=device_num, parallel_mode='data_parallel', gradients_mean=True)
local_data_url = os.path.join(local_data_url, str(device_id))
mox.file.copy_parallel(args_opt.data_url, local_data_url)
else:
if args_opt.run_distribute:
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
context.set_context(device_id=device_id)
init()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
context.set_context(device_id=args_opt.device_id)
device_num = 1
device_id = 0
# define net
net = mobilenet_v3_small(num_classes=config.num_classes, multiplier=1.)
# define loss
if config.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(
smooth_factor=config.label_smooth, num_classes=config.num_classes)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define dataset
if args_opt.run_modelarts:
dataset = create_dataset(dataset_path=local_data_url,
do_train=True,
batch_size=config.batch_size,
device_num=device_num, rank=device_id)
else:
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True,
batch_size=config.batch_size,
device_num=device_num, rank=device_id)
step_size = dataset.get_dataset_size()
# resume
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict)
# define optimizer
loss_scale = FixedLossScaleManager(
config.loss_scale, drop_overflow_update=False)
lr = Tensor(get_lr(global_step=0,
lr_init=0,
lr_end=0,
lr_max=config.lr,
warmup_epochs=config.warmup_epochs,
total_epochs=config.epoch_size,
steps_per_epoch=step_size))
opt = RMSProp(net.trainable_params(), learning_rate=lr, decay=0.9, weight_decay=config.weight_decay,
momentum=config.momentum, epsilon=0.001, loss_scale=config.loss_scale)
# define model
model = Model(net, loss_fn=loss, optimizer=opt,
loss_scale_manager=loss_scale, amp_level='O3')
cb = [Monitor(lr_init=lr.asnumpy())]
if config.save_checkpoint and (device_num == 1 or device_id == 0):
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
if args_opt.run_modelarts:
ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=local_train_url, config=config_ck)
else:
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'model_' + str(device_id) + '/')
ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=save_ckpt_path, config=config_ck)
cb += [ckpt_cb]
# begine train
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
if args_opt.run_modelarts and config.save_checkpoint and (device_num == 1 or device_id == 0):
mox.file.copy_parallel(local_train_url, args_opt.train_url)