forked from mindspore-Ecosystem/mindspore
!16259 efficientnet-b0 master
Merge pull request !16259 from Gogery/e3m
This commit is contained in:
commit
dd2b88a163
|
@ -0,0 +1,189 @@
|
|||
# 目录
|
||||
|
||||
- [目录](#目录)
|
||||
- [EfficientNet-B0描述](#EfficientNet-B0描述)
|
||||
- [模型架构](#模型架构)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本和示例代码](#脚本和示例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [训练过程](#训练过程)
|
||||
- [启动](#启动)
|
||||
- [结果](#结果)
|
||||
- [评估过程](#评估过程)
|
||||
- [启动](#启动-1)
|
||||
- [结果](#结果-1)
|
||||
- [模型说明](#模型说明)
|
||||
- [训练性能](#训练性能)
|
||||
- [随机情况的描述](#随机情况的描述)
|
||||
- [ModelZoo 主页](#modelzoo-主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# EfficientNet-B0描述
|
||||
|
||||
EfficientNet是一种卷积神经网络架构和缩放方法,它使用复合系数统一缩放深度/宽度/分辨率的所有维度。与任意缩放这些因素的常规做法不同,EfficientNet缩放方法使用一组固定的缩放系数来均匀缩放网络宽度,深度和分辨率。(2019年)
|
||||
|
||||
[论文](https://arxiv.org/abs/1905.11946):Mingxing Tan, Quoc V. Le. EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. 2019.
|
||||
|
||||
# 模型架构
|
||||
|
||||
EfficientNet总体网络架构如下:
|
||||
|
||||
[链接](https://arxiv.org/abs/1905.11946)
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的数据集:[imagenet](http://www.image-net.org/)
|
||||
|
||||
- 数据集大小: 146G, 1330k 1000类彩色图像
|
||||
- 训练: 140G, 1280k张图片
|
||||
- 测试: 6G, 50k张图片
|
||||
- 数据格式:RGB
|
||||
- 注:数据在src/dataset.py中处理。
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件(Ascend)
|
||||
- 使用Ascend来搭建硬件环境。
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore 教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
# 脚本说明
|
||||
|
||||
## 脚本和样例代码
|
||||
|
||||
```python
|
||||
├── EfficientNet-B0
|
||||
├── README_CN.md # EfficientNet-B0相关描述
|
||||
├── scripts
|
||||
│ ├──run_standalone_train.sh # 用于单卡训练的shell脚本
|
||||
│ ├──run_distribute_train.sh # 用于八卡训练的shell脚本
|
||||
│ └──run_eval.sh # 用于评估的shell脚本
|
||||
├── src
|
||||
│ ├──models # EfficientNet-B0架构
|
||||
│ │ ├──effnet.py
|
||||
│ │ └──layers.py
|
||||
│ ├──config.py # 参数配置
|
||||
│ ├──dataset.py # 创建数据集
|
||||
│ ├──loss.py # 损失函数
|
||||
│ ├──lr_generator.py # 配置学习率
|
||||
│ └──Monitor.py # 监控网络损失和其他数据
|
||||
├── eval.py # 评估脚本
|
||||
├── export.py # 模型格式转换脚本
|
||||
└── train.py # 训练脚本
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
模型训练和评估过程中使用的参数可以在config.py中设置:
|
||||
|
||||
```python
|
||||
'class_num': 1000, # 数据集类别数
|
||||
'batch_size': 256, # 数据批次大小
|
||||
'loss_scale': 1024, # loss scale
|
||||
'momentum': 0.9, # 动量参数
|
||||
'weight_decay': 1e-5, # 权重衰减率
|
||||
'epoch_size': 350, # 模型迭代次数
|
||||
'save_checkpoint': True, # 是否保存ckpt文件
|
||||
'save_checkpoint_epochs': 1, # 每迭代相应次数保存一个ckpt文件
|
||||
'keep_checkpoint_max': 5, # 保存ckpt文件的最大数量
|
||||
'save_checkpoint_path': "./checkpoint", # 保存ckpt文件的路径
|
||||
'opt': 'rmsprop', # 优化器
|
||||
'opt_eps': 0.001, # 改善数值稳定性的优化器参数
|
||||
'warmup_epochs': 2, # warmup epoch数量
|
||||
'lr_decay_mode': 'liner', # 学习率下降方式
|
||||
'use_label_smooth': True, # 是否使用label smooth
|
||||
'label_smooth_factor': 0.1, # 标签平滑因子
|
||||
'lr_init': 0.0001, # 初始学习率
|
||||
'lr_max': 0.2, # 最大学习率
|
||||
'lr_end': 0.00001, # 最终学习率
|
||||
```
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 启动
|
||||
|
||||
您可以使用python或shell脚本进行训练。
|
||||
|
||||
```shell
|
||||
# 训练示例
|
||||
python:
|
||||
Ascend单卡训练示例:python train.py --device_id [DEVICE_ID] --dataset_path [DATA_DIR]
|
||||
|
||||
shell:
|
||||
Ascend单卡训练示例: sh ./scripts/run_standalone_train.sh [DEVICE_ID] [DATA_DIR]
|
||||
Ascend八卡并行训练:
|
||||
cd ./scripts/
|
||||
sh ./run_distribute_train.sh [RANK_TABLE_FILE] [DATA_DIR]
|
||||
```
|
||||
|
||||
### 结果
|
||||
|
||||
ckpt文件将存储在 `./checkpoint` 路径下,训练日志将被记录到 `log.txt` 中。训练日志部分示例如下:
|
||||
|
||||
```shell
|
||||
epoch 1: epoch time: 665943.590, per step time: 1065.510, avg loss: 5.273
|
||||
epoch 2: epoch time: 297900.211, per step time: 476.640, avg loss: 4.286
|
||||
epoch 3: epoch time: 297218.029, per step time: 475.549, avg loss: 3.869
|
||||
epoch 4: epoch time: 297271.768, per step time: 475.635, avg loss: 3.648
|
||||
epoch 5: epoch time: 297314.768, per step time: 475.704, avg loss: 3.356
|
||||
```
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 启动
|
||||
|
||||
您可以使用python或shell脚本进行评估。
|
||||
|
||||
```shell
|
||||
# 评估示例
|
||||
python:
|
||||
python eval.py --device_id [DEVICE_ID] --dataset_path [DATA_DIR] --checkpoint_path [PATH_CHECKPOINT]
|
||||
|
||||
shell:
|
||||
sh ./scripts/run_eval.sh [DEVICE_ID] [DATA_DIR] [PATH_CHECKPOINT]
|
||||
```
|
||||
|
||||
> 训练过程中可以生成ckpt文件。
|
||||
|
||||
### 结果
|
||||
|
||||
可以在 `eval_log.txt` 查看评估结果。
|
||||
|
||||
```shell
|
||||
result: {'Loss': 1.8745046273255959, 'Top_1_Acc': 0.7668870192307692, 'Top_5_Acc': 0.9318509615384616} ckpt= ./checkpoint/model_0/Efficientnet_b0-rank0-350_625.ckpt
|
||||
```
|
||||
|
||||
# 模型说明
|
||||
|
||||
## 训练性能
|
||||
|
||||
| 参数 | Ascend |
|
||||
| -------------------------- | ------------------------------------- |
|
||||
| 模型名称 | EfficientNet |
|
||||
| 模型版本 | B0 |
|
||||
| 运行环境 | HUAWEI CLOUD Modelarts |
|
||||
| 上传时间 | 2021-3-28 |
|
||||
| 数据集 | imagenet |
|
||||
| 训练参数 | src/config.py |
|
||||
| 优化器 | RMSProp |
|
||||
| 损失函数 | CrossEntropySmooth |
|
||||
| 最终损失 | 1.87 |
|
||||
| 精确度 (8p) | Top1[76.7%], Top5[93.2%] |
|
||||
| 训练总时间 (8p) | 29.5h |
|
||||
| 评估总时间 | 1min |
|
||||
| 参数量 (M) | 61M |
|
||||
| 脚本 | [链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/efficientnet-b0) |
|
||||
|
||||
# 随机情况的描述
|
||||
|
||||
我们在 `dataset.py` 和 `train.py` 脚本中设置了随机种子。
|
||||
|
||||
# ModelZoo
|
||||
|
||||
请核对官方 [主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train efficientnet."""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
from mindspore import context, nn
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.models.effnet import EfficientNet
|
||||
from src.config import config
|
||||
from src.dataset import create_dataset
|
||||
from src.loss import CrossEntropySmooth
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
# modelarts parameter
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
|
||||
# Ascend parameter
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='Device id')
|
||||
|
||||
parser.add_argument('--run_modelarts', type=ast.literal_eval, default=False, help='Run distribute')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False)
|
||||
|
||||
if args_opt.run_modelarts:
|
||||
import moxing as mox
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(device_id=device_id)
|
||||
local_data_url = '/cache/data/'
|
||||
local_train_url = '/cache/ckpt/'
|
||||
mox.file.copy_parallel(args_opt.data_url, local_data_url)
|
||||
mox.file.copy_parallel(args_opt.train_url, local_train_url)
|
||||
else:
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
|
||||
# create dataset
|
||||
if args_opt.run_modelarts:
|
||||
dataset = create_dataset(dataset_path=local_data_url,
|
||||
do_train=False,
|
||||
batch_size=config.batch_size)
|
||||
ckpt_path = local_train_url + 'Efficientnet_b0-rank0-350_625.ckpt'
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
else:
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||
do_train=False,
|
||||
batch_size=config.batch_size)
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# define net
|
||||
net = EfficientNet(1, 1)
|
||||
|
||||
# load checkpoint
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
# define loss
|
||||
loss = CrossEntropySmooth(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
|
||||
# define model
|
||||
eval_metrics = {'Loss': nn.Loss(),
|
||||
'Top_1_Acc': nn.Top1CategoricalAccuracy(),
|
||||
'Top_5_Acc': nn.Top5CategoricalAccuracy()}
|
||||
model = Model(net, loss_fn=loss, metrics=eval_metrics)
|
||||
|
||||
# eval model
|
||||
res = model.eval(dataset)
|
||||
if args_opt.run_modelarts:
|
||||
print("result:", res, "ckpt=", local_data_url)
|
||||
else:
|
||||
print("result:", res, "ckpt=", args_opt.checkpoint_path)
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
efficientnet export.
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
from src.models.effnet import EfficientNet
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--checkpoint_path', type=str, required=True, help='Checkpoint file path')
|
||||
parser.add_argument("--file_name", type=str, default="resnet", help="output file name.")
|
||||
parser.add_argument('--width', type=int, default=224, help='input width')
|
||||
parser.add_argument('--height', type=int, default=224, help='input height')
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
net = EfficientNet(1, 1)
|
||||
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
input_shp = [1, 3, args_opt.height, args_opt.width]
|
||||
input_array = Tensor(np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
|
||||
export(net, input_array, file_name=args_opt.file_name, file_format=args_opt.file_format)
|
|
@ -0,0 +1,61 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
if [ ! -f $PATH1 ]; then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]; then
|
||||
echo "error: DATASET_PATH=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --dataset_path=$PATH2 --run_distribute=True > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,24 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
PATH_CHECKPOINT=$3
|
||||
|
||||
python ./eval.py \
|
||||
--device_id=$DEVICE_ID \
|
||||
--checkpoint_path=$PATH_CHECKPOINT \
|
||||
--dataset_path=$DATA_DIR > eval.log 2>&1 &
|
|
@ -0,0 +1,22 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=$1
|
||||
DATA_DIR=$2
|
||||
python ./train.py \
|
||||
--device_id=$DEVICE_ID \
|
||||
--dataset_path=$DATA_DIR > log.txt 2>&1 &
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
# config for efficientnet, imagenet2012.
|
||||
config = ed({
|
||||
"class_num": 1000,
|
||||
"batch_size": 256,
|
||||
"loss_scale": 1024,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-5,
|
||||
"epoch_size": 350,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 5,
|
||||
"save_checkpoint_path": "./checkpoint",
|
||||
"opt": 'rmsprop',
|
||||
"opt_eps": 0.001,
|
||||
"warmup_epochs": 2,
|
||||
"lr_decay_mode": "liner",
|
||||
"use_label_smooth": True,
|
||||
"label_smooth_factor": 0.1,
|
||||
"lr_init": 0.0001,
|
||||
"lr_max": 0.2,
|
||||
"lr_end": 0.00001
|
||||
})
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
|
||||
|
||||
def create_dataset(dataset_path, do_train, batch_size=16, device_num=1, rank=0):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
batch_size(int): the batch size of dataset. Default: 16.
|
||||
device_num (int): Number of shards that the dataset should be divided into (default=1).
|
||||
rank (int): The shard ID within num_shards (default=0).
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if device_num == 1:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=64, shuffle=True)
|
||||
else:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=64, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank)
|
||||
# define map operations
|
||||
if do_train:
|
||||
trans = [
|
||||
C.RandomCropDecodeResize(224),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize(255),
|
||||
C.CenterCrop(224)
|
||||
]
|
||||
trans += [
|
||||
C.Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]),
|
||||
# C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255]),
|
||||
C.HWC2CHW(),
|
||||
C2.TypeCast(mstype.float32)
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8)
|
||||
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8)
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define loss function for network"""
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class CrossEntropySmooth(_Loss):
|
||||
"""CrossEntropy"""
|
||||
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
|
||||
super(CrossEntropySmooth, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.sparse = sparse
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
|
||||
|
||||
def construct(self, logit, label):
|
||||
if self.sparse:
|
||||
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(logit, label)
|
||||
return loss
|
|
@ -0,0 +1,87 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
lr_decay_mode(string): learning rate decay mode, including steps, poly or default
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
if lr_decay_mode == 'steps':
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
for i in range(total_steps):
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr = lr_max * 0.01
|
||||
else:
|
||||
lr = lr_max * 0.001
|
||||
lr_each_step.append(lr)
|
||||
elif lr_decay_mode == 'poly':
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max) * base * base
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
elif lr_decay_mode == 'cosine':
|
||||
decay_steps = total_steps - warmup_steps
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
lr = float(lr_init) + lr_inc * (i + 1)
|
||||
else:
|
||||
linear_decay = (total_steps - i) / decay_steps
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
|
||||
decayed = linear_decay * cosine_decay + 0.00001
|
||||
lr = lr_max * decayed
|
||||
lr_each_step.append(lr)
|
||||
else:
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
|
||||
lr_each_step.append(lr)
|
||||
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
return lr_each_step
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""efficientnet model define"""
|
||||
import math
|
||||
import mindspore.nn as nn
|
||||
|
||||
from src.models.layers import conv_bn_act
|
||||
from src.models.layers import AdaptiveAvgPool2d
|
||||
from src.models.layers import Flatten
|
||||
from src.models.layers import SEModule
|
||||
from src.models.layers import DropConnect
|
||||
|
||||
|
||||
class MBConv(nn.Cell):
|
||||
"""MBConv"""
|
||||
def __init__(self, in_, out_, expand,
|
||||
kernel_size, stride, skip,
|
||||
se_ratio, dc_ratio=0.2):
|
||||
super().__init__()
|
||||
mid_ = in_ * expand
|
||||
self.expand = expand
|
||||
self.expand_conv = conv_bn_act(in_, mid_, kernel_size=1, bias=False)
|
||||
|
||||
self.depth_wise_conv = conv_bn_act(mid_, mid_,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
groups=mid_, bias=False)
|
||||
|
||||
self.se = SEModule(mid_, int(in_ * se_ratio))
|
||||
|
||||
self.project_conv = nn.SequentialCell([
|
||||
nn.Conv2d(mid_, out_, kernel_size=1, stride=1, has_bias=False),
|
||||
nn.BatchNorm2d(num_features=out_, eps=0.001, momentum=0.99)
|
||||
])
|
||||
self.skip = skip and (stride == 1) and (in_ == out_)
|
||||
|
||||
# DropConnect
|
||||
self.dropconnect = DropConnect(dc_ratio)
|
||||
|
||||
def construct(self, inputs):
|
||||
"""MBConv"""
|
||||
if self.expand != 1:
|
||||
expand = self.expand_conv(inputs)
|
||||
else:
|
||||
expand = inputs
|
||||
x = self.depth_wise_conv(expand)
|
||||
x = self.se(x)
|
||||
x = self.project_conv(x)
|
||||
if self.skip:
|
||||
x = x + inputs
|
||||
return x
|
||||
|
||||
|
||||
class MBBlock(nn.Cell):
|
||||
"""MBBlock"""
|
||||
def __init__(self, in_, out_, expand, kernel, stride, num_repeat, skip, se_ratio, drop_connect_ratio=0.2):
|
||||
super().__init__()
|
||||
layers = [MBConv(in_, out_, expand, kernel, stride, skip, se_ratio, drop_connect_ratio)]
|
||||
for _ in range(1, num_repeat):
|
||||
layers.append(MBConv(out_, out_, expand, kernel, 1, skip, se_ratio, drop_connect_ratio))
|
||||
self.layers = nn.SequentialCell([*layers])
|
||||
|
||||
def construct(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class EfficientNet(nn.Cell):
|
||||
"""efficientnet model"""
|
||||
def __init__(self, width_coeff, depth_coeff,
|
||||
depth_div=8, min_depth=None,
|
||||
dropout_rate=0.2, drop_connect_rate=0.2,
|
||||
num_classes=1000):
|
||||
super().__init__()
|
||||
min_depth = min_depth or depth_div
|
||||
dropout_rate = 1 - dropout_rate
|
||||
|
||||
def renew_ch(x):
|
||||
if not width_coeff:
|
||||
return x
|
||||
|
||||
x *= width_coeff
|
||||
new_x = max(min_depth, int(x + depth_div / 2) // depth_div * depth_div)
|
||||
if new_x < 0.9 * x:
|
||||
new_x += depth_div
|
||||
return int(new_x)
|
||||
|
||||
def renew_repeat(x):
|
||||
return int(math.ceil(x * depth_coeff))
|
||||
|
||||
self.stem = conv_bn_act(3, renew_ch(32), kernel_size=3, stride=2, bias=False)
|
||||
|
||||
self.blocks = nn.SequentialCell([
|
||||
# input channel output expand k s skip se
|
||||
MBBlock(renew_ch(32), renew_ch(16), 1, 3, 1, renew_repeat(1), True, 0.25, drop_connect_rate),
|
||||
MBBlock(renew_ch(16), renew_ch(24), 6, 3, 2, renew_repeat(2), True, 0.25, drop_connect_rate),
|
||||
MBBlock(renew_ch(24), renew_ch(40), 6, 5, 2, renew_repeat(2), True, 0.25, drop_connect_rate),
|
||||
MBBlock(renew_ch(40), renew_ch(80), 6, 3, 2, renew_repeat(3), True, 0.25, drop_connect_rate),
|
||||
MBBlock(renew_ch(80), renew_ch(112), 6, 5, 1, renew_repeat(3), True, 0.25, drop_connect_rate),
|
||||
MBBlock(renew_ch(112), renew_ch(192), 6, 5, 2, renew_repeat(4), True, 0.25, drop_connect_rate),
|
||||
MBBlock(renew_ch(192), renew_ch(320), 6, 3, 1, renew_repeat(1), True, 0.25, drop_connect_rate)
|
||||
])
|
||||
|
||||
self.head = nn.SequentialCell([
|
||||
*conv_bn_act(renew_ch(320), renew_ch(1280), kernel_size=1, bias=False),
|
||||
AdaptiveAvgPool2d(),
|
||||
nn.Dropout(dropout_rate),
|
||||
Flatten(),
|
||||
nn.Dense(renew_ch(1280), num_classes)
|
||||
])
|
||||
|
||||
def construct(self, inputs):
|
||||
stem = self.stem(inputs)
|
||||
x = self.blocks(stem)
|
||||
x = self.head(x)
|
||||
return x
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""efficientnet model define"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
|
||||
|
||||
def conv_bn_act(in_, out_, kernel_size,
|
||||
stride=1, groups=1, bias=True,
|
||||
eps=1e-3, momentum=0.01):
|
||||
"""conv_bn_act"""
|
||||
return nn.SequentialCell([
|
||||
nn.Conv2d(in_, out_, kernel_size, stride, group=groups, has_bias=bias),
|
||||
nn.BatchNorm2d(num_features=out_, eps=eps, momentum=1.0 - momentum),
|
||||
Swish()
|
||||
])
|
||||
|
||||
|
||||
class Swish(nn.Cell):
|
||||
"""Swish"""
|
||||
def construct(self, x):
|
||||
sigmoid = P.Sigmoid()
|
||||
x = x * sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class Flatten(nn.Cell):
|
||||
"""Flatten"""
|
||||
def construct(self, x):
|
||||
shape = P.Shape()
|
||||
reshape = P.Reshape()
|
||||
x = reshape(x, (shape(x)[0], -1))
|
||||
return x
|
||||
|
||||
|
||||
class SEModule(nn.Cell):
|
||||
"""SEModule"""
|
||||
def __init__(self, in_, squeeze_ch):
|
||||
super().__init__()
|
||||
|
||||
self.se = nn.SequentialCell([
|
||||
AdaptiveAvgPool2d(),
|
||||
nn.Conv2d(in_, squeeze_ch, kernel_size=1, stride=1, pad_mode='pad', padding=0, has_bias=True),
|
||||
Swish(),
|
||||
nn.Conv2d(squeeze_ch, in_, kernel_size=1, stride=1, pad_mode='pad', padding=0, has_bias=True),
|
||||
])
|
||||
|
||||
def construct(self, x):
|
||||
sigmoid = P.Sigmoid()
|
||||
x = x * sigmoid(self.se(x))
|
||||
return x
|
||||
|
||||
|
||||
class AdaptiveAvgPool2d(nn.Cell):
|
||||
"""AdaptiveAvgPool2d"""
|
||||
def __init__(self):
|
||||
super(AdaptiveAvgPool2d, self).__init__()
|
||||
self.mean = P.ReduceMean(True)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.mean(x, (2, 3))
|
||||
return x
|
||||
|
||||
class DropConnect(nn.Cell):
|
||||
"""DropConnect"""
|
||||
def __init__(self, ratio):
|
||||
super().__init__()
|
||||
self.ratio = 1.0 - ratio
|
||||
|
||||
def construct(self, x):
|
||||
"""DropConnect"""
|
||||
if not self.training:
|
||||
return x
|
||||
|
||||
random_tensor = self.ratio
|
||||
shape = (random_tensor.shape[0], 1, 1, 1)
|
||||
stdnormal = P.StandardNormal(seed=2)
|
||||
random_tensor = stdnormal(shape)
|
||||
random_tensor.requires_grad = False
|
||||
floor = P.Floor()
|
||||
x = x / self.ratio * floor(random_tensor)
|
||||
return x
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Monitor loss and time"""
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
|
||||
class Monitor(Callback):
|
||||
"""
|
||||
Monitor loss and time.
|
||||
|
||||
Args:
|
||||
lr_init (numpy array): train lr
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
|
||||
"""
|
||||
|
||||
def __init__(self, lr_init=None):
|
||||
super(Monitor, self).__init__()
|
||||
self.lr_init = lr_init
|
||||
self.lr_init_len = len(lr_init)
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.losses = []
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
|
||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||
per_step_mseconds = epoch_mseconds / cb_params.batch_num
|
||||
print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds,
|
||||
per_step_mseconds,
|
||||
np.mean(self.losses)))
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
"""step end"""
|
||||
cb_params = run_context.original_args()
|
||||
step_mseconds = (time.time() - self.step_time) * 1000
|
||||
step_loss = cb_params.net_outputs
|
||||
|
||||
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
|
||||
step_loss = step_loss[0]
|
||||
if isinstance(step_loss, Tensor):
|
||||
step_loss = np.mean(step_loss.asnumpy())
|
||||
|
||||
self.losses.append(step_loss)
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
|
||||
|
||||
print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format(
|
||||
cb_params.cur_epoch_num -
|
||||
1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss,
|
||||
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
|
|
@ -0,0 +1,149 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train efficientnet."""
|
||||
import os
|
||||
import ast
|
||||
import argparse
|
||||
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import SGD, RMSProp
|
||||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.lr_generator import get_lr
|
||||
from src.models.effnet import EfficientNet
|
||||
from src.config import config
|
||||
from src.monitor import Monitor
|
||||
from src.dataset import create_dataset
|
||||
from src.loss import CrossEntropySmooth
|
||||
|
||||
set_seed(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='image classification training')
|
||||
# modelarts parameter
|
||||
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
|
||||
|
||||
# Ascend parameter
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Run distribute')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='Device id')
|
||||
|
||||
parser.add_argument('--run_modelarts', type=ast.literal_eval, default=False, help='Run mode')
|
||||
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
|
||||
# init distributed
|
||||
if args_opt.run_modelarts:
|
||||
import moxing as mox
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(device_id=device_id)
|
||||
local_data_url = '/cache/data'
|
||||
local_train_url = '/cache/ckpt'
|
||||
if device_num > 1:
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=device_num, parallel_mode='data_parallel', gradients_mean=True)
|
||||
local_data_url = os.path.join(local_data_url, str(device_id))
|
||||
mox.file.copy_parallel(args_opt.data_url, local_data_url)
|
||||
else:
|
||||
if args_opt.run_distribute:
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
device_num = int(os.getenv('RANK_SIZE'))
|
||||
context.set_context(device_id=device_id)
|
||||
init()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
device_num = 1
|
||||
device_id = 0
|
||||
|
||||
# define network
|
||||
net = EfficientNet(1, 1)
|
||||
net.to_float(mstype.float16)
|
||||
|
||||
# define loss
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropySmooth(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
|
||||
# define dataset
|
||||
if args_opt.run_modelarts:
|
||||
dataset = create_dataset(dataset_path=local_data_url,
|
||||
do_train=True,
|
||||
batch_size=config.batch_size,
|
||||
device_num=device_num, rank=device_id)
|
||||
else:
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||
do_train=True,
|
||||
batch_size=config.batch_size,
|
||||
device_num=device_num, rank=device_id)
|
||||
step_size = dataset.get_dataset_size()
|
||||
|
||||
# resume
|
||||
if args_opt.resume:
|
||||
ckpt = load_checkpoint(args_opt.resume)
|
||||
load_param_into_net(net, ckpt)
|
||||
|
||||
# get learning rate
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
lr = Tensor(get_lr(lr_init=config.lr_init,
|
||||
lr_end=config.lr_end,
|
||||
lr_max=config.lr_max,
|
||||
warmup_epochs=config.warmup_epochs,
|
||||
total_epochs=config.epoch_size,
|
||||
steps_per_epoch=step_size,
|
||||
lr_decay_mode=config.lr_decay_mode))
|
||||
|
||||
# define optimization
|
||||
if config.opt == 'sgd':
|
||||
optimizer = SGD(net.trainable_params(), learning_rate=lr, momentum=config.momentum,
|
||||
weight_decay=config.weight_decay, loss_scale=config.loss_scale)
|
||||
elif config.opt == 'rmsprop':
|
||||
optimizer = RMSProp(net.trainable_params(), learning_rate=lr, decay=0.9, weight_decay=config.weight_decay,
|
||||
momentum=config.momentum, epsilon=config.opt_eps, loss_scale=config.loss_scale)
|
||||
|
||||
# define model
|
||||
model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale,
|
||||
metrics={'acc'}, amp_level='O3')
|
||||
|
||||
# define callbacks
|
||||
cb = [Monitor(lr_init=lr.asnumpy())]
|
||||
if config.save_checkpoint and (device_num == 1 or device_id == 0):
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
if args_opt.run_modelarts:
|
||||
ckpt_cb = ModelCheckpoint(f"Efficientnet_b0-rank{device_id}", directory=local_train_url, config=config_ck)
|
||||
else:
|
||||
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'model_' + str(device_id) + '/')
|
||||
ckpt_cb = ModelCheckpoint(f"Efficientnet_b0-rank{device_id}", directory=save_ckpt_path, config=config_ck)
|
||||
cb += [ckpt_cb]
|
||||
|
||||
# begine train
|
||||
model.train(config.epoch_size, dataset, callbacks=cb)
|
||||
if args_opt.run_modelarts and config.save_checkpoint and (device_num == 1 or device_id == 0):
|
||||
mox.file.copy_parallel(local_train_url, args_opt.train_url)
|
Loading…
Reference in New Issue