forked from mindspore-Ecosystem/mindspore
!18929 New model Ibnnet with GPU version
Merge pull request !18929 from 水扬波/ibnnet_gpu
This commit is contained in:
commit
315bbf12ab
|
@ -0,0 +1,285 @@
|
|||
# 目录
|
||||
|
||||
- [目录](#目录)
|
||||
- [IBN-Net概述](#IBN-Net概述)
|
||||
- [IBN-Net示例](#IBN-Net示例)
|
||||
- [数据集](#数据集)
|
||||
- [环境要求](#环境要求)
|
||||
- [快速入门](#快速入门)
|
||||
- [脚本说明](#脚本说明)
|
||||
- [脚本和样例代码](#脚本和样例代码)
|
||||
- [脚本参数](#脚本参数)
|
||||
- [预训练模型](#预训练模型)
|
||||
- [训练过程](#训练过程)
|
||||
- [训练](#训练)
|
||||
- [分布式训练](#分布式训练)
|
||||
- [评估过程](#评估过程)
|
||||
- [评估](#评估)
|
||||
- [模型描述](#模型描述)
|
||||
- [性能](#性能)
|
||||
- [评估性能](#评估性能)
|
||||
- [推理性能](#推理性能)
|
||||
- [使用方法](#使用方法)
|
||||
- [推理](#推理)
|
||||
- [迁移学习](#迁移学习)
|
||||
- [随机情况说明](#随机情况说明)
|
||||
- [ModelZoo主页](#ModelZoo主页)
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
# IBN-Net概述
|
||||
|
||||
卷积神经网络(CNNs)在许多计算机视觉问题上取得了巨大的成功。与现有的设计CNN架构的工作不同,论文提出了一种新的卷积架构IBN-Net,它可以提高单个域中单个任务的性能,这显著提高了CNN在一个领域(如城市景观)的建模能力以及在另一个领域(如GTA5)的泛化能力,而无需微调。IBN-Net将InstanceNorm(IN)和BatchNorm(BN)作为构建块进行了集成,并可以封装到许多高级的深度网络中以提高其性能。这项工作有三个关键贡献。(1) 通过深入研究IN和BN,我们发现IN学习对外观变化不变的特征,例如颜色、样式和虚拟/现实,而BN对于保存内容相关信息是必不可少的。(2) IBN-Net可以应用于许多高级的深层体系结构,如DenseNet、ResNet、ResNeXt和SENet,并在不增加计算量的情况下不断地提高它们的性能。(3) 当将训练好的网络应用到新的领域时,例如从GTA5到城市景观,IBN网络作为领域适应方法实现了类似的改进,即使不使用来自目标领域的数据。
|
||||
|
||||
[论文](https://arxiv.org/abs/1807.09441): Pan X , Ping L , Shi J , et al. Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net[C]// European Conference on Computer Vision. Springer, Cham, 2018.
|
||||
|
||||
# IBN-Net示例
|
||||
|
||||
# 数据集
|
||||
|
||||
使用的数据集:[ImageNet2012](http://www.image-net.org/)
|
||||
训练集:1,281,167张图片+标签
|
||||
验证集:50,000张图片+标签
|
||||
测试集:100,000张图片
|
||||
|
||||
# 环境要求
|
||||
|
||||
- 硬件:Ascend/GPU
|
||||
- 使用Ascend/GPU处理器来搭建硬件环境。
|
||||
|
||||
- 框架
|
||||
- [MindSpore](https://www.mindspore.cn/install)
|
||||
- 如需查看详情,请参见如下资源:
|
||||
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
|
||||
|
||||
# 快速入门
|
||||
|
||||
通过官方网站安装MindSpore后,您可以按照如下步骤进行训练和评估:
|
||||
|
||||
```python
|
||||
# 分布式训练运行示例
|
||||
sh scripts/run_distribute_train.sh /path/dataset /path/evalset pretrained_model.ckpt rank_size
|
||||
|
||||
# 单机训练运行示例
|
||||
sh scripts/run_standalone_train.sh /path/dataset /path/evalset pretrained_model.ckpt
|
||||
|
||||
# 运行评估示例
|
||||
sh scripts/run_eval.sh
|
||||
```
|
||||
|
||||
## 脚本说明
|
||||
|
||||
## 脚本和样例代码
|
||||
|
||||
```path
|
||||
└── IBNNet
|
||||
├── README.md // IBNNet相关描述
|
||||
├── scripts
|
||||
├── run_distribute_train.sh // 用于分布式训练的shell脚本
|
||||
├── run_distribute_train_gpu.sh // 用于GPU分布式训练的shell脚本
|
||||
├── run_standalone_train.sh // 用于单机训练的shell脚本
|
||||
├── run_standalone_train.sh // 用于GPU单机训练的shell脚本
|
||||
├── run_eval.sh // 用于评估的shell脚本
|
||||
└── run_eval.sh // 用于GPU评估的shell脚本
|
||||
├── src
|
||||
├── loss.py //损失函数
|
||||
├── lr_generator.py //生成学习率
|
||||
├── config.py // 参数配置
|
||||
├── dataset.py // 创建数据集
|
||||
├── resnet_ibn.py // IBNNet架构
|
||||
├── utils
|
||||
├── pth2ckpt.py //转换pth文件为ckpt文件
|
||||
├── export.py
|
||||
├── eval.py // 测试脚本
|
||||
├── train.py // 训练脚本
|
||||
|
||||
|
||||
```
|
||||
|
||||
## 脚本参数
|
||||
|
||||
```python
|
||||
train.py和config.py中主要参数如下:
|
||||
|
||||
-- use_modelarts:是否使用modelarts平台训练。可选值为True、False。
|
||||
-- device_id:用于训练或评估数据集的设备ID。当使用train.sh进行分布式训练时,忽略此参数。
|
||||
-- device_num:使用train.sh进行分布式训练时使用的设备数。
|
||||
-- train_url:checkpoint的输出路径。
|
||||
-- data_url:训练集路径。
|
||||
-- ckpt_url:checkpoint路径。
|
||||
-- eval_url:验证集路径。
|
||||
|
||||
```
|
||||
|
||||
## 预训练模型
|
||||
|
||||
可以使用utils/pth2ckpt.py将预训练的pth文件转换为ckpt文件。
|
||||
pth预训练模型文件获取路径如下:[预训练模型](https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth)
|
||||
|
||||
## 训练过程
|
||||
|
||||
### 训练
|
||||
|
||||
- 在Ascend环境训练
|
||||
|
||||
```shell
|
||||
sh scripts/run_standalone_train.sh /path/dataset /path/evalset pretrained_model.ckpt
|
||||
```
|
||||
|
||||
- 在GPU环境训练
|
||||
|
||||
```shell
|
||||
sh scripts/run_standalone_train_gpu.sh /path/dataset /path/evalset pretrained_model.ckpt
|
||||
```
|
||||
|
||||
### 分布式训练
|
||||
|
||||
- 在Ascend环境训练
|
||||
|
||||
```shell
|
||||
sh scripts/run_distribute_train.sh /path/dataset /path/evalset pretrained_model.ckpt rank_size
|
||||
```
|
||||
|
||||
上述shell脚本将在后台运行分布训练。可以通过`device[X]/test_*.log`文件查看结果。
|
||||
采用以下方式达到损失值:
|
||||
|
||||
```log
|
||||
epoch: 12 step: 2502, loss is 1.7709649
|
||||
epoch time: 331584.555 ms, per step time: 132.528 ms
|
||||
epoch: 12 step: 2502, loss is 1.2770984
|
||||
epoch time: 331503.971 ms, per step time: 132.496 ms
|
||||
...
|
||||
epoch: 82 step: 2502, loss is 0.98658705
|
||||
epoch time: 331877.856 ms, per step time: 132.645 ms
|
||||
epoch: 82 step: 2502, loss is 0.82476664
|
||||
epoch time: 331689.239 ms, per step time: 132.570 ms
|
||||
|
||||
```
|
||||
|
||||
- 在GPU环境训练
|
||||
|
||||
```shell
|
||||
sh scripts/run_distribute_train_gpu.sh /path/dataset /path/evalset pretrained_model.ckpt rank_size
|
||||
```
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 评估
|
||||
|
||||
- 在Ascend环境运行时评估ImageNet数据集
|
||||
|
||||
```bash
|
||||
sh scripts/run_eval.sh path/evalset path/ckpt
|
||||
```
|
||||
|
||||
上述命令将在后台运行,您可以通过eval.log文件查看结果。测试数据集的准确性如下:
|
||||
|
||||
```bash
|
||||
{'Accuracy': 0.7785483870967742}
|
||||
```
|
||||
|
||||
- 在GPU环境运行时评估ImageNet数据集
|
||||
|
||||
```bash
|
||||
sh scripts/run_eval_gpu.sh path/evalset path/ckpt
|
||||
```
|
||||
|
||||
上述命令将在后台运行,您可以通过eval.log文件查看结果。测试数据集的准确性如下:
|
||||
|
||||
```bash
|
||||
============== Accuracy:{'top_5_accuracy': 0.93684, 'top_1_accuracy': 0.7743} ==============
|
||||
```
|
||||
|
||||
# 模型描述
|
||||
|
||||
## 性能
|
||||
|
||||
### 评估性能
|
||||
|
||||
| 参数 | IBN-Net |
|
||||
| ------------- | ----------------------------------------------- |
|
||||
| 模型版本 | resnet50_ibn_a |
|
||||
| 资源 | Ascend 910; CPU: 2.60GHz,192内核;内存,755G |
|
||||
| 上传日期 | 2021-03-30 |
|
||||
| MindSpore版本 | 1.1.1-c76-tr5 |
|
||||
| 数据集 | ImageNet2012 |
|
||||
| 训练参数 | lr=0.1; gamma=0.1 |
|
||||
| 优化器 | SGD |
|
||||
| 损失函数 | SoftmaxCrossEntropyExpand |
|
||||
| 输出 | 概率 |
|
||||
| 损失 | 0.6 |
|
||||
| 速度 | 1卡:127毫秒/步;8卡:132毫秒/步 |
|
||||
| 总时间 | 1卡:65小时;8卡:9.5小时 |
|
||||
| 参数(M) | 46.15 |
|
||||
| 微调检查点 | 293M (.ckpt file) |
|
||||
| 脚本 | [脚本路径](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ibnnet) |
|
||||
|
||||
### 推理性能
|
||||
|
||||
| 参数 | IBN-Net |
|
||||
| ------------- | ------------------ |
|
||||
| 模型版本 | resnet50_ibn_a |
|
||||
| 资源 | Ascend 910 |
|
||||
| 上传日期 | 2021/03/30 |
|
||||
| MindSpore版本 | 1.1.1-c76-tr5 |
|
||||
| 数据集 | ImageNet2012 |
|
||||
| 输出 | 概率 |
|
||||
| 准确性 | 1卡:77.45%; 8卡:77.45% |
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 推理
|
||||
|
||||
如果您需要使用已训练模型在GPU、Ascend 910、Ascend 310等多个硬件平台上进行推理,可参考[此处](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/migrate_3rd_scripts.html)。操作示例如下:
|
||||
|
||||
```python
|
||||
# 加载未知数据集进行推理
|
||||
dataset = dataset.create_dataset(cfg.data_path, 1, False)
|
||||
|
||||
# 定义模型
|
||||
net = resnet50_ibn_a(num_classes=1000, pretrained=False)
|
||||
param_dict = load_checkpoint(args.ckpt_url)
|
||||
load_param_into_net(net, param_dict)
|
||||
print('Load Pretrained parameters done!')
|
||||
|
||||
criterion = SoftmaxCrossEntropyExpand(sparse=True)
|
||||
|
||||
step = train_dataset.get_dataset_size()
|
||||
lr = lr_generator(args.lr, train_epoch, steps_per_epoch=step)
|
||||
optimizer = nn.SGD(params=net.trainable_params(), learning_rate=lr,
|
||||
momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
|
||||
# 模型变形
|
||||
model = Model(net, loss_fn=criterion, optimizer=optimizer, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
|
||||
loss_cb = LossMonitor()
|
||||
|
||||
# 设置并应用检查点参数
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=step, keep_checkpoint_max=5)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="ResNet50_" + str(device_id), config=config_ck, directory='/cache/train_output/device_' + str(device_id))
|
||||
|
||||
cb = [ckpoint_cb, time_cb, loss_cb, eval_cb]
|
||||
model.train(train_epoch, train_dataset, callbacks=cb)
|
||||
|
||||
# 加载预训练模型
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
# 对未知数据集进行预测
|
||||
acc = model.eval(eval_dataset)
|
||||
print("accuracy: ", acc)
|
||||
```
|
||||
|
||||
### 迁移学习
|
||||
|
||||
待补充
|
||||
|
||||
# 随机情况说明
|
||||
|
||||
在dataset.py中,我们设置了“create_dataset_ImageNet”函数内的种子。
|
||||
|
||||
# ModelZoo主页
|
||||
|
||||
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
python eval.py
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.resnet_ibn import resnet50_ibn_a
|
||||
from src.loss import SoftmaxCrossEntropyExpand
|
||||
from src.dataset import create_dataset_ImageNet as create_dataset
|
||||
from src.lr_generator import lr_generator
|
||||
from src.config import cfg
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
||||
|
||||
# Datasets
|
||||
parser.add_argument('--eval_url', required=True, type=str, help='val data path')
|
||||
# Optimization options
|
||||
parser.add_argument('--epochs', default=100, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||
help='manual epoch number (useful on restarts)')
|
||||
parser.add_argument('--train_batch', default=256, type=int, metavar='N',
|
||||
help='train batch size (default: 256)')
|
||||
parser.add_argument('--test_batch', default=100, type=int, metavar='N',
|
||||
help='test batch size (default: 100)')
|
||||
# Checkpoints
|
||||
parser.add_argument('-c', '--checkpoint', required=True, type=str, metavar='PATH',
|
||||
help='path to save checkpoint (default: checkpoint)')
|
||||
|
||||
# Device options
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', choices=['GPU', 'Ascend'])
|
||||
parser.add_argument('--device_num', type=int, default=1)
|
||||
parser.add_argument('--device_id', type=int, default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_epoch = 1
|
||||
step = 60
|
||||
target = args.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
|
||||
context.set_context(device_id=args.device_id, enable_auto_mixed_precision=True)
|
||||
|
||||
lr = lr_generator(cfg.lr, train_epoch, steps_per_epoch=step)
|
||||
net = resnet50_ibn_a(num_classes=cfg.class_num)
|
||||
criterion = SoftmaxCrossEntropyExpand(sparse=True)
|
||||
optimizer = nn.SGD(params=net.trainable_params(), learning_rate=lr,
|
||||
momentum=cfg.momentum, weight_decay=cfg.weight_decay)
|
||||
model = Model(net, loss_fn=criterion, optimizer=optimizer, metrics={"top_1_accuracy", "top_5_accuracy"})
|
||||
|
||||
print("============== Starting Testing ==============")
|
||||
# load the saved model for evaluation
|
||||
param_dict = load_checkpoint(args.checkpoint)
|
||||
# load parameter to the network
|
||||
load_param_into_net(net, param_dict)
|
||||
# load testing dataset
|
||||
ds_eval = create_dataset(os.path.join(args.eval_url), do_train=False, repeat_num=1,
|
||||
batch_size=cfg.test_batch, target=target)
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
print("============== Accuracy:{} ==============".format(acc))
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air, onnx, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
|
||||
|
||||
from src.config import cfg
|
||||
from src.resnet_ibn import resnet50_ibn_a
|
||||
|
||||
parser = argparse.ArgumentParser(description='Classification')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="ibnnet", help="output file name.")
|
||||
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
|
||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||
help="device target")
|
||||
parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet'],
|
||||
help='dataset name.')
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args.dataset_name == 'imagenet':
|
||||
cfg = cfg
|
||||
else:
|
||||
raise ValueError("dataset is not support.")
|
||||
|
||||
net = resnet50_ibn_a(num_classes=cfg.class_num)
|
||||
|
||||
assert args.ckpt_file is not None, "args.ckpt_file is None."
|
||||
param_dict = load_checkpoint(args.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.ones([args.batch_size, 3, 224, 224]), mstype.float32)
|
||||
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
|
|
@ -0,0 +1,65 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATA_PATH EVAL_PATH CKPT_PATH RANK_SIZE"
|
||||
echo "For example: bash run.sh /path/dataset /path/evalset /path/ckpt 8"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
DATA_PATH=$1
|
||||
export DATA_PATH=${DATA_PATH}
|
||||
EVAL_PATH=$2
|
||||
export EVAL_PATH=${EVAL_PATH}
|
||||
# shellcheck disable=SC2034
|
||||
CKPT_PATH=$3
|
||||
export CKPT_PATH=${CKPT_PATH}
|
||||
RANK_SIZE=$4
|
||||
|
||||
EXEC_PATH=$(pwd)
|
||||
echo "$EXEC_PATH"
|
||||
|
||||
test_dist_8pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
|
||||
export RANK_SIZE=8
|
||||
}
|
||||
|
||||
test_dist_2pcs()
|
||||
{
|
||||
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
|
||||
export RANK_SIZE=2
|
||||
}
|
||||
|
||||
test_dist_${RANK_SIZE}pcs
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
rm -rf device$i
|
||||
mkdir device$i
|
||||
cp -r ./src/ ./device$i
|
||||
cp train.py ./device$i
|
||||
cd ./device$i
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
echo "start training for device $i"
|
||||
env > env$i.log
|
||||
python train.py --data_url "$DATA_PATH" --eval_url "$EVAL_PATH" --ckpt_url "$CKPT_PATH" \
|
||||
--device_num "$RANK_SIZE" --pretrained > train.log$i 2>&1 &
|
||||
cd ../
|
||||
done
|
||||
echo "start training"
|
||||
cd ../
|
|
@ -0,0 +1,44 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATA_PATH EVAL_PATH CKPT_PATH RANK_SIZE"
|
||||
echo "For example: bash run.sh /path/dataset /path/evalset /path/ckpt 8"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
set -e
|
||||
|
||||
export DEVICE_NUM=$4
|
||||
export RANK_SIZE=$4
|
||||
export DATASET_NAME=$1
|
||||
export EVAL_PATH=$2
|
||||
export CKPT_PATH=$3
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
rm -rf ./train_parallel
|
||||
mkdir ./train_parallel
|
||||
cp -r ./src/ ./train_parallel
|
||||
# shellcheck disable=SC2035
|
||||
cp *.py ./train_parallel
|
||||
cd ./train_parallel
|
||||
env > env.log
|
||||
echo "start training"
|
||||
mpirun -n $4 --allow-run-as-root \
|
||||
python train.py --device_num $4 --device_target GPU --data_url $1 \
|
||||
--ckpt_url $3 --eval_url $2 \
|
||||
--pretrained \
|
||||
> train.log 2>&1 &
|
|
@ -0,0 +1,25 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
DATA_PATH=$1
|
||||
CHECKPOINT=$2
|
||||
|
||||
python eval.py \
|
||||
--device_id 0 \
|
||||
--checkpoint "$CHECKPOINT" \
|
||||
--eval_url "$DATA_PATH" \
|
||||
> eval.log 2>&1 &
|
||||
echo "start evaluation"
|
|
@ -0,0 +1,26 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
DATA_PATH=$1
|
||||
CHECKPOINT=$2
|
||||
|
||||
python eval.py \
|
||||
--device_id 0 \
|
||||
--checkpoint "$CHECKPOINT" \
|
||||
--eval_url "$DATA_PATH" \
|
||||
--device_target "GPU" \
|
||||
> eval.log 2>&1 &
|
||||
echo "start evaluation"
|
|
@ -0,0 +1,37 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATA_PATH EVAL_PATH CKPT_PATH"
|
||||
echo "For example: bash run.sh /path/dataset /path/evalset /path/ckpt"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
EXE_PATH=$(pwd)
|
||||
DATA_PATH=$1
|
||||
EVAL_PATH=$2
|
||||
CKPT_PATH=$3
|
||||
|
||||
python train.py \
|
||||
--epochs 100 \
|
||||
--train_url "$EXE_PATH" \
|
||||
--data_url "$DATA_PATH" \
|
||||
--eval_url "$EVAL_PATH" \
|
||||
--ckpt_url "$CKPT_PATH" \
|
||||
--pretrained \
|
||||
> train.log 2>&1 &
|
||||
echo "start training"
|
||||
cd ../
|
|
@ -0,0 +1,38 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run.sh DATA_PATH EVAL_PATH CKPT_PATH"
|
||||
echo "For example: bash run.sh /path/dataset /path/evalset /path/ckpt"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
EXE_PATH=$(pwd)
|
||||
DATA_PATH=$1
|
||||
EVAL_PATH=$2
|
||||
CKPT_PATH=$3
|
||||
|
||||
python train.py \
|
||||
--epochs 100 \
|
||||
--train_url "$EXE_PATH" \
|
||||
--data_url "$DATA_PATH" \
|
||||
--eval_url "$EVAL_PATH" \
|
||||
--ckpt_url "$CKPT_PATH" \
|
||||
--device_target "GPU" \
|
||||
--pretrained \
|
||||
> train.log 2>&1 &
|
||||
echo "start training"
|
||||
cd ../
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
python config.py
|
||||
"""
|
||||
from easydict import EasyDict
|
||||
|
||||
cfg = EasyDict({
|
||||
"class_num": 1000,
|
||||
"train_batch": 64,
|
||||
"test_batch": 100,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1e-4,
|
||||
"epoch_size": 1,
|
||||
"pretrain_epoch_size": 0,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 10,
|
||||
"keep_checkpoint_max": 10,
|
||||
"save_checkpoint_path": "./",
|
||||
"warmup_epochs": 5,
|
||||
"lr": 0.1,
|
||||
"gamma": 0.1,
|
||||
"schedule": [30, 60, 90]
|
||||
})
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
python dataset.py
|
||||
"""
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
|
||||
|
||||
def create_dataset_ImageNet(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
|
||||
"""
|
||||
create a train or eval imagenet2012 dataset for ibnnet
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
if target == "Ascend":
|
||||
device_num, rank_id = _get_rank_info()
|
||||
else:
|
||||
init("nccl")
|
||||
rank_id = get_rank()
|
||||
device_num = get_group_size()
|
||||
|
||||
if device_num == 1:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
else:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
image_size = 224
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||
|
||||
# define map operations
|
||||
if do_train:
|
||||
trans = [
|
||||
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize(256),
|
||||
C.CenterCrop(image_size),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
|
||||
ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)
|
||||
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
ds = ds.repeat(repeat_num)
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
def create_evalset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
|
||||
"""
|
||||
create a eval imagenet2012 dataset for ibnnet
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
|
||||
image_size = 224
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||
|
||||
# define map operations
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize(256),
|
||||
C.CenterCrop(image_size),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
|
||||
ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)
|
||||
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
ds = ds.repeat(repeat_num)
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
def _get_rank_info():
|
||||
"""
|
||||
get rank size and rank id
|
||||
"""
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
|
||||
if rank_size > 1:
|
||||
rank_size = int(os.environ.get("RANK_SIZE"))
|
||||
rank_id = int(os.environ.get("RANK_ID"))
|
||||
else:
|
||||
rank_size = 1
|
||||
rank_id = 0
|
||||
|
||||
return rank_size, rank_id
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
python loss.py
|
||||
"""
|
||||
from mindspore import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyExpand(nn.Cell):
|
||||
"""
|
||||
SoftmaxCrossEntropy
|
||||
|
||||
Args:
|
||||
logit(Tensor): input tensor
|
||||
label(Tensor): label of images
|
||||
|
||||
Returns:
|
||||
loss(float): loss
|
||||
"""
|
||||
def __init__(self, sparse=False):
|
||||
super(SoftmaxCrossEntropyExpand, self).__init__()
|
||||
self.exp = ops.Exp()
|
||||
self.sum = ops.ReduceSum(keep_dims=True)
|
||||
self.onehot = ops.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.div = ops.RealDiv()
|
||||
self.log = ops.Log()
|
||||
self.sum_cross_entropy = ops.ReduceSum(keep_dims=False)
|
||||
self.mul = ops.Mul()
|
||||
self.mul2 = ops.Mul()
|
||||
self.mean = ops.ReduceMean(keep_dims=False)
|
||||
self.sparse = sparse
|
||||
self.max = ops.ReduceMax(keep_dims=True)
|
||||
self.sub = ops.Sub()
|
||||
self.eps = Tensor(1e-24, mstype.float32)
|
||||
|
||||
def construct(self, logit, label):
|
||||
"""
|
||||
construct
|
||||
"""
|
||||
logit_max = self.max(logit, -1)
|
||||
exp = self.exp(self.sub(logit, logit_max))
|
||||
exp_sum = self.sum(exp, -1)
|
||||
softmax_result = self.div(exp, exp_sum)
|
||||
if self.sparse:
|
||||
label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
|
||||
|
||||
softmax_result_log = self.log(softmax_result + self.eps)
|
||||
loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1)
|
||||
loss = self.mul2(ops.scalar_to_array(-1.0), loss)
|
||||
loss = self.mean(loss, -1)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
python lr_generator.py
|
||||
"""
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from src.config import cfg
|
||||
|
||||
|
||||
def lr_generator(lr_init, total_epochs, steps_per_epoch):
|
||||
lr_each_step = []
|
||||
for i in range(total_epochs):
|
||||
if i in cfg.schedule:
|
||||
lr_init *= cfg.gamma
|
||||
for _ in range(steps_per_epoch):
|
||||
lr_each_step.append(lr_init)
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
return Tensor(lr_each_step)
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
python modules.py
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
|
||||
class IBN(nn.Cell):
|
||||
r"""Instance-Batch Normalization layer from
|
||||
`"Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net"
|
||||
<https://arxiv.org/pdf/1807.09441.pdf>`
|
||||
|
||||
Args:
|
||||
planes (int): Number of channels for the input tensor
|
||||
ratio (float): Ratio of instance normalization in the IBN layer
|
||||
"""
|
||||
|
||||
def __init__(self, planes, ratio=0.5):
|
||||
super(IBN, self).__init__()
|
||||
self.half = int(planes * ratio)
|
||||
self.IN = nn.GroupNorm(self.half, self.half, affine=True)
|
||||
self.BN = nn.BatchNorm2d(planes - self.half)
|
||||
|
||||
def construct(self, x):
|
||||
op_split = ops.Split(1, 2)
|
||||
split = op_split(x)
|
||||
out1 = self.IN(split[0])
|
||||
out2 = self.BN(split[1])
|
||||
op_cat = ops.Concat(1)
|
||||
out = op_cat((out1, out2))
|
||||
return out
|
||||
|
||||
|
||||
class SELayer(nn.Cell):
|
||||
"""SELayer
|
||||
|
||||
Args:
|
||||
x (Tensor): input tensor
|
||||
"""
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(SELayer, self).__init__()
|
||||
self.avg_pool = ops.ReduceMean()
|
||||
self.fc = nn.SequentialCell(
|
||||
[
|
||||
nn.Dense(channel, int(channel / reduction), has_bias=False),
|
||||
nn.ReLU(),
|
||||
nn.Dense(int(channel / reduction), channel, has_bias=False),
|
||||
nn.Sigmoid()
|
||||
]
|
||||
)
|
||||
|
||||
def construct(self, x):
|
||||
[b, c, _, _] = x.shape
|
||||
_reshape = ops.Reshape()
|
||||
y = _reshape(self.avg_pool(x, (2, 3)), (b, c))
|
||||
y = _reshape(self.fc(y), (b, c, 1, 1))
|
||||
broadcast = ops.BroadcastTo(x.shape)
|
||||
return x * broadcast(y)
|
|
@ -0,0 +1,308 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
python resnet_ibn.py
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
|
||||
from src.modules import IBN
|
||||
|
||||
|
||||
class BasicBlock_IBN(nn.Cell):
|
||||
"""
|
||||
BasicBlock_IBN
|
||||
|
||||
Args:
|
||||
x (Tensor): input tensor
|
||||
"""
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None):
|
||||
super(BasicBlock_IBN, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
|
||||
padding=0, has_bias=False)
|
||||
if ibn == 'a':
|
||||
self.bn1 = IBN(planes)
|
||||
else:
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, pad_mode='pad', padding=1, has_bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.IN = nn.GroupNorm(planes, planes, affine=True) if ibn == 'b' else None
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
construct
|
||||
"""
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
if self.IN is not None:
|
||||
out = self.IN(out)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck_IBN(nn.Cell):
|
||||
"""
|
||||
Bottleneck_IBN
|
||||
|
||||
Args:
|
||||
x (Tensor): input tensor
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None):
|
||||
super(Bottleneck_IBN, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, has_bias=False)
|
||||
if ibn == 'a':
|
||||
self.bn1 = IBN(planes)
|
||||
else:
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||
padding=0, has_bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, has_bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.IN = nn.GroupNorm(planes * 4, planes * 4, affine=True) if ibn == 'b' else None
|
||||
self.relu = nn.ReLU()
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
construct
|
||||
"""
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
if self.IN is not None:
|
||||
out = self.IN(out)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet_IBN(nn.Cell):
|
||||
"""
|
||||
ResNet_IBN
|
||||
|
||||
Args:
|
||||
x (Tensor): input tensor
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layers,
|
||||
ibn_cfg=('a', 'a', 'a', None),
|
||||
num_classes=1000):
|
||||
self.inplanes = 64
|
||||
super(ResNet_IBN, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad',
|
||||
has_bias=False)
|
||||
if ibn_cfg[0] == 'b':
|
||||
self.bn1 = nn.GroupNorm(64, 64, affine=True)
|
||||
else:
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU()
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], ibn=ibn_cfg[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, ibn=ibn_cfg[1])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, ibn=ibn_cfg[2])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, ibn=ibn_cfg[3])
|
||||
self.avgpool = nn.AvgPool2d(kernel_size=7, stride=7)
|
||||
self.fc = nn.Dense(512 * block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, ibn=None):
|
||||
"""
|
||||
_make_layer
|
||||
"""
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.SequentialCell([
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, has_bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion)
|
||||
])
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes,
|
||||
None if ibn == 'b' else ibn,
|
||||
stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes,
|
||||
None if (ibn == 'b' and i < blocks - 1) else ibn))
|
||||
|
||||
return nn.SequentialCell(*layers)
|
||||
|
||||
def construct(self, x):
|
||||
"""
|
||||
construct
|
||||
"""
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
reshape = ops.Reshape()
|
||||
x = reshape(x, (x.shape[0], -1))
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def resnet18_ibn_a(**kwargs):
|
||||
"""
|
||||
Constructs a ResNet-18-IBN-a model.
|
||||
"""
|
||||
model = ResNet_IBN(block=BasicBlock_IBN,
|
||||
layers=[2, 2, 2, 2],
|
||||
ibn_cfg=('a', 'a', 'a', None),
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def resnet34_ibn_a(**kwargs):
|
||||
"""
|
||||
Constructs a ResNet-34-IBN-a model.
|
||||
"""
|
||||
model = ResNet_IBN(block=BasicBlock_IBN,
|
||||
layers=[3, 4, 6, 3],
|
||||
ibn_cfg=('a', 'a', 'a', None),
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def resnet50_ibn_a(**kwargs):
|
||||
"""
|
||||
Constructs a ResNet-50-IBN-a model.
|
||||
"""
|
||||
model = ResNet_IBN(block=Bottleneck_IBN,
|
||||
layers=[3, 4, 6, 3],
|
||||
ibn_cfg=('a', 'a', 'a', None),
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def resnet101_ibn_a(**kwargs):
|
||||
"""
|
||||
Constructs a ResNet-101-IBN-a model.
|
||||
"""
|
||||
model = ResNet_IBN(block=Bottleneck_IBN,
|
||||
layers=[3, 4, 23, 3],
|
||||
ibn_cfg=('a', 'a', 'a', None),
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def resnet152_ibn_a(**kwargs):
|
||||
"""
|
||||
Constructs a ResNet-152-IBN-a model.
|
||||
"""
|
||||
model = ResNet_IBN(block=Bottleneck_IBN,
|
||||
layers=[3, 8, 36, 3],
|
||||
ibn_cfg=('a', 'a', 'a', None),
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def resnet18_ibn_b(**kwargs):
|
||||
"""
|
||||
Constructs a ResNet-18-IBN-b model.
|
||||
"""
|
||||
model = ResNet_IBN(block=BasicBlock_IBN,
|
||||
layers=[2, 2, 2, 2],
|
||||
ibn_cfg=('b', 'b', None, None),
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def resnet34_ibn_b(**kwargs):
|
||||
"""
|
||||
Constructs a ResNet-34-IBN-b model.
|
||||
"""
|
||||
model = ResNet_IBN(block=BasicBlock_IBN,
|
||||
layers=[3, 4, 6, 3],
|
||||
ibn_cfg=('b', 'b', None, None),
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def resnet50_ibn_b(**kwargs):
|
||||
"""
|
||||
Constructs a ResNet-50-IBN-b model.
|
||||
"""
|
||||
model = ResNet_IBN(block=Bottleneck_IBN,
|
||||
layers=[3, 4, 6, 3],
|
||||
ibn_cfg=('b', 'b', None, None),
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def resnet101_ibn_b(**kwargs):
|
||||
"""
|
||||
Constructs a ResNet-101-IBN-b model.
|
||||
"""
|
||||
model = ResNet_IBN(block=Bottleneck_IBN,
|
||||
layers=[3, 4, 23, 3],
|
||||
ibn_cfg=('b', 'b', None, None),
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def resnet152_ibn_b(**kwargs):
|
||||
"""
|
||||
Constructs a ResNet-152-IBN-b model.
|
||||
"""
|
||||
model = ResNet_IBN(block=Bottleneck_IBN,
|
||||
layers=[3, 8, 36, 3],
|
||||
ibn_cfg=('b', 'b', None, None),
|
||||
**kwargs)
|
||||
return model
|
|
@ -0,0 +1,147 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
python train.py
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Callback
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.communication.management import init, get_rank
|
||||
|
||||
from src.loss import SoftmaxCrossEntropyExpand
|
||||
from src.resnet_ibn import resnet50_ibn_a
|
||||
from src.dataset import create_dataset_ImageNet as create_dataset, create_evalset
|
||||
from src.lr_generator import lr_generator
|
||||
from src.config import cfg
|
||||
|
||||
parser = argparse.ArgumentParser(description='Mindspore ImageNet Training')
|
||||
|
||||
parser.add_argument('--use_modelarts', action="store_true",
|
||||
help="using this argument for modelarts")
|
||||
# Datasets
|
||||
parser.add_argument('--train_url', default='.', type=str)
|
||||
parser.add_argument('--data_url', required=True, type=str, help='data path')
|
||||
parser.add_argument('--ckpt_url', required=True, type=str, help="ckpt path")
|
||||
parser.add_argument('--eval_url', required=True, type=str, help="eval path")
|
||||
# Optimization options
|
||||
parser.add_argument('--epochs', default=100, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||
help='manual epoch number (useful on restarts)')
|
||||
parser.add_argument('--schedule', type=int, nargs='+', default=[30, 60, 90],
|
||||
help='Decrease learning rate at these epochs.')
|
||||
|
||||
# Checkpoints
|
||||
parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
|
||||
help='path to save checkpoint (default: checkpoint)')
|
||||
|
||||
parser.add_argument('--pretrained', action="store_true",
|
||||
help='use pre-trained model')
|
||||
# Device options
|
||||
parser.add_argument('--device_target', type=str,
|
||||
default='Ascend', choices=['GPU', 'Ascend'])
|
||||
parser.add_argument('--device_num', type=int, default=1)
|
||||
parser.add_argument('--device_id', type=int, default=0)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Precision verification using callback function.
|
||||
"""
|
||||
# define the operator required
|
||||
def __init__(self, models, eval_ds, epochs_per_eval, file_name):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.models = models
|
||||
self.eval_dataset = eval_ds
|
||||
self.epochs_per_eval = epochs_per_eval
|
||||
self.file_name = file_name
|
||||
|
||||
# define operator function in epoch end
|
||||
def epoch_end(self, run_context):
|
||||
cb_param = run_context.original_args()
|
||||
cur_epoch = cb_param.cur_epoch_num
|
||||
if cur_epoch > 90:
|
||||
acc = self.models.eval(self.eval_dataset, dataset_sink_mode=False)
|
||||
self.epochs_per_eval["epoch"].append(cur_epoch)
|
||||
self.epochs_per_eval["acc"].append(acc["Accuracy"])
|
||||
print(acc)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_epoch = args.epochs
|
||||
target = args.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=target, save_graphs=False)
|
||||
device_id = args.device_id
|
||||
if args.device_num > 1:
|
||||
if target == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id,
|
||||
enable_auto_mixed_precision=True)
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True,
|
||||
auto_parallel_search_mode="recursive_programming")
|
||||
init()
|
||||
elif target == 'GPU':
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=args.device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True,
|
||||
auto_parallel_search_mode="recursive_programming")
|
||||
else:
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
|
||||
train_dataset = create_dataset(dataset_path=args.data_url, do_train=True, repeat_num=1,
|
||||
batch_size=cfg.train_batch, target=target)
|
||||
eval_dataset = create_evalset(dataset_path=args.eval_url, do_train=False, repeat_num=1,
|
||||
batch_size=cfg.test_batch, target=target)
|
||||
|
||||
net = resnet50_ibn_a(num_classes=cfg.class_num)
|
||||
if args.pretrained:
|
||||
param_dict = load_checkpoint(args.ckpt_url)
|
||||
load_param_into_net(net, param_dict)
|
||||
criterion = SoftmaxCrossEntropyExpand(sparse=True)
|
||||
step = train_dataset.get_dataset_size()
|
||||
lr = lr_generator(cfg.lr, train_epoch, steps_per_epoch=step)
|
||||
optimizer = nn.SGD(params=net.trainable_params(), learning_rate=lr,
|
||||
momentum=cfg.momentum, weight_decay=cfg.weight_decay)
|
||||
model = Model(net, loss_fn=criterion, optimizer=optimizer,
|
||||
metrics={"Accuracy": Accuracy()})
|
||||
|
||||
config_ck = CheckpointConfig(
|
||||
save_checkpoint_steps=step, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
|
||||
ckpoint_cb = ModelCheckpoint(prefix="ResNet50_" + str(device_id), config=config_ck,
|
||||
directory='.')
|
||||
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
|
||||
loss_cb = LossMonitor()
|
||||
epoch_per_eval = {"epoch": [], "acc": []}
|
||||
eval_cb = EvalCallBack(model, eval_dataset, epoch_per_eval, "ibn")
|
||||
cb = [ckpoint_cb, time_cb, loss_cb, eval_cb]
|
||||
if args.device_num == 1:
|
||||
model.train(train_epoch, train_dataset, callbacks=cb)
|
||||
elif args.device_num > 1 and get_rank() % 8 == 0:
|
||||
model.train(train_epoch, train_dataset, callbacks=cb)
|
||||
else:
|
||||
model.train(train_epoch, train_dataset)
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
python pth2ckpt.py
|
||||
"""
|
||||
import torch
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
from mindspore import Tensor
|
||||
|
||||
param = {
|
||||
'bn1.bias': 'bn1.beta',
|
||||
'bn1.weight': 'bn1.gamma',
|
||||
'IN.weight': 'IN.gamma',
|
||||
'IN.bias': 'IN.beta',
|
||||
'BN.bias': 'BN.beta',
|
||||
'in.weight': 'in.gamma',
|
||||
'bn.weight': 'bn.gamma',
|
||||
'bn.bias': 'bn.beta',
|
||||
'bn2.weight': 'bn2.gamma',
|
||||
'bn2.bias': 'bn2.beta',
|
||||
'bn3.bias': 'bn3.beta',
|
||||
'bn3.weight': 'bn3.gamma',
|
||||
'BN.running_mean': 'BN.moving_mean',
|
||||
'BN.running_var': 'BN.moving_variance',
|
||||
'bn.running_mean': 'bn.moving_mean',
|
||||
'bn.running_var': 'bn.moving_variance',
|
||||
'bn1.running_mean': 'bn1.moving_mean',
|
||||
'bn1.running_var': 'bn1.moving_variance',
|
||||
'bn2.running_mean': 'bn2.moving_mean',
|
||||
'bn2.running_var': 'bn2.moving_variance',
|
||||
'bn3.running_mean': 'bn3.moving_mean',
|
||||
'bn3.running_var': 'bn3.moving_variance',
|
||||
'downsample.1.running_mean': 'downsample.1.moving_mean',
|
||||
'downsample.1.running_var': 'downsample.1.moving_variance',
|
||||
'downsample.0.weight': 'downsample.1.weight',
|
||||
'downsample.1.bias': 'downsample.1.beta',
|
||||
'downsample.1.weight': 'downsample.1.gamma'
|
||||
}
|
||||
|
||||
|
||||
def pytorch2mindspore():
|
||||
"""
|
||||
|
||||
Returns:
|
||||
object:
|
||||
"""
|
||||
par_dict = torch.load('resnet50_ibn_a-d9d0bb7b.pth', map_location='cpu')
|
||||
new_params_list = []
|
||||
for name in par_dict:
|
||||
param_dict = {}
|
||||
parameter = par_dict[name]
|
||||
print(name)
|
||||
for fix in param:
|
||||
if name.endswith(fix):
|
||||
name = name[:name.rfind(fix)]
|
||||
name = name + param[fix]
|
||||
|
||||
print('========================ibn_name', name)
|
||||
|
||||
param_dict['name'] = name
|
||||
param_dict['data'] = Tensor(parameter.numpy())
|
||||
new_params_list.append(param_dict)
|
||||
|
||||
save_checkpoint(new_params_list, 'resnet50_ibn_a.ckpt')
|
||||
|
||||
|
||||
pytorch2mindspore()
|
Loading…
Reference in New Issue