This commit is contained in:
Yunzhu0v0 2020-12-18 22:54:21 +08:00 committed by GRAY0v0
parent 50ccfecb7b
commit 5d6fd30b59
15 changed files with 1121 additions and 0 deletions

View File

@ -0,0 +1,275 @@
# 目录
- [目录](#目录)
- [Tiny-DarkNet描述](#tiny-darknet描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本描述](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [单机训练](#单机训练)
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [模型描述](#模型描述)
- [性能](#性能)
- [训练性能](#训练性能)
- [评估性能](#评估性能)
- [ModelZoo主页](#modelzoo主页)
# [Tiny-DarkNet描述](#目录)
Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的图像分类数据集ImageNet所进行的图像分类网络模型。 Tiny-DarkNet作为作者为了满足用户对较小模型规模的需求而尽量降低模型的大小设计的简易版本的Darknet具有优于AlexNet和SqueezeNet的图像分类能力同时其只使用少于它们的模型参数。为了减少模型的规模该Tiny-DarkNet网络没有使用全连接层仅由卷积层、最大池化层、平均池化层组成。
更多Tiny-DarkNet详细信息可以参考[官方介绍](https://pjreddie.com/darknet/tiny-darknet/)
# [模型架构](#目录)
具体而言, Tiny-DarkNet网络由**1×1 conv**, **3×3 conv**, **2×2 max**和全局平均池化层组成,这些模块相互组成将输入的图片转换成一个**1×1000**的向量。
# [数据集](#目录)
以下将介绍模型中使用数据集以及其出处:
<!-- Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below. -->
<!-- Dataset used: [CIFAR-10](<http://www.cs.toronto.edu/~kriz/cifar.html>) -->
<!-- Dataset used ImageNet can refer to [paper](<https://ieeexplore.ieee.org/abstract/document/5206848>)
- Dataset size: 125G, 1250k colorful images in 1000 classes
- Train: 120G, 1200k images
- Test: 5G, 50k images
- Data format: RGB images.
- Note: Data will be processed in src/dataset.py -->
所使用的数据集可参考[论文](<https://ieeexplore.ieee.org/abstract/document/5206848>)
- 数据集规模125G1250k张分别属于1000个类的彩色图像
- 训练集: 120G,1200k张图片
- 测试集: 5G, 50k张图片
- 数据格式: RGB格式图片
- 注意: 数据将会被 src/dataset.py 中的函数进行处理
<!-- # [Features](#contents)
## [Distrubuted](#contents)
<!-- 不同的机器有同一个模型的多个副本,每个机器分配到不同的数据,然后将所有机器的计算结果按照某种方式合并 -->
<!-- 在深度学习中,当数据集和参数量的规模越来越大,训练所需的时间和硬件资源会随之增加,最后会变成制约训练的瓶颈。[分布式并行训练](<https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/distributed_training_tutorials.html>)可以降低对内存、计算性能等硬件的需求是进行训练的重要优化手段。本模型使用了mindspore提供的自动并行模式AUTO_PARALLEL该方法是融合了数据并行、模型并行及混合并行的1种分布式并行模式可以自动建立代价模型找到训练时间较短的并行策略为用户选择1种并行模式。 -->
# [环境要求](#目录)
- 硬件Ascend
- 请准备具有Ascend处理器的硬件环境.如果想使用Ascend资源请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) 至ascend@huawei.com. 当收到许可即可使用Ascend资源.
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 更多的信息请访问以下链接:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [快速入门](#目录)
根据官方网站成功安装MindSpore以后可以按照以下步骤进行训练和测试模型
- 在Ascend资源上运行
```python
# 单机训练
python train.py > train.log 2>&1 &
# 分布式训练
bash scripts/run_train.sh rank_table.json
# 评估
python eval.py > eval.log 2>&1 &
OR
bash run_eval.sh
```
进行并行训练时, 需要提前创建JSON格式的hccl配置文件。
请按照以下链接的指导进行设置:
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
更多的细节请参考具体的script文件
# [脚本描述](#目录)
## [脚本及样例代码](#目录)
```bash
├── Tiny-DarkNet
├── README.md // Tiny-Darknet相关说明
├── scripts
│ ├──run_train.sh // Ascend分布式训练shell脚本
│ ├──run_eval.sh // Ascend评估shell脚本
├── src
│ ├──dataset.py // 创建数据集
│ ├──tinydarknet.py // Tiny-Darknet网络结构
│ ├──config.py // 参数配置
├── train.py // 训练脚本
├── eval.py // 评估脚本
├── export.py // 导出checkpoint文件
```
## [脚本参数](#目录)
训练和测试的参数可在 config.py 中进行设置
- Tiny-Darknet的配置文件
```python
'pre_trained': 'False' # 是否载入预训练模型
'num_classes': 1000 # 数据集中类的数量
'lr_init': 0.1 # 初始学习率
'batch_size': 128 # 训练的batch_size
'epoch_size': 500 # 总共的训练epoch
'momentum': 0.9 # 动量
'weight_decay': 1e-4 # 权重衰减率
'image_height': 224 # 输入图像的高度
'image_width': 224 # 输入图像的宽度
'data_path': './ImageNet_Original/train/' # 训练数据集的绝对路径
'val_data_path': './ImageNet_Original/val/' # 评估数据集的绝对路径
'device_target': 'Ascend' # 程序运行的设备
'device_id': 0 # 用来训练和评估的设备编号
'keep_checkpoint_max': 10 # 仅仅保持最新的keep_checkpoint_max个checkpoint文件
'checkpoint_path': './train_tinydarknet_imagenet-125_390.ckpt' # 保存checkpoint文件的绝对路径
'onnx_filename': 'tinydarknet.onnx' # 用于export.py 文件中的onnx模型的文件名
'air_filename': 'tinydarknet.air' # 用于export.py 文件中的air模型的文件名
'lr_scheduler': 'exponential' # 学习率策略
'lr_epochs': [70, 140, 210, 280] # 学习率进行变化的epoch数
'lr_gamma': 0.3 # lr_scheduler为exponential时的学习率衰减因子
'eta_min': 0.0 # cosine_annealing策略中的eta_min
'T_max': 150 # cosine_annealing策略中的T-max
'warmup_epochs': 0 # 热启动的epoch数
'is_dynamic_loss_scale': 0 # 动态损失尺度
'loss_scale': 1024 # 损失尺度
'label_smooth_factor': 0.1 # 训练标签平滑因子
'use_label_smooth': True # 是否采用训练标签平滑
```
更多的细节, 请参考`config.py`.
## [训练过程](#目录)
### [单机训练](#目录)
- 在Ascend资源上运行
```python
python train.py > train.log 2>&1 &
```
上述的python命令将运行在后台中可以通过 `train.log` 文件查看运行结果.
训练完成后,默认情况下,可在script文件夹下得到一些checkpoint文件. 训练的损失值将以如下的形式展示:
<!-- After training, you'll get some checkpoint files under the script folder by default. The loss value will be achieved as follows: -->
```python
# grep "loss is " train.log
epoch: 498 step: 1251, loss is 2.7798953
Epoch time: 130690.544, per step time: 104.469
epoch: 499 step: 1251, loss is 2.9261637
Epoch time: 130511.081, per step time: 104.325
epoch: 500 step: 1251, loss is 2.69412
Epoch time: 127067.548, per step time: 101.573
...
```
模型checkpoint文件将会保存在当前文件夹下.
<!-- The model checkpoint will be saved in the current directory. -->
### [分布式训练](#目录)
- 在Ascend资源上运行
```python
sh scripts/run_train.sh
```
上述的脚本命令将在后台中进行分布式训练,可以通过`train_parallel[X]/log`文件查看运行结果. 训练的损失值将以如下的形式展示:
```python
# grep "result: " train_parallel*/log
epoch: 498 step: 1251, loss is 2.7798953
Epoch time: 130690.544, per step time: 104.469
epoch: 499 step: 1251, loss is 2.9261637
Epoch time: 130511.081, per step time: 104.325
epoch: 500 step: 1251, loss is 2.69412
Epoch time: 127067.548, per step time: 101.573
...
```
## [评估过程](#目录)
### [评估](#目录)
- 在Ascend资源上进行评估:
在运行如下命令前,请确认用于评估的checkpoint文件的路径.请将checkpoint路径设置为绝对路径,例如:"username/imagenet/train_tiny-darknet_imagenet-125_390.ckpt".
```python
python eval.py > eval.log 2>&1 &
OR
sh scripts/run_eval.sh
```
上述的python命令将运行在后台中可以通过"eval.log"文件查看结果. 测试数据集的准确率将如下面所列:
```python
# grep "accuracy: " eval.log
accuracy: {'top_1_accuracy': 0.5871979166666667, 'top_5_accuracy': 0.8175280448717949}
```
请注意在并行训练后,测试请将checkpoint_path设置为最后保存的checkpoint文件的路径,准确率将如下面所列:
```python
# grep "accuracy: " eval.log
accuracy: {'top_1_accuracy': 0.5871979166666667, 'top_5_accuracy': 0.8175280448717949}
```
# [模型描述](#目录)
## [性能](#目录)
### [训练性能](#目录)
| 参数 | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| 模型版本 | V1 |
| 资源 | Ascend 910, CPU 2.60GHz, 56cores, Memory 314G |
| 上传日期 | 2020/12/22 |
| MindSpore版本 | 1.1.0 |
| 数据集 | 1200k张图片 |
| 训练参数 | epoch=500, steps=1251, batch_size=128, lr=0.1 |
| 优化器 | Momentum |
| 损失函数 | Softmax Cross Entropy |
| 速度 | 8卡: 104 ms/step |
| 总时间 | 8卡: 17.8小时 |
| 参数(M) | 4.0 |
| 脚本 | [Tiny-Darknet脚本](https://gitee.com/mindspore/mindspore/tree/r0.7/model_zoo/official/cv/googlenet) |
### [评估性能](#目录)
| 参数 | Ascend |
| ------------------- | --------------------------- |
| 模型版本 | V1 |
| 资源 | Ascend 910 |
| 上传日期 | 2020/12/22 |
| MindSpore版本 | 1.1.0 |
| 数据集 | 200k张图片 |
| batch_size | 128 |
| 输出 | 分类概率 |
| 准确率 | 8卡 Top-5: 81.7% |
| 推理模型 | 11.6M (.ckpt文件) |
# [ModelZoo主页](#目录)
请参考官方[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,73 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############test tinydarknet example on cifar10#################
python eval.py
"""
import argparse
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from src.config import imagenet_cfg
from src.dataset import create_dataset_imagenet
from src.tinydarknet import TinyDarkNet
from src.CrossEntropySmooth import CrossEntropySmooth
set_seed(1)
parser = argparse.ArgumentParser(description='tinydarknet')
parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'],
help='dataset name.')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
args_opt = parser.parse_args()
if __name__ == '__main__':
if args_opt.dataset_name == "imagenet":
cfg = imagenet_cfg
dataset = create_dataset_imagenet(cfg.val_data_path, 1, False)
if not cfg.use_label_smooth:
cfg.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
net = TinyDarkNet(num_classes=cfg.num_classes)
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
else:
raise ValueError("dataset is not support.")
device_target = cfg.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
if device_target == "Ascend":
context.set_context(device_id=cfg.device_id)
if args_opt.checkpoint_path is not None:
param_dict = load_checkpoint(args_opt.checkpoint_path)
print("load checkpoint from [{}].".format(args_opt.checkpoint_path))
else:
param_dict = load_checkpoint(cfg.checkpoint_path)
print("load checkpoint from [{}].".format(cfg.checkpoint_path))
load_param_into_net(net, param_dict)
net.set_train(False)
acc = model.eval(dataset)
print("accuracy: ", acc)

View File

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air and onnx models#################
python export.py
"""
import argparse
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import imagenet_cfg
from src.tinydarknet import TinydarkNet
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Classification')
parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'],
help='dataset name.')
args_opt = parser.parse_args()
if args_opt.dataset_name == 'imagenet':
cfg = imagenet_cfg
else:
raise ValueError("dataset is not support.")
net = TinydarkNet(num_classes=cfg.num_classes)
assert cfg.checkpoint_path is not None, "cfg.checkpoint_path is None."
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]), ms.float32)
export(net, input_arr, file_name=cfg.onnx_filename, file_format="ONNX")
export(net, input_arr, file_name=cfg.air_filename, file_format="AIR")

View File

@ -0,0 +1,25 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.tinydarknet import TinyDarkNet
def tinydarknet(*args, **kwargs):
return TinyDarkNet(*args, **kwargs)
def create_network(name, *args, **kwargs):
if name == "tinydarknet":
return tinydarknet(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

View File

@ -0,0 +1,17 @@
#!/usr/bin/env bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
python ../eval.py > ./eval.log 2>&1 &

View File

@ -0,0 +1,66 @@
#!/usr/bin/env bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "$1 $2"
if [ $# != 1 ] && [ $# != 2 ]
then
echo "Usage: sh run_train.sh [RANK_TABLE_FILE] [cifar10|imagenet]"
exit 1
fi
if [ ! -f $1 ]
then
echo "error:RANK_TABLE_FILE=$1 is not a file"
exit 1
fi
dataset_type='cifar10'
if [ $# == 2 ]
then
if [ $2 != "cifar10" ] && [ $2 != "imagenet" ]
then
echo "error: the selected dataset is neither cifar10 nor imagenet"
exit 1
fi
dataset_type=$2
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
RANK_TABLE_FILE=$(realpath $1)
export RANK_TABLE_FILE
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
export SERVER_ID=0
rank_start=$((DEVICE_NUM * SERVER_ID))
for((i=0; i<${SERVER_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp -r ../src ./train_parallel$i
cp ../train.py ./train_parallel$i
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
cd .train_parallel$i || exit
env > env.log
python train.py --device_id=$i --dataset_name=$dataset_type> log 2>&1 &
cd ..
done

View File

@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss

View File

@ -0,0 +1,54 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in main.py
"""
from easydict import EasyDict as edict
imagenet_cfg = edict({
'name': 'imagenet',
'pre_trained': False,
'num_classes': 1000,
'lr_init': 0.1,
'batch_size': 128,
'epoch_size': 500,
'momentum': 0.9,
'weight_decay': 1e-4,
'image_height': 224,
'image_width': 224,
'data_path': './dataset/imagenet_original/train/',
'val_data_path': './dataset/imagenet_original/val/',
'device_target': 'Ascend',
'device_id': 0,
'device_num': 8,
'keep_checkpoint_max': 1,
'checkpoint_path': './scripts/train_parallel4/ckpt_4/train_tinydarknet_imagenet-300_1251.ckpt',
'onnx_filename': 'tinydarknet.onnx',
'air_filename': 'tinydarknet.air',
# optimizer and lr related
'lr_scheduler': 'exponential',
'lr_epochs': [70, 140, 210, 280],
'lr_gamma': 0.3,
'eta_min': 0.0,
'T_max': 150,
'warmup_epochs': 0,
# loss related
'is_dynamic_loss_scale': False,
'loss_scale': 1024,
'label_smooth_factor': 0.1,
'use_label_smooth': True,
})

View File

@ -0,0 +1,100 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as vision
from src.config import imagenet_cfg
def create_dataset_imagenet(dataset_path, repeat_num=1, training=True,
num_parallel_workers=None, shuffle=None):
"""
create a train or eval imagenet2012 dataset for resnet50
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
device_num, rank_id = _get_rank_info()
if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
num_shards=device_num, shard_id=rank_id)
assert imagenet_cfg.image_height == imagenet_cfg.image_width, "image_height not equal image_width"
image_size = imagenet_cfg.image_height
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if training:
transform_img = [
vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
vision.RandomHorizontalFlip(prob=0.5),
vision.RandomColorAdjust(0.4, 0.4, 0.4, 0.1),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
else:
transform_img = [
vision.Decode(),
vision.Resize(256),
vision.CenterCrop(image_size),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
transform_label = [C.TypeCast(mstype.int32)]
data_set = data_set.map(input_columns="image", num_parallel_workers=8, operations=transform_img)
data_set = data_set.map(input_columns="label", num_parallel_workers=8, operations=transform_label)
# apply batch operations
data_set = data_set.batch(imagenet_cfg.batch_size, drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
from mindspore.communication.management import get_rank, get_group_size
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = rank_id = None
return rank_size, rank_id

View File

@ -0,0 +1,20 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr"""
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr

View File

@ -0,0 +1,39 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr"""
import math
import numpy as np
from .linear_warmup import linear_warmup_lr
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
""" warmup cosine annealing lr"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)

View File

@ -0,0 +1,59 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""lr"""
from collections import Counter
import numpy as np
from .linear_warmup import linear_warmup_lr
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
"""warmup step lr"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * steps_per_epoch
milestones_steps.append(milestones_step)
lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma ** milestones_steps_counter[i]
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
"""lr"""
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
"""lr"""
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)

View File

@ -0,0 +1,142 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""TinydarkNet"""
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
from mindspore.ops import operations as P
def weight_variable():
"""Weight variable."""
return TruncatedNormal(0.02)
class Conv1x1dBlock(nn.Cell):
"""
Basic convolutional block
Args:
in_channles (int): Input channel.
out_channels (int): Output channel.
kernel_size (int): Input kernel size. Default: 1
stride (int): Stride size for the first convolutional layer. Default: 1.
padding (int): Implicit paddings on both sides of the input. Default: 0.
pad_mode (str): Padding mode. Optional values are "same", "valid", "pad". Default: "same".
Returns:
Tensor, output tensor.
"""
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode="same"):
super(Conv1x1dBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_variable())
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
self.leakyrelu = nn.LeakyReLU()
def construct(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.leakyrelu(x)
return x
class Conv3x3dBlock(nn.Cell):
"""
Basic convolutional block
Args:
in_channles (int): Input channel.
out_channels (int): Output channel.
kernel_size (int): Input kernel size. Default: 1
stride (int): Stride size for the first convolutional layer. Default: 1.
padding (int): Implicit paddings on both sides of the input. Default: 0.
pad_mode (str): Padding mode. Optional values are "same", "valid", "pad". Default: "same".
Returns:
Tensor, output tensor.
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, pad_mode="pad"):
super(Conv3x3dBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, pad_mode=pad_mode, weight_init=weight_variable())
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
self.leakyrelu = nn.LeakyReLU()
def construct(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.leakyrelu(x)
return x
class TinyDarkNet(nn.Cell):
"""
Tinydarknet architecture
"""
def __init__(self, num_classes, include_top=True):
super(TinyDarkNet, self).__init__()
self.conv1 = Conv3x3dBlock(3, 16)
self.conv2 = Conv3x3dBlock(16, 32)
self.conv3 = Conv1x1dBlock(32, 16)
self.conv4 = Conv3x3dBlock(16, 128)
self.conv5 = Conv1x1dBlock(128, 16)
self.conv6 = Conv3x3dBlock(16, 128)
self.conv7 = Conv1x1dBlock(128, 32)
self.conv8 = Conv3x3dBlock(32, 256)
self.conv9 = Conv1x1dBlock(256, 32)
self.conv10 = Conv3x3dBlock(32, 256)
self.conv11 = Conv1x1dBlock(256, 64)
self.conv12 = Conv3x3dBlock(64, 512)
self.conv13 = Conv1x1dBlock(512, 64)
self.conv14 = Conv3x3dBlock(64, 512)
self.conv15 = Conv1x1dBlock(512, 128)
self.conv16 = Conv1x1dBlock(128, 1000)
self.maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same")
self.avgpool2d = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
def construct(self, x):
"""construct"""
x = self.conv1(x)
x = self.maxpool2d(x)
x = self.conv2(x)
x = self.maxpool2d(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.conv6(x)
x = self.maxpool2d(x)
x = self.conv7(x)
x = self.conv8(x)
x = self.conv9(x)
x = self.conv10(x)
x = self.maxpool2d(x)
x = self.conv11(x)
x = self.conv12(x)
x = self.conv13(x)
x = self.conv14(x)
x = self.conv15(x)
x = self.conv16(x)
x = self.avgpool2d(x, (2, 3))
x = self.flatten(x)
return x

View File

@ -0,0 +1,165 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
#################train tinydarknet example on cifar10########################
python train.py
"""
import argparse
from mindspore import Tensor
from mindspore import context
from mindspore.communication.management import init, get_rank
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.common import set_seed
from src.config import imagenet_cfg
from src.dataset import create_dataset_imagenet
from src.tinydarknet import TinyDarkNet
from src.CrossEntropySmooth import CrossEntropySmooth
set_seed(1)
def lr_steps_imagenet(_cfg, steps_per_epoch):
"""lr step for imagenet"""
from src.lr_scheduler.warmup_step_lr import warmup_step_lr
from src.lr_scheduler.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
if _cfg.lr_scheduler == 'exponential':
_lr = warmup_step_lr(_cfg.lr_init,
_cfg.lr_epochs,
steps_per_epoch,
_cfg.warmup_epochs,
_cfg.epoch_size,
gamma=_cfg.lr_gamma,
)
elif _cfg.lr_scheduler == 'cosine_annealing':
_lr = warmup_cosine_annealing_lr(_cfg.lr_init,
steps_per_epoch,
_cfg.warmup_epochs,
_cfg.epoch_size,
_cfg.T_max,
_cfg.eta_min)
else:
raise NotImplementedError(_cfg.lr_scheduler)
return _lr
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Classification')
parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'],
help='dataset name.')
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
args_opt = parser.parse_args()
if args_opt.dataset_name == "imagenet":
cfg = imagenet_cfg
else:
raise ValueError("Unsupport dataset.")
# set context
device_target = cfg.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
device_num = cfg.device_num
rank = 0
if device_target == "Ascend":
if args_opt.device_id is not None:
context.set_context(device_id=args_opt.device_id)
else:
context.set_context(device_id=cfg.device_id)
if device_num > 1:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
rank = get_rank()
else:
raise ValueError("Unsupported platform.")
if args_opt.dataset_name == "imagenet":
dataset = create_dataset_imagenet(cfg.data_path, 1)
else:
raise ValueError("Unsupport dataset.")
batch_num = dataset.get_dataset_size()
net = TinyDarkNet(num_classes=cfg.num_classes)
# Continue training if set pre_trained to be True
if cfg.pre_trained:
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
loss_scale_manager = None
if args_opt.dataset_name == 'imagenet':
lr = lr_steps_imagenet(cfg, batch_num)
def get_param_groups(network):
""" get param groups """
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
else:
decay_params.append(x)
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
if cfg.is_dynamic_loss_scale:
cfg.loss_scale = 1
opt = Momentum(params=get_param_groups(net),
learning_rate=Tensor(lr),
momentum=cfg.momentum,
weight_decay=cfg.weight_decay,
loss_scale=cfg.loss_scale)
if not cfg.use_label_smooth:
cfg.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
if cfg.is_dynamic_loss_scale:
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
else:
loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O3", loss_scale_manager=loss_scale_manager)
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 50, keep_checkpoint_max=cfg.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=batch_num)
ckpt_save_dir = "./ckpt_" + str(rank) + "/"
ckpoint_cb = ModelCheckpoint(prefix="train_tinydarknet_" + args_opt.dataset_name, directory=ckpt_save_dir,
config=config_ck)
loss_cb = LossMonitor()
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
print("train success")