diff --git a/model_zoo/official/cv/tinydarknet/README_CN.md b/model_zoo/official/cv/tinydarknet/README_CN.md new file mode 100644 index 00000000000..66c61864d59 --- /dev/null +++ b/model_zoo/official/cv/tinydarknet/README_CN.md @@ -0,0 +1,275 @@ +# 目录 + +- [目录](#目录) +- [Tiny-DarkNet描述](#tiny-darknet描述) +- [模型架构](#模型架构) +- [数据集](#数据集) +- [环境要求](#环境要求) +- [快速入门](#快速入门) +- [脚本描述](#脚本说明) + - [脚本及样例代码](#脚本及样例代码) + - [脚本参数](#脚本参数) + - [训练过程](#训练过程) + - [单机训练](#单机训练) + - [分布式训练](#分布式训练) + - [评估过程](#评估过程) + - [评估](#评估) +- [模型描述](#模型描述) + - [性能](#性能) + - [训练性能](#训练性能) + - [评估性能](#评估性能) +- [ModelZoo主页](#modelzoo主页) + +# [Tiny-DarkNet描述](#目录) + +Tiny-DarkNet是Joseph Chet Redmon等人提出的一个16层的针对于经典的图像分类数据集ImageNet所进行的图像分类网络模型。 Tiny-DarkNet作为作者为了满足用户对较小模型规模的需求而尽量降低模型的大小设计的简易版本的Darknet,具有优于AlexNet和SqueezeNet的图像分类能力,同时其只使用少于它们的模型参数。为了减少模型的规模,该Tiny-DarkNet网络没有使用全连接层,仅由卷积层、最大池化层、平均池化层组成。 + +更多Tiny-DarkNet详细信息可以参考[官方介绍](https://pjreddie.com/darknet/tiny-darknet/) + +# [模型架构](#目录) + +具体而言, Tiny-DarkNet网络由**1×1 conv**, **3×3 conv**, **2×2 max**和全局平均池化层组成,这些模块相互组成将输入的图片转换成一个**1×1000**的向量。 + +# [数据集](#目录) + +以下将介绍模型中使用数据集以及其出处: + + + + + + +所使用的数据集可参考[论文]() + +- 数据集规模:125G,1250k张分别属于1000个类的彩色图像 + - 训练集: 120G,1200k张图片 + - 测试集: 5G, 50k张图片 +- 数据格式: RGB格式图片 + - 注意: 数据将会被 src/dataset.py 中的函数进行处理 + + + + + +# [环境要求](#目录) + +- 硬件(Ascend) + - 请准备具有Ascend处理器的硬件环境.如果想使用Ascend资源,请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) 至ascend@huawei.com. 当收到许可即可使用Ascend资源. +- 框架 + - [MindSpore](https://www.mindspore.cn/install/en) +- 更多的信息请访问以下链接: + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) + +# [快速入门](#目录) + +根据官方网站成功安装MindSpore以后,可以按照以下步骤进行训练和测试模型: + +- 在Ascend资源上运行: + + ```python + # 单机训练 + python train.py > train.log 2>&1 & + + # 分布式训练 + bash scripts/run_train.sh rank_table.json + + # 评估 + python eval.py > eval.log 2>&1 & + OR + bash run_eval.sh + ``` + + 进行并行训练时, 需要提前创建JSON格式的hccl配置文件。 + + 请按照以下链接的指导进行设置: + + + +更多的细节请参考具体的script文件 + +# [脚本描述](#目录) + +## [脚本及样例代码](#目录) + +```bash + +├── Tiny-DarkNet + ├── README.md // Tiny-Darknet相关说明 + ├── scripts + │ ├──run_train.sh // Ascend分布式训练shell脚本 + │ ├──run_eval.sh // Ascend评估shell脚本 + ├── src + │ ├──dataset.py // 创建数据集 + │ ├──tinydarknet.py // Tiny-Darknet网络结构 + │ ├──config.py // 参数配置 + ├── train.py // 训练脚本 + ├── eval.py // 评估脚本 + ├── export.py // 导出checkpoint文件 + +``` + +## [脚本参数](#目录) + +训练和测试的参数可在 config.py 中进行设置 + +- Tiny-Darknet的配置文件 + + ```python + 'pre_trained': 'False' # 是否载入预训练模型 + 'num_classes': 1000 # 数据集中类的数量 + 'lr_init': 0.1 # 初始学习率 + 'batch_size': 128 # 训练的batch_size + 'epoch_size': 500 # 总共的训练epoch + 'momentum': 0.9 # 动量 + 'weight_decay': 1e-4 # 权重衰减率 + 'image_height': 224 # 输入图像的高度 + 'image_width': 224 # 输入图像的宽度 + 'data_path': './ImageNet_Original/train/' # 训练数据集的绝对路径 + 'val_data_path': './ImageNet_Original/val/' # 评估数据集的绝对路径 + 'device_target': 'Ascend' # 程序运行的设备 + 'device_id': 0 # 用来训练和评估的设备编号 + 'keep_checkpoint_max': 10 # 仅仅保持最新的keep_checkpoint_max个checkpoint文件 + 'checkpoint_path': './train_tinydarknet_imagenet-125_390.ckpt' # 保存checkpoint文件的绝对路径 + 'onnx_filename': 'tinydarknet.onnx' # 用于export.py 文件中的onnx模型的文件名 + 'air_filename': 'tinydarknet.air' # 用于export.py 文件中的air模型的文件名 + 'lr_scheduler': 'exponential' # 学习率策略 + 'lr_epochs': [70, 140, 210, 280] # 学习率进行变化的epoch数 + 'lr_gamma': 0.3 # lr_scheduler为exponential时的学习率衰减因子 + 'eta_min': 0.0 # cosine_annealing策略中的eta_min + 'T_max': 150 # cosine_annealing策略中的T-max + 'warmup_epochs': 0 # 热启动的epoch数 + 'is_dynamic_loss_scale': 0 # 动态损失尺度 + 'loss_scale': 1024 # 损失尺度 + 'label_smooth_factor': 0.1 # 训练标签平滑因子 + 'use_label_smooth': True # 是否采用训练标签平滑 + ``` + +更多的细节, 请参考`config.py`. + +## [训练过程](#目录) + +### [单机训练](#目录) + +- 在Ascend资源上运行: + + ```python + python train.py > train.log 2>&1 & + ``` + + 上述的python命令将运行在后台中,可以通过 `train.log` 文件查看运行结果. + + 训练完成后,默认情况下,可在script文件夹下得到一些checkpoint文件. 训练的损失值将以如下的形式展示: + + + ```python + # grep "loss is " train.log + epoch: 498 step: 1251, loss is 2.7798953 + Epoch time: 130690.544, per step time: 104.469 + epoch: 499 step: 1251, loss is 2.9261637 + Epoch time: 130511.081, per step time: 104.325 + epoch: 500 step: 1251, loss is 2.69412 + Epoch time: 127067.548, per step time: 101.573 + ... + ``` + + 模型checkpoint文件将会保存在当前文件夹下. + + +### [分布式训练](#目录) + +- 在Ascend资源上运行: + + ```python + sh scripts/run_train.sh + ``` + + 上述的脚本命令将在后台中进行分布式训练,可以通过`train_parallel[X]/log`文件查看运行结果. 训练的损失值将以如下的形式展示: + + ```python + # grep "result: " train_parallel*/log + epoch: 498 step: 1251, loss is 2.7798953 + Epoch time: 130690.544, per step time: 104.469 + epoch: 499 step: 1251, loss is 2.9261637 + Epoch time: 130511.081, per step time: 104.325 + epoch: 500 step: 1251, loss is 2.69412 + Epoch time: 127067.548, per step time: 101.573 + ... + ``` + +## [评估过程](#目录) + +### [评估](#目录) + +- 在Ascend资源上进行评估: + + 在运行如下命令前,请确认用于评估的checkpoint文件的路径.请将checkpoint路径设置为绝对路径,例如:"username/imagenet/train_tiny-darknet_imagenet-125_390.ckpt". + + ```python + python eval.py > eval.log 2>&1 & + OR + sh scripts/run_eval.sh + ``` + + 上述的python命令将运行在后台中,可以通过"eval.log"文件查看结果. 测试数据集的准确率将如下面所列: + + ```python + # grep "accuracy: " eval.log + accuracy: {'top_1_accuracy': 0.5871979166666667, 'top_5_accuracy': 0.8175280448717949} + ``` + + 请注意在并行训练后,测试请将checkpoint_path设置为最后保存的checkpoint文件的路径,准确率将如下面所列: + + ```python + # grep "accuracy: " eval.log + accuracy: {'top_1_accuracy': 0.5871979166666667, 'top_5_accuracy': 0.8175280448717949} + ``` + +# [模型描述](#目录) + +## [性能](#目录) + +### [训练性能](#目录) + +| 参数 | Ascend | +| -------------------------- | ----------------------------------------------------------- | +| 模型版本 | V1 | +| 资源 | Ascend 910, CPU 2.60GHz, 56cores, Memory 314G | +| 上传日期 | 2020/12/22 | +| MindSpore版本 | 1.1.0 | +| 数据集 | 1200k张图片 | +| 训练参数 | epoch=500, steps=1251, batch_size=128, lr=0.1 | +| 优化器 | Momentum | +| 损失函数 | Softmax Cross Entropy | +| 速度 | 8卡: 104 ms/step | +| 总时间 | 8卡: 17.8小时 | +| 参数(M) | 4.0 | +| 脚本 | [Tiny-Darknet脚本](https://gitee.com/mindspore/mindspore/tree/r0.7/model_zoo/official/cv/googlenet) | + +### [评估性能](#目录) + +| 参数 | Ascend | +| ------------------- | --------------------------- | +| 模型版本 | V1 | +| 资源 | Ascend 910 | +| 上传日期 | 2020/12/22 | +| MindSpore版本 | 1.1.0 | +| 数据集 | 200k张图片 | +| batch_size | 128 | +| 输出 | 分类概率 | +| 准确率 | 8卡 Top-5: 81.7% | +| 推理模型 | 11.6M (.ckpt文件) | + +# [ModelZoo主页](#目录) + + 请参考官方[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/cv/tinydarknet/eval.py b/model_zoo/official/cv/tinydarknet/eval.py new file mode 100644 index 00000000000..b9ea91e4136 --- /dev/null +++ b/model_zoo/official/cv/tinydarknet/eval.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +##############test tinydarknet example on cifar10################# +python eval.py +""" +import argparse + + +from mindspore import context + +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.common import set_seed + +from src.config import imagenet_cfg +from src.dataset import create_dataset_imagenet + +from src.tinydarknet import TinyDarkNet +from src.CrossEntropySmooth import CrossEntropySmooth + +set_seed(1) + +parser = argparse.ArgumentParser(description='tinydarknet') +parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'], + help='dataset name.') +parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +args_opt = parser.parse_args() + +if __name__ == '__main__': + + if args_opt.dataset_name == "imagenet": + cfg = imagenet_cfg + dataset = create_dataset_imagenet(cfg.val_data_path, 1, False) + if not cfg.use_label_smooth: + cfg.label_smooth_factor = 0.0 + loss = CrossEntropySmooth(sparse=True, reduction="mean", + smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) + net = TinyDarkNet(num_classes=cfg.num_classes) + model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) + + else: + raise ValueError("dataset is not support.") + + device_target = cfg.device_target + context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) + if device_target == "Ascend": + context.set_context(device_id=cfg.device_id) + + if args_opt.checkpoint_path is not None: + param_dict = load_checkpoint(args_opt.checkpoint_path) + print("load checkpoint from [{}].".format(args_opt.checkpoint_path)) + else: + param_dict = load_checkpoint(cfg.checkpoint_path) + print("load checkpoint from [{}].".format(cfg.checkpoint_path)) + + load_param_into_net(net, param_dict) + net.set_train(False) + + acc = model.eval(dataset) + print("accuracy: ", acc) diff --git a/model_zoo/official/cv/tinydarknet/export.py b/model_zoo/official/cv/tinydarknet/export.py new file mode 100644 index 00000000000..4dd8545985d --- /dev/null +++ b/model_zoo/official/cv/tinydarknet/export.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +##############export checkpoint file into air and onnx models################# +python export.py +""" +import argparse +import numpy as np + +import mindspore as ms +from mindspore import Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net, export + +from src.config import imagenet_cfg +from src.tinydarknet import TinydarkNet + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Classification') + parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'], + help='dataset name.') + args_opt = parser.parse_args() + + if args_opt.dataset_name == 'imagenet': + cfg = imagenet_cfg + else: + raise ValueError("dataset is not support.") + + net = TinydarkNet(num_classes=cfg.num_classes) + + assert cfg.checkpoint_path is not None, "cfg.checkpoint_path is None." + param_dict = load_checkpoint(cfg.checkpoint_path) + load_param_into_net(net, param_dict) + + input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]), ms.float32) + export(net, input_arr, file_name=cfg.onnx_filename, file_format="ONNX") + export(net, input_arr, file_name=cfg.air_filename, file_format="AIR") diff --git a/model_zoo/official/cv/tinydarknet/mindspore_hub_conf.py b/model_zoo/official/cv/tinydarknet/mindspore_hub_conf.py new file mode 100644 index 00000000000..6238a0ed093 --- /dev/null +++ b/model_zoo/official/cv/tinydarknet/mindspore_hub_conf.py @@ -0,0 +1,25 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config.""" +from src.tinydarknet import TinyDarkNet + +def tinydarknet(*args, **kwargs): + return TinyDarkNet(*args, **kwargs) + + +def create_network(name, *args, **kwargs): + if name == "tinydarknet": + return tinydarknet(*args, **kwargs) + raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/model_zoo/official/cv/tinydarknet/scripts/run_eval.sh b/model_zoo/official/cv/tinydarknet/scripts/run_eval.sh new file mode 100644 index 00000000000..1e58fee4ab4 --- /dev/null +++ b/model_zoo/official/cv/tinydarknet/scripts/run_eval.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +python ../eval.py > ./eval.log 2>&1 & diff --git a/model_zoo/official/cv/tinydarknet/scripts/run_train.sh b/model_zoo/official/cv/tinydarknet/scripts/run_train.sh new file mode 100644 index 00000000000..300ec840db3 --- /dev/null +++ b/model_zoo/official/cv/tinydarknet/scripts/run_train.sh @@ -0,0 +1,66 @@ +#!/usr/bin/env bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "$1 $2" + +if [ $# != 1 ] && [ $# != 2 ] +then + echo "Usage: sh run_train.sh [RANK_TABLE_FILE] [cifar10|imagenet]" +exit 1 +fi + +if [ ! -f $1 ] +then + echo "error:RANK_TABLE_FILE=$1 is not a file" +exit 1 +fi + + +dataset_type='cifar10' +if [ $# == 2 ] +then + if [ $2 != "cifar10" ] && [ $2 != "imagenet" ] + then + echo "error: the selected dataset is neither cifar10 nor imagenet" + exit 1 + fi + dataset_type=$2 +fi + + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 +RANK_TABLE_FILE=$(realpath $1) +export RANK_TABLE_FILE +echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}" + +export SERVER_ID=0 +rank_start=$((DEVICE_NUM * SERVER_ID)) +for((i=0; i<${SERVER_NUM}; i++)) +do + export DEVICE_ID=$i + export RANK_ID=$((rank_start + i)) + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp -r ../src ./train_parallel$i + cp ../train.py ./train_parallel$i + echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type" + cd .train_parallel$i || exit + env > env.log + python train.py --device_id=$i --dataset_name=$dataset_type> log 2>&1 & + cd .. +done diff --git a/model_zoo/official/cv/tinydarknet/src/CrossEntropySmooth.py b/model_zoo/official/cv/tinydarknet/src/CrossEntropySmooth.py new file mode 100644 index 00000000000..55d5d8b8082 --- /dev/null +++ b/model_zoo/official/cv/tinydarknet/src/CrossEntropySmooth.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""define loss function for network""" +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.nn.loss.loss import _Loss +from mindspore.ops import functional as F +from mindspore.ops import operations as P + + +class CrossEntropySmooth(_Loss): + """CrossEntropy""" + def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): + super(CrossEntropySmooth, self).__init__() + self.onehot = P.OneHot() + self.sparse = sparse + self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) + self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) + self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction) + + def construct(self, logit, label): + if self.sparse: + label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) + loss = self.ce(logit, label) + return loss diff --git a/model_zoo/official/cv/tinydarknet/src/config.py b/model_zoo/official/cv/tinydarknet/src/config.py new file mode 100644 index 00000000000..901c4d8dba0 --- /dev/null +++ b/model_zoo/official/cv/tinydarknet/src/config.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +network config setting, will be used in main.py +""" +from easydict import EasyDict as edict + +imagenet_cfg = edict({ + 'name': 'imagenet', + 'pre_trained': False, + 'num_classes': 1000, + 'lr_init': 0.1, + 'batch_size': 128, + 'epoch_size': 500, + 'momentum': 0.9, + 'weight_decay': 1e-4, + 'image_height': 224, + 'image_width': 224, + 'data_path': './dataset/imagenet_original/train/', + 'val_data_path': './dataset/imagenet_original/val/', + 'device_target': 'Ascend', + 'device_id': 0, + 'device_num': 8, + 'keep_checkpoint_max': 1, + 'checkpoint_path': './scripts/train_parallel4/ckpt_4/train_tinydarknet_imagenet-300_1251.ckpt', + 'onnx_filename': 'tinydarknet.onnx', + 'air_filename': 'tinydarknet.air', + + # optimizer and lr related + 'lr_scheduler': 'exponential', + 'lr_epochs': [70, 140, 210, 280], + 'lr_gamma': 0.3, + 'eta_min': 0.0, + 'T_max': 150, + 'warmup_epochs': 0, + + # loss related + 'is_dynamic_loss_scale': False, + 'loss_scale': 1024, + 'label_smooth_factor': 0.1, + 'use_label_smooth': True, +}) diff --git a/model_zoo/official/cv/tinydarknet/src/dataset.py b/model_zoo/official/cv/tinydarknet/src/dataset.py new file mode 100644 index 00000000000..75c1ae30765 --- /dev/null +++ b/model_zoo/official/cv/tinydarknet/src/dataset.py @@ -0,0 +1,100 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Data operations, will be used in train.py and eval.py +""" +import os + +import mindspore.common.dtype as mstype +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as C +import mindspore.dataset.vision.c_transforms as vision +from src.config import imagenet_cfg + +def create_dataset_imagenet(dataset_path, repeat_num=1, training=True, + num_parallel_workers=None, shuffle=None): + """ + create a train or eval imagenet2012 dataset for resnet50 + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + + Returns: + dataset + """ + + device_num, rank_id = _get_rank_info() + + if device_num == 1: + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle) + else: + data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle, + num_shards=device_num, shard_id=rank_id) + + assert imagenet_cfg.image_height == imagenet_cfg.image_width, "image_height not equal image_width" + image_size = imagenet_cfg.image_height + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + if training: + transform_img = [ + vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + vision.RandomHorizontalFlip(prob=0.5), + vision.RandomColorAdjust(0.4, 0.4, 0.4, 0.1), + vision.Normalize(mean=mean, std=std), + vision.HWC2CHW() + ] + else: + transform_img = [ + vision.Decode(), + vision.Resize(256), + vision.CenterCrop(image_size), + vision.Normalize(mean=mean, std=std), + vision.HWC2CHW() + ] + + transform_label = [C.TypeCast(mstype.int32)] + + data_set = data_set.map(input_columns="image", num_parallel_workers=8, operations=transform_img) + data_set = data_set.map(input_columns="label", num_parallel_workers=8, operations=transform_label) + + # apply batch operations + data_set = data_set.batch(imagenet_cfg.batch_size, drop_remainder=True) + + # apply dataset repeat operation + data_set = data_set.repeat(repeat_num) + + return data_set + + +def _get_rank_info(): + """ + get rank size and rank id + """ + rank_size = int(os.environ.get("RANK_SIZE", 1)) + + if rank_size > 1: + from mindspore.communication.management import get_rank, get_group_size + rank_size = get_group_size() + rank_id = get_rank() + else: + rank_size = rank_id = None + + return rank_size, rank_id diff --git a/model_zoo/official/cv/tinydarknet/src/lr_scheduler/__init__.py b/model_zoo/official/cv/tinydarknet/src/lr_scheduler/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model_zoo/official/cv/tinydarknet/src/lr_scheduler/linear_warmup.py b/model_zoo/official/cv/tinydarknet/src/lr_scheduler/linear_warmup.py new file mode 100644 index 00000000000..78e7b85f6dd --- /dev/null +++ b/model_zoo/official/cv/tinydarknet/src/lr_scheduler/linear_warmup.py @@ -0,0 +1,20 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""lr""" + +def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + lr = float(init_lr) + lr_inc * current_step + return lr diff --git a/model_zoo/official/cv/tinydarknet/src/lr_scheduler/warmup_cosine_annealing_lr.py b/model_zoo/official/cv/tinydarknet/src/lr_scheduler/warmup_cosine_annealing_lr.py new file mode 100644 index 00000000000..349270b6e11 --- /dev/null +++ b/model_zoo/official/cv/tinydarknet/src/lr_scheduler/warmup_cosine_annealing_lr.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""lr""" + +import math +import numpy as np + +from .linear_warmup import linear_warmup_lr + + +def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): + """ warmup cosine annealing lr""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2 + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) diff --git a/model_zoo/official/cv/tinydarknet/src/lr_scheduler/warmup_step_lr.py b/model_zoo/official/cv/tinydarknet/src/lr_scheduler/warmup_step_lr.py new file mode 100644 index 00000000000..df78f17d795 --- /dev/null +++ b/model_zoo/official/cv/tinydarknet/src/lr_scheduler/warmup_step_lr.py @@ -0,0 +1,59 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""lr""" + +from collections import Counter +import numpy as np + +from .linear_warmup import linear_warmup_lr + + +def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): + """warmup step lr""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + milestones = lr_epochs + milestones_steps = [] + for milestone in milestones: + milestones_step = milestone * steps_per_epoch + milestones_steps.append(milestones_step) + + lr_each_step = [] + lr = base_lr + milestones_steps_counter = Counter(milestones_steps) + for i in range(total_steps): + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = lr * gamma ** milestones_steps_counter[i] + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): + """lr""" + return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) + + +def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): + """lr""" + lr_epochs = [] + for i in range(1, max_epoch): + if i % epoch_size == 0: + lr_epochs.append(i) + return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) diff --git a/model_zoo/official/cv/tinydarknet/src/tinydarknet.py b/model_zoo/official/cv/tinydarknet/src/tinydarknet.py new file mode 100644 index 00000000000..a6d740e6f9d --- /dev/null +++ b/model_zoo/official/cv/tinydarknet/src/tinydarknet.py @@ -0,0 +1,142 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""TinydarkNet""" +import mindspore.nn as nn +from mindspore.common.initializer import TruncatedNormal +from mindspore.ops import operations as P + + +def weight_variable(): + """Weight variable.""" + return TruncatedNormal(0.02) + + +class Conv1x1dBlock(nn.Cell): + """ + Basic convolutional block + Args: + in_channles (int): Input channel. + out_channels (int): Output channel. + kernel_size (int): Input kernel size. Default: 1 + stride (int): Stride size for the first convolutional layer. Default: 1. + padding (int): Implicit paddings on both sides of the input. Default: 0. + pad_mode (str): Padding mode. Optional values are "same", "valid", "pad". Default: "same". + Returns: + Tensor, output tensor. + """ + + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode="same"): + super(Conv1x1dBlock, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, pad_mode=pad_mode, weight_init=weight_variable()) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + self.leakyrelu = nn.LeakyReLU() + + def construct(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.leakyrelu(x) + return x + +class Conv3x3dBlock(nn.Cell): + """ + Basic convolutional block + Args: + in_channles (int): Input channel. + out_channels (int): Output channel. + kernel_size (int): Input kernel size. Default: 1 + stride (int): Stride size for the first convolutional layer. Default: 1. + padding (int): Implicit paddings on both sides of the input. Default: 0. + pad_mode (str): Padding mode. Optional values are "same", "valid", "pad". Default: "same". + Returns: + Tensor, output tensor. + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, pad_mode="pad"): + super(Conv3x3dBlock, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, pad_mode=pad_mode, weight_init=weight_variable()) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + self.leakyrelu = nn.LeakyReLU() + + def construct(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.leakyrelu(x) + return x + + + +class TinyDarkNet(nn.Cell): + """ + Tinydarknet architecture + """ + + def __init__(self, num_classes, include_top=True): + super(TinyDarkNet, self).__init__() + self.conv1 = Conv3x3dBlock(3, 16) + self.conv2 = Conv3x3dBlock(16, 32) + self.conv3 = Conv1x1dBlock(32, 16) + self.conv4 = Conv3x3dBlock(16, 128) + self.conv5 = Conv1x1dBlock(128, 16) + self.conv6 = Conv3x3dBlock(16, 128) + self.conv7 = Conv1x1dBlock(128, 32) + self.conv8 = Conv3x3dBlock(32, 256) + self.conv9 = Conv1x1dBlock(256, 32) + self.conv10 = Conv3x3dBlock(32, 256) + self.conv11 = Conv1x1dBlock(256, 64) + self.conv12 = Conv3x3dBlock(64, 512) + self.conv13 = Conv1x1dBlock(512, 64) + self.conv14 = Conv3x3dBlock(64, 512) + self.conv15 = Conv1x1dBlock(512, 128) + self.conv16 = Conv1x1dBlock(128, 1000) + + self.maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="same") + self.avgpool2d = P.ReduceMean(keep_dims=True) + + self.flatten = nn.Flatten() + + + def construct(self, x): + """construct""" + x = self.conv1(x) + x = self.maxpool2d(x) + + x = self.conv2(x) + x = self.maxpool2d(x) + + x = self.conv3(x) + x = self.conv4(x) + x = self.conv5(x) + x = self.conv6(x) + x = self.maxpool2d(x) + + x = self.conv7(x) + x = self.conv8(x) + x = self.conv9(x) + x = self.conv10(x) + x = self.maxpool2d(x) + + x = self.conv11(x) + x = self.conv12(x) + x = self.conv13(x) + x = self.conv14(x) + x = self.conv15(x) + x = self.conv16(x) + + x = self.avgpool2d(x, (2, 3)) + x = self.flatten(x) + + return x diff --git a/model_zoo/official/cv/tinydarknet/train.py b/model_zoo/official/cv/tinydarknet/train.py new file mode 100644 index 00000000000..dee1865cfe0 --- /dev/null +++ b/model_zoo/official/cv/tinydarknet/train.py @@ -0,0 +1,165 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +#################train tinydarknet example on cifar10######################## +python train.py +""" +import argparse + +from mindspore import Tensor +from mindspore import context +from mindspore.communication.management import init, get_rank +from mindspore.nn.optim.momentum import Momentum +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager +from mindspore.train.model import Model +from mindspore.context import ParallelMode +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.common import set_seed + +from src.config import imagenet_cfg +from src.dataset import create_dataset_imagenet +from src.tinydarknet import TinyDarkNet +from src.CrossEntropySmooth import CrossEntropySmooth + +set_seed(1) + + +def lr_steps_imagenet(_cfg, steps_per_epoch): + """lr step for imagenet""" + from src.lr_scheduler.warmup_step_lr import warmup_step_lr + from src.lr_scheduler.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr + if _cfg.lr_scheduler == 'exponential': + _lr = warmup_step_lr(_cfg.lr_init, + _cfg.lr_epochs, + steps_per_epoch, + _cfg.warmup_epochs, + _cfg.epoch_size, + gamma=_cfg.lr_gamma, + ) + elif _cfg.lr_scheduler == 'cosine_annealing': + _lr = warmup_cosine_annealing_lr(_cfg.lr_init, + steps_per_epoch, + _cfg.warmup_epochs, + _cfg.epoch_size, + _cfg.T_max, + _cfg.eta_min) + else: + raise NotImplementedError(_cfg.lr_scheduler) + + return _lr + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Classification') + parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet', 'cifar10'], + help='dataset name.') + parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)') + args_opt = parser.parse_args() + + if args_opt.dataset_name == "imagenet": + cfg = imagenet_cfg + else: + raise ValueError("Unsupport dataset.") + + # set context + device_target = cfg.device_target + + context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) + device_num = cfg.device_num + + rank = 0 + if device_target == "Ascend": + if args_opt.device_id is not None: + context.set_context(device_id=args_opt.device_id) + else: + context.set_context(device_id=cfg.device_id) + + if device_num > 1: + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True) + init() + rank = get_rank() + else: + raise ValueError("Unsupported platform.") + + if args_opt.dataset_name == "imagenet": + dataset = create_dataset_imagenet(cfg.data_path, 1) + else: + raise ValueError("Unsupport dataset.") + + batch_num = dataset.get_dataset_size() + + net = TinyDarkNet(num_classes=cfg.num_classes) + # Continue training if set pre_trained to be True + if cfg.pre_trained: + param_dict = load_checkpoint(cfg.checkpoint_path) + load_param_into_net(net, param_dict) + + loss_scale_manager = None + if args_opt.dataset_name == 'imagenet': + lr = lr_steps_imagenet(cfg, batch_num) + + def get_param_groups(network): + """ get param groups """ + decay_params = [] + no_decay_params = [] + for x in network.trainable_params(): + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + no_decay_params.append(x) + elif parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + elif parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not include BN + no_decay_params.append(x) + else: + decay_params.append(x) + + return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] + + + if cfg.is_dynamic_loss_scale: + cfg.loss_scale = 1 + + opt = Momentum(params=get_param_groups(net), + learning_rate=Tensor(lr), + momentum=cfg.momentum, + weight_decay=cfg.weight_decay, + loss_scale=cfg.loss_scale) + if not cfg.use_label_smooth: + cfg.label_smooth_factor = 0.0 + loss = CrossEntropySmooth(sparse=True, reduction="mean", + smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) + + if cfg.is_dynamic_loss_scale: + loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000) + else: + loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False) + + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, + amp_level="O3", loss_scale_manager=loss_scale_manager) + + config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 50, keep_checkpoint_max=cfg.keep_checkpoint_max) + time_cb = TimeMonitor(data_size=batch_num) + ckpt_save_dir = "./ckpt_" + str(rank) + "/" + ckpoint_cb = ModelCheckpoint(prefix="train_tinydarknet_" + args_opt.dataset_name, directory=ckpt_save_dir, + config=config_ck) + loss_cb = LossMonitor() + model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) + print("train success")