forked from mindspore-Ecosystem/mindspore
!16264 mobilenetV3_small_x1_0 master
Merge pull request !16264 from Gogery/mo3m
This commit is contained in:
commit
728224dd9e
|
@ -0,0 +1,185 @@
|
|||
# 目录
|
||||
|
||||
- [目录](#目录)
|
||||
- [MobileNetV3描述](#mobilenetv3描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本和示例代码](#脚本和示例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [启动](#启动)
|
||||
- [结果](#结果)
|
||||
- [评估过程](#评估过程)
|
||||
- [启动](#启动-1)
|
||||
- [结果](#结果-1)
|
||||
- [模型说明](#模型说明)
|
||||
- [训练性能](#训练性能)
|
||||
- [随机情况的描述](#随机情况的描述)
|
||||
- [ModelZoo 主页](#modelzoo-主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# MobileNetV3描述
|
||||
|
||||
MobileNetV3结合硬件感知神经网络架构搜索(NAS)和NetAdapt算法,已经可以移植到手机CPU上运行,后续随新架构进一步优化改进。(2019年11月20日)
|
||||
|
||||
[论文](https://arxiv.org/pdf/1905.02244):Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al."Searching for mobilenetv3."In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324.2019.
|
||||
|
||||
# 模型架构
|
||||
|
||||
MobileNetV3总体网络架构如下:
|
||||
|
||||
[链接](https://arxiv.org/pdf/1905.02244)
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的数据集:[imagenet](http://www.image-net.org/)
|
||||
|
||||
- 数据集大小: 146G, 1330k 1000类彩色图像
|
||||
- 训练: 140G, 1280k张图片
|
||||
- 测试: 6G, 50k张图片
|
||||
- 数据格式:RGB
|
||||
- 注:数据在src/dataset.py中处理。
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 使用Ascend来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore 教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本和样例代码
|
||||
|
||||
```python
|
||||
├── MobileNetV3
|
||||
├── README_CN.md # MobileNetV3相关描述
|
||||
├── scripts
|
||||
│ ├──run_standalone_train.sh # 用于单卡训练的shell脚本
|
||||
│ ├──run_distribute_train.sh # 用于八卡训练的shell脚本
|
||||
│ └──run_eval.sh # 用于评估的shell脚本
|
||||
├── src
|
||||
│ ├──config.py # 参数配置
|
||||
│ ├──dataset.py # 创建数据集
|
||||
│ ├──loss.py # 损失函数
|
||||
│ ├──lr_generator.py # 配置学习率
|
||||
│ ├──mobilenetV3.py # MobileNetV3架构
|
||||
│ └──monitor.py # 监控网络损失和其他数据
|
||||
├── eval.py # 评估脚本
|
||||
├── export.py # 模型格式转换脚本
|
||||
└── train.py # 训练脚本
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
模型训练和评估过程中使用的参数可以在config.py中设置:
|
||||
|
||||
```python
|
||||
'num_classes': 1000, # 数据集类别数
|
||||
'image_height': 224, # 输入图像高度
|
||||
'image_width': 224, # 输入图像宽度
|
||||
'batch_size': 256, # 数据批次大小
|
||||
'epoch_size': 370, # 模型迭代次数
|
||||
'warmup_epochs': 4, # warmup epoch数量
|
||||
'lr': 0.05, # 学习率
|
||||
'momentum': 0.9, # 动量参数
|
||||
'weight_decay': 4e-5, # 权重衰减率
|
||||
'label_smooth': 0.1, # 标签平滑因子
|
||||
'loss_scale': 1024, # loss scale
|
||||
'save_checkpoint': True, # 是否保存ckpt文件
|
||||
'save_checkpoint_epochs': 1, # 每迭代相应次数保存一个ckpt文件
|
||||
'keep_checkpoint_max': 5, # 保存ckpt文件的最大数量
|
||||
'save_checkpoint_path': "./checkpoint", # 保存ckpt文件的路径
|
||||
'export_file': "mobilenetv3_small", # export文件
|
||||
'export_format': "MINDIR", # export格式
|
||||
```
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 启动
|
||||
|
||||
您可以使用python或shell脚本进行训练。
|
||||
|
||||
```shell
|
||||
# 训练示例
|
||||
python:
|
||||
Ascend单卡训练示例:python train.py --device_id [DEVICE_ID] --dataset_path [DATA_DIR]
|
||||
|
||||
shell:
|
||||
Ascend单卡训练示例: sh ./scripts/run_standalone_train.sh [DEVICE_ID] [DATA_DIR]
|
||||
Ascend八卡并行训练:
|
||||
cd ./scripts/
|
||||
sh ./run_distribute_train.sh [RANK_TABLE_FILE] [DATA_DIR]
|
||||
```
|
||||
|
||||
### 结果
|
||||
|
||||
ckpt文件将存储在 `./checkpoint` 路径下,训练日志将被记录到 `log.txt` 中。训练日志部分示例如下:
|
||||
|
||||
```shell
|
||||
epoch 1: epoch time: 553262.126, per step time: 518.521, avg loss: 5.270
|
||||
epoch 2: epoch time: 151033.049, per step time: 141.549, avg loss: 4.529
|
||||
epoch 3: epoch time: 150605.300, per step time: 141.148, avg loss: 4.101
|
||||
epoch 4: epoch time: 150638.805, per step time: 141.180, avg loss: 4.014
|
||||
epoch 5: epoch time: 150594.088, per step time: 141.138, avg loss: 3.607
|
||||
```
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 启动
|
||||
|
||||
您可以使用python或shell脚本进行评估。
|
||||
|
||||
```shell
|
||||
# 评估示例
|
||||
python:
|
||||
python eval.py --device_id [DEVICE_ID] --dataset_path [DATA_DIR] --checkpoint_path [PATH_CHECKPOINT]
|
||||
|
||||
shell:
|
||||
sh ./scripts/run_eval.sh [DEVICE_ID] [DATA_DIR] [PATH_CHECKPOINT]
|
||||
```
|
||||
|
||||
> 训练过程中可以生成ckpt文件。
|
||||
|
||||
### 结果
|
||||
|
||||
可以在 `eval_log.txt` 查看评估结果。
|
||||
|
||||
```shell
|
||||
result: {'Loss': 2.3101649037352554, 'Top_1_Acc': 0.6746546546546547, 'Top_5_Acc': 0.8722122122122122} ckpt= ./checkpoint/model_0/mobilenetV3-370_625.ckpt
|
||||
```
|
||||
|
||||
# 模型说明
|
||||
|
||||
## 训练性能
|
||||
|
||||
| 参数 | Ascend |
|
||||
| -------------------------- | ------------------------------------- |
|
||||
| 模型名称 | mobilenetV3 |
|
||||
| 模型版本 | 小版本 |
|
||||
| 运行环境 | HUAWEI CLOUD Modelarts |
|
||||
| 上传时间 | 2021-3-25 |
|
||||
| 数据集 | imagenet |
|
||||
| 训练参数 | src/config.py |
|
||||
| 优化器 | RMSProp |
|
||||
| 损失函数 | CrossEntropyWithLabelSmooth |
|
||||
| 最终损失 | 2.31 |
|
||||
| 精确度 (8p) | Top1[67.5%], Top5[87.2%] |
|
||||
| 训练总时间 (8p) | 16.4h |
|
||||
| 评估总时间 | 1min |
|
||||
| 参数量 (M) | 36M |
|
||||
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/mobilenetV3_small_x1_0) |
|
||||
|
||||
# 随机情况的描述
|
||||
|
||||
我们在 `dataset.py` 和 `train.py` 脚本中设置了随机种子。
|
||||
|
||||
# ModelZoo
|
||||
|
||||
请核对官方 [主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
eval.
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
from mindspore import context
|
||||
from mindspore import nn
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.dataset import create_dataset
|
||||
from src.config import config
|
||||
from src.loss import CrossEntropyWithLabelSmooth
|
||||
from src.mobilenetv3 import mobilenet_v3_small
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
# modelarts parameter
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
|
||||
# Ascend parameter
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='Device id')
|
||||
|
||||
parser.add_argument('--run_modelarts', type=ast.literal_eval, default=False, help='Run distribute')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args_opt.run_modelarts:
|
||||
import moxing as mox
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(device_id=device_id)
|
||||
local_data_url = '/cache/data/'
|
||||
local_train_url = '/cache/ckpt/'
|
||||
mox.file.copy_parallel(args_opt.data_url, local_data_url)
|
||||
mox.file.copy_parallel(args_opt.train_url, local_train_url)
|
||||
else:
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False)
|
||||
net = mobilenet_v3_small(num_classes=config.num_classes, multiplier=1.)
|
||||
|
||||
if args_opt.run_modelarts:
|
||||
dataset = create_dataset(dataset_path=local_data_url,
|
||||
do_train=False,
|
||||
batch_size=config.batch_size)
|
||||
ckpt_path = local_train_url + 'mobilenetV3-370_1067.ckpt'
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
else:
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||
do_train=False,
|
||||
batch_size=config.batch_size)
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
# define loss, model
|
||||
loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth, num_classes=config.num_classes)
|
||||
|
||||
# define model
|
||||
eval_metrics = {'Loss': nn.Loss(),
|
||||
'Top_1_Acc': nn.Top1CategoricalAccuracy(),
|
||||
'Top_5_Acc': nn.Top5CategoricalAccuracy()}
|
||||
model = Model(net, loss_fn=loss, metrics=eval_metrics)
|
||||
|
||||
# eval model
|
||||
res = model.eval(dataset)
|
||||
if args_opt.run_modelarts:
|
||||
print("result:", res, "ckpt=", local_data_url)
|
||||
else:
|
||||
print("result:", res, "ckpt=", args_opt.checkpoint_path)
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
mobilenetv3_small export.
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
from src.config import config
|
||||
from src.mobilenetv3 import mobilenet_v3_small
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--checkpoint_path', type=str, required=True, help='Checkpoint file path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
net = mobilenet_v3_small(num_classes=config.num_classes, multiplier=1.)
|
||||
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
input_shp = [1, 3, config.image_height, config.image_width]
|
||||
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
|
||||
export(net, input_array, file_name=config.export_file, file_format=config.export_format)
|
|
@ -0,0 +1,61 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
if [ ! -f $PATH1 ]; then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]; then
|
||||
echo "error: DATASET_PATH=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --dataset_path=$PATH2 --run_distribute=True > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,25 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
PATH_CHECKPOINT=$3
|
||||
|
||||
python ./eval.py \
|
||||
--device_target=Ascend \
|
||||
--device_id=$DEVICE_ID \
|
||||
--checkpoint_path=$PATH_CHECKPOINT \
|
||||
--dataset_path=$DATA_DIR > eval.log 2>&1 &
|
|
@ -0,0 +1,22 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
python ./train.py \
|
||||
--device_id=$DEVICE_ID \
|
||||
--dataset_path=$DATA_DIR > log.txt 2>&1 &
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"num_classes": 1000,
|
||||
"image_height": 224,
|
||||
"image_width": 224,
|
||||
"batch_size": 150,
|
||||
"epoch_size": 370,
|
||||
"warmup_epochs": 4,
|
||||
"lr": 0.05,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 4e-5,
|
||||
"label_smooth": 0.1,
|
||||
"loss_scale": 1024,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 5,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
"export_file": "mobilenetv3_small",
|
||||
"export_format": "MINDIR",
|
||||
})
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
create train or eval dataset.
|
||||
"""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
|
||||
|
||||
def create_dataset(dataset_path, do_train, batch_size=16, device_num=1, rank=0):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
batch_size(int): the batch size of dataset. Default: 16.
|
||||
device_num (int): Number of shards that the dataset should be divided into (default=1).
|
||||
rank (int): The shard ID within num_shards (default=0).
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if device_num == 1:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=64, shuffle=True)
|
||||
else:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=64, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank)
|
||||
# define map operations
|
||||
if do_train:
|
||||
trans = [
|
||||
C.RandomCropDecodeResize(224),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize(256),
|
||||
C.CenterCrop(224)
|
||||
]
|
||||
trans += [
|
||||
C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255]),
|
||||
C.HWC2CHW(),
|
||||
C2.TypeCast(mstype.float32)
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
|
||||
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define loss function for network"""
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class CrossEntropyWithLabelSmooth(_Loss):
|
||||
"""CrossEntropyWithLabelSmooth"""
|
||||
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropyWithLabelSmooth, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.sparse = sparse
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
|
||||
|
||||
def construct(self, logit, label):
|
||||
if self.sparse:
|
||||
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(logit, label)
|
||||
return loss
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
global_step(int): total steps of the training
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_end + \
|
||||
(lr_max - lr_end) * \
|
||||
(1. + math.cos(math.pi * (i - warmup_steps) / (total_steps - warmup_steps))) / 2.
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
|
||||
return learning_rate
|
|
@ -0,0 +1,406 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""MobileNetV3 model define"""
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
__all__ = ['mobilenet_v3_large',
|
||||
'mobilenet_v3_small']
|
||||
|
||||
|
||||
class hswish(nn.Cell):
|
||||
"""hswish"""
|
||||
def construct(self, x):
|
||||
out = x * nn.ReLU6()(x + 3) / 6
|
||||
return out
|
||||
|
||||
|
||||
class hsigmoid(nn.Cell):
|
||||
"""hsigmoid"""
|
||||
def construct(self, x):
|
||||
out = nn.ReLU6()(x + 3) / 6
|
||||
return out
|
||||
|
||||
def _make_divisible(x, divisor=8):
|
||||
"""_make_divisible"""
|
||||
return int(np.ceil(x * 1. / divisor) * divisor)
|
||||
|
||||
|
||||
class Activation(nn.Cell):
|
||||
"""
|
||||
Activation definition.
|
||||
|
||||
Args:
|
||||
act_func(string): activation name.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, act_func):
|
||||
super(Activation, self).__init__()
|
||||
if act_func == 'relu':
|
||||
self.act = nn.ReLU()
|
||||
elif act_func == 'relu6':
|
||||
self.act = nn.ReLU6()
|
||||
elif act_func in ('hsigmoid', 'hard_sigmoid'):
|
||||
self.act = hsigmoid()
|
||||
elif act_func in ('hswish', 'hard_swish'):
|
||||
self.act = hswish()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def construct(self, x):
|
||||
return self.act(x)
|
||||
|
||||
|
||||
class GlobalAvgPooling(nn.Cell):
|
||||
"""
|
||||
Global avg pooling definition.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> GlobalAvgPooling()
|
||||
"""
|
||||
|
||||
def __init__(self, keep_dims=False):
|
||||
super(GlobalAvgPooling, self).__init__()
|
||||
self.mean = P.ReduceMean(keep_dims=keep_dims)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.mean(x, (2, 3))
|
||||
return x
|
||||
|
||||
|
||||
class SE(nn.Cell):
|
||||
"""
|
||||
SE warpper definition.
|
||||
|
||||
Args:
|
||||
num_out (int): Numbers of output channels.
|
||||
ratio (int): middle output ratio.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> SE(4)
|
||||
"""
|
||||
|
||||
def __init__(self, num_out, ratio=4):
|
||||
super(SE, self).__init__()
|
||||
num_mid = _make_divisible(num_out // ratio)
|
||||
self.pool = GlobalAvgPooling(keep_dims=True)
|
||||
self.conv1 = nn.Conv2d(in_channels=num_out, out_channels=num_mid,
|
||||
kernel_size=1, has_bias=True, pad_mode='pad')
|
||||
self.act1 = Activation('relu')
|
||||
self.conv2 = nn.Conv2d(in_channels=num_mid, out_channels=num_out,
|
||||
kernel_size=1, has_bias=True, pad_mode='pad')
|
||||
self.act2 = Activation('hsigmoid')
|
||||
self.mul = P.Mul()
|
||||
|
||||
def construct(self, x):
|
||||
out = self.pool(x)
|
||||
out = self.conv1(out)
|
||||
out = self.act1(out)
|
||||
out = self.conv2(out)
|
||||
out = self.act2(out)
|
||||
out = self.mul(x, out)
|
||||
return out
|
||||
|
||||
|
||||
class Unit(nn.Cell):
|
||||
"""
|
||||
Unit warpper definition.
|
||||
|
||||
Args:
|
||||
num_in (int): Input channel.
|
||||
num_out (int): Output channel.
|
||||
kernel_size (int): Input kernel size.
|
||||
stride (int): Stride size.
|
||||
padding (int): Padding number.
|
||||
num_groups (int): Output num group.
|
||||
use_act (bool): Used activation or not.
|
||||
act_type (string): Activation type.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> Unit(3, 3)
|
||||
"""
|
||||
|
||||
def __init__(self, num_in, num_out, kernel_size=1, stride=1, padding=0, num_groups=1,
|
||||
use_act=True, act_type='relu'):
|
||||
super(Unit, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels=num_in,
|
||||
out_channels=num_out,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
group=num_groups,
|
||||
has_bias=False,
|
||||
pad_mode='pad')
|
||||
self.bn = nn.BatchNorm2d(num_out)
|
||||
self.use_act = use_act
|
||||
self.act = Activation(act_type) if use_act else None
|
||||
|
||||
def construct(self, x):
|
||||
out = self.conv(x)
|
||||
out = self.bn(out)
|
||||
if self.use_act:
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResUnit(nn.Cell):
|
||||
"""
|
||||
ResUnit warpper definition.
|
||||
|
||||
Args:
|
||||
num_in (int): Input channel.
|
||||
num_mid (int): Middle channel.
|
||||
num_out (int): Output channel.
|
||||
kernel_size (int): Input kernel size.
|
||||
stride (int): Stride size.
|
||||
act_type (str): Activation type.
|
||||
use_se (bool): Use SE warpper or not.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> ResUnit(16, 3, 1, 1)
|
||||
"""
|
||||
def __init__(self, num_in, num_mid, num_out, kernel_size, stride=1, act_type='relu', use_se=False):
|
||||
super(ResUnit, self).__init__()
|
||||
self.use_se = use_se
|
||||
self.first_conv = (num_out != num_mid)
|
||||
self.use_short_cut_conv = True
|
||||
|
||||
if self.first_conv:
|
||||
self.expand = Unit(num_in, num_mid, kernel_size=1,
|
||||
stride=1, padding=0, act_type=act_type)
|
||||
else:
|
||||
self.expand = None
|
||||
self.conv1 = Unit(num_mid, num_mid, kernel_size=kernel_size, stride=stride,
|
||||
padding=self._get_pad(kernel_size), act_type=act_type, num_groups=num_mid)
|
||||
if use_se:
|
||||
self.se = SE(num_mid)
|
||||
self.conv2 = Unit(num_mid, num_out, kernel_size=1, stride=1,
|
||||
padding=0, act_type=act_type, use_act=False)
|
||||
if num_in != num_out or stride != 1:
|
||||
self.use_short_cut_conv = False
|
||||
self.add = P.Add() if self.use_short_cut_conv else None
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
if self.first_conv:
|
||||
out = self.expand(x)
|
||||
else:
|
||||
out = x
|
||||
out = self.conv1(out)
|
||||
if self.use_se:
|
||||
out = self.se(out)
|
||||
out = self.conv2(out)
|
||||
if self.use_short_cut_conv:
|
||||
out = self.add(x, out)
|
||||
return out
|
||||
|
||||
def _get_pad(self, kernel_size):
|
||||
"""set the padding number"""
|
||||
pad = 0
|
||||
if kernel_size == 1:
|
||||
pad = 0
|
||||
elif kernel_size == 3:
|
||||
pad = 1
|
||||
elif kernel_size == 5:
|
||||
pad = 2
|
||||
elif kernel_size == 7:
|
||||
pad = 3
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return pad
|
||||
|
||||
|
||||
class MobileNetV3(nn.Cell):
|
||||
"""
|
||||
MobileNetV3 architecture.
|
||||
|
||||
Args:
|
||||
model_cfgs (Cell): number of classes.
|
||||
num_classes (int): Output number classes.
|
||||
multiplier (int): Channels multiplier for round to 8/16 and others. Default is 1.
|
||||
final_drop (float): Dropout number.
|
||||
round_nearest (list): Channel round to . Default is 8.
|
||||
Returns:
|
||||
Tensor, output tensor.
|
||||
|
||||
Examples:
|
||||
>>> MobileNetV3(num_classes=1000)
|
||||
"""
|
||||
|
||||
def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0., round_nearest=8):
|
||||
super(MobileNetV3, self).__init__()
|
||||
self.cfgs = model_cfgs['cfg']
|
||||
self.inplanes = 16
|
||||
self.features = []
|
||||
first_conv_in_channel = 3
|
||||
first_conv_out_channel = _make_divisible(multiplier * self.inplanes)
|
||||
|
||||
self.features.append(nn.Conv2d(in_channels=first_conv_in_channel,
|
||||
out_channels=first_conv_out_channel,
|
||||
kernel_size=3, padding=1, stride=2,
|
||||
has_bias=False, pad_mode='pad'))
|
||||
self.features.append(nn.BatchNorm2d(first_conv_out_channel))
|
||||
self.features.append(Activation('hswish'))
|
||||
for layer_cfg in self.cfgs:
|
||||
self.features.append(self._make_layer(kernel_size=layer_cfg[0],
|
||||
exp_ch=_make_divisible(multiplier * layer_cfg[1]),
|
||||
out_channel=_make_divisible(multiplier * layer_cfg[2]),
|
||||
use_se=layer_cfg[3],
|
||||
act_func=layer_cfg[4],
|
||||
stride=layer_cfg[5]))
|
||||
output_channel = _make_divisible(multiplier * model_cfgs["cls_ch_squeeze"])
|
||||
self.features.append(nn.Conv2d(in_channels=_make_divisible(multiplier * self.cfgs[-1][2]),
|
||||
out_channels=output_channel,
|
||||
kernel_size=1, padding=0, stride=1,
|
||||
has_bias=False, pad_mode='pad'))
|
||||
self.features.append(nn.BatchNorm2d(output_channel))
|
||||
self.features.append(Activation('hswish'))
|
||||
self.features.append(GlobalAvgPooling(keep_dims=True))
|
||||
self.features.append(nn.Conv2d(in_channels=output_channel,
|
||||
out_channels=model_cfgs['cls_ch_expand'],
|
||||
kernel_size=1, padding=0, stride=1,
|
||||
has_bias=False, pad_mode='pad'))
|
||||
self.features.append(Activation('hswish'))
|
||||
if final_drop > 0:
|
||||
self.features.append((nn.Dropout(final_drop)))
|
||||
|
||||
# make it nn.CellList
|
||||
self.features = nn.SequentialCell(self.features)
|
||||
self.output = nn.Conv2d(in_channels=model_cfgs['cls_ch_expand'],
|
||||
out_channels=num_classes,
|
||||
kernel_size=1, has_bias=True, pad_mode='pad')
|
||||
self.squeeze = P.Squeeze(axis=(2, 3))
|
||||
|
||||
self._initialize_weights()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.features(x)
|
||||
x = self.output(x)
|
||||
x = self.squeeze(x)
|
||||
return x
|
||||
|
||||
def _make_layer(self, kernel_size, exp_ch, out_channel, use_se, act_func, stride=1):
|
||||
mid_planes = exp_ch
|
||||
out_planes = out_channel
|
||||
|
||||
layer = ResUnit(self.inplanes, mid_planes, out_planes,
|
||||
kernel_size, stride=stride, act_type=act_func, use_se=use_se)
|
||||
self.inplanes = out_planes
|
||||
return layer
|
||||
|
||||
def _initialize_weights(self):
|
||||
"""
|
||||
Initialize weights.
|
||||
|
||||
Args:
|
||||
|
||||
Returns:
|
||||
None.
|
||||
|
||||
Examples:
|
||||
>>> _initialize_weights()
|
||||
"""
|
||||
self.init_parameters_data()
|
||||
for _, m in self.cells_and_names():
|
||||
if isinstance(m, (nn.Conv2d)):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
|
||||
m.weight.data.shape).astype("float32")))
|
||||
if m.bias is not None:
|
||||
m.bias.set_data(
|
||||
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.gamma.set_data(
|
||||
Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
|
||||
m.beta.set_data(
|
||||
Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
|
||||
elif isinstance(m, nn.Dense):
|
||||
m.weight.set_data(Tensor(np.random.normal(
|
||||
0, 0.01, m.weight.data.shape).astype("float32")))
|
||||
if m.bias is not None:
|
||||
m.bias.set_data(
|
||||
Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
|
||||
|
||||
|
||||
def mobilenet_v3(model_name, **kwargs):
|
||||
"""
|
||||
Constructs a MobileNet V2 model
|
||||
"""
|
||||
model_cfgs = {
|
||||
"large": {
|
||||
"cfg": [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, False, 'relu', 1],
|
||||
[3, 64, 24, False, 'relu', 2],
|
||||
[3, 72, 24, False, 'relu', 1],
|
||||
[5, 72, 40, True, 'relu', 2],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[3, 240, 80, False, 'hswish', 2],
|
||||
[3, 200, 80, False, 'hswish', 1],
|
||||
[3, 184, 80, False, 'hswish', 1],
|
||||
[3, 184, 80, False, 'hswish', 1],
|
||||
[3, 480, 112, True, 'hswish', 1],
|
||||
[3, 672, 112, True, 'hswish', 1],
|
||||
[5, 672, 160, True, 'hswish', 2],
|
||||
[5, 960, 160, True, 'hswish', 1],
|
||||
[5, 960, 160, True, 'hswish', 1]],
|
||||
"cls_ch_squeeze": 960,
|
||||
"cls_ch_expand": 1280,
|
||||
},
|
||||
"small": {
|
||||
"cfg": [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, True, 'relu', 2],
|
||||
[3, 72, 24, False, 'relu', 2],
|
||||
[3, 88, 24, False, 'relu', 1],
|
||||
[5, 96, 40, True, 'hswish', 2],
|
||||
[5, 240, 40, True, 'hswish', 1],
|
||||
[5, 240, 40, True, 'hswish', 1],
|
||||
[5, 120, 48, True, 'hswish', 1],
|
||||
[5, 144, 48, True, 'hswish', 1],
|
||||
[5, 288, 96, True, 'hswish', 2],
|
||||
[5, 576, 96, True, 'hswish', 1],
|
||||
[5, 576, 96, True, 'hswish', 1]],
|
||||
"cls_ch_squeeze": 576,
|
||||
"cls_ch_expand": 1280,
|
||||
}
|
||||
}
|
||||
return MobileNetV3(model_cfgs[model_name], **kwargs)
|
||||
|
||||
|
||||
mobilenet_v3_large = partial(mobilenet_v3, model_name="large")
|
||||
mobilenet_v3_small = partial(mobilenet_v3, model_name="small")
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Monitor loss and time"""
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.callback import Callback
|
||||
class Monitor(Callback):
|
||||
"""
|
||||
Monitor loss and time.
|
||||
|
||||
Args:
|
||||
lr_init (numpy array): train lr
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
|
||||
"""
|
||||
|
||||
def __init__(self, lr_init=None):
|
||||
super(Monitor, self).__init__()
|
||||
self.lr_init = lr_init
|
||||
self.lr_init_len = len(lr_init)
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.losses = []
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
|
||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||
per_step_mseconds = epoch_mseconds / cb_params.batch_num
|
||||
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds,
|
||||
per_step_mseconds,
|
||||
np.mean(self.losses)))
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end"""
|
||||
cb_params = run_context.original_args()
|
||||
step_mseconds = (time.time() - self.step_time) * 1000
|
||||
step_loss = cb_params.net_outputs
|
||||
|
||||
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
|
||||
step_loss = step_loss[0]
|
||||
if isinstance(step_loss, Tensor):
|
||||
step_loss = np.mean(step_loss.asnumpy())
|
||||
|
||||
self.losses.append(step_loss)
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
|
||||
|
||||
print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format(
|
||||
cb_params.cur_epoch_num -
|
||||
1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss,
|
||||
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_imagenet."""
|
||||
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
|
||||
from mindspore.nn import RMSProp
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.communication.management import init
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.lr_generator import get_lr
|
||||
from src.config import config
|
||||
from src.loss import CrossEntropyWithLabelSmooth
|
||||
from src.monitor import Monitor
|
||||
from src.mobilenetv3 import mobilenet_v3_small
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
# modelarts parameter
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
|
||||
# Ascend parameter
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='Device id')
|
||||
|
||||
parser.add_argument('--run_modelarts', type=ast.literal_eval, default=False, help='Run mode')
|
||||
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# init distributed
|
||||
if args_opt.run_modelarts:
|
||||
import moxing as mox
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(device_id=device_id)
|
||||
local_data_url = '/cache/data'
|
||||
local_train_url = '/cache/ckpt'
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode='data_parallel', gradients_mean=True)
|
||||
local_data_url = os.path.join(local_data_url, str(device_id))
|
||||
mox.file.copy_parallel(args_opt.data_url, local_data_url)
|
||||
else:
|
||||
if args_opt.run_distribute:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(device_id=device_id)
|
||||
init()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
device_num = 1
|
||||
device_id = 0
|
||||
# define net
|
||||
net = mobilenet_v3_small(num_classes=config.num_classes, multiplier=1.)
|
||||
# define loss
|
||||
if config.label_smooth > 0:
|
||||
loss = CrossEntropyWithLabelSmooth(
|
||||
smooth_factor=config.label_smooth, num_classes=config.num_classes)
|
||||
else:
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
# define dataset
|
||||
if args_opt.run_modelarts:
|
||||
dataset = create_dataset(dataset_path=local_data_url,
|
||||
do_train=True,
|
||||
batch_size=config.batch_size,
|
||||
device_num=device_num, rank=device_id)
|
||||
else:
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||
do_train=True,
|
||||
batch_size=config.batch_size,
|
||||
device_num=device_num, rank=device_id)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# resume
|
||||
if args_opt.pre_trained:
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
load_param_into_net(net, param_dict)
|
||||
# define optimizer
|
||||
loss_scale = FixedLossScaleManager(
|
||||
config.loss_scale, drop_overflow_update=False)
|
||||
lr = Tensor(get_lr(global_step=0,
|
||||
lr_init=0,
|
||||
lr_end=0,
|
||||
lr_max=config.lr,
|
||||
warmup_epochs=config.warmup_epochs,
|
||||
total_epochs=config.epoch_size,
|
||||
steps_per_epoch=step_size))
|
||||
opt = RMSProp(net.trainable_params(), learning_rate=lr, decay=0.9, weight_decay=config.weight_decay,
|
||||
momentum=config.momentum, epsilon=0.001, loss_scale=config.loss_scale)
|
||||
# define model
|
||||
model = Model(net, loss_fn=loss, optimizer=opt,
|
||||
loss_scale_manager=loss_scale, amp_level='O3')
|
||||
|
||||
cb = [Monitor(lr_init=lr.asnumpy())]
|
||||
|
||||
if config.save_checkpoint and (device_num == 1 or device_id == 0):
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
if args_opt.run_modelarts:
|
||||
ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=local_train_url, config=config_ck)
|
||||
else:
|
||||
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'model_' + str(device_id) + '/')
|
||||
ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=save_ckpt_path, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
# begine train
|
||||
model.train(config.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
if args_opt.run_modelarts and config.save_checkpoint and (device_num == 1 or device_id == 0):
|
||||
mox.file.copy_parallel(local_train_url, args_opt.train_url)
|
Loading…
Reference in New Issue