!18929 New model Ibnnet with GPU version

Merge pull request !18929 from 水扬波/ibnnet_gpu
This commit is contained in:
i-robot 2021-07-07 09:12:03 +00:00 committed by Gitee
commit 315bbf12ab
17 changed files with 1543 additions and 0 deletions

View File

@ -0,0 +1,285 @@
# 目录
- [目录](#目录)
- [IBN-Net概述](#IBN-Net概述)
- [IBN-Net示例](#IBN-Net示例)
- [数据集](#数据集)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本和样例代码](#脚本和样例代码)
- [脚本参数](#脚本参数)
- [预训练模型](#预训练模型)
- [训练过程](#训练过程)
- [训练](#训练)
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [推理性能](#推理性能)
- [使用方法](#使用方法)
- [推理](#推理)
- [迁移学习](#迁移学习)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#ModelZoo主页)
<!-- /TOC -->
# IBN-Net概述
卷积神经网络CNNs在许多计算机视觉问题上取得了巨大的成功。与现有的设计CNN架构的工作不同论文提出了一种新的卷积架构IBN-Net它可以提高单个域中单个任务的性能这显著提高了CNN在一个领域如城市景观的建模能力以及在另一个领域如GTA5的泛化能力而无需微调。IBN-Net将InstanceNormIN和BatchNormBN作为构建块进行了集成并可以封装到许多高级的深度网络中以提高其性能。这项工作有三个关键贡献。1 通过深入研究IN和BN我们发现IN学习对外观变化不变的特征例如颜色、样式和虚拟/现实而BN对于保存内容相关信息是必不可少的。2 IBN-Net可以应用于许多高级的深层体系结构如DenseNet、ResNet、ResNeXt和SENet并在不增加计算量的情况下不断地提高它们的性能。3 当将训练好的网络应用到新的领域时例如从GTA5到城市景观IBN网络作为领域适应方法实现了类似的改进即使不使用来自目标领域的数据。
[论文](https://arxiv.org/abs/1807.09441) Pan X , Ping L , Shi J , et al. Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net[C]// European Conference on Computer Vision. Springer, Cham, 2018.
# IBN-Net示例
# 数据集
使用的数据集:[ImageNet2012](http://www.image-net.org/)
训练集1,281,167张图片+标签
验证集50,000张图片+标签
测试集100,000张图片
# 环境要求
- 硬件Ascend/GPU
- 使用Ascend/GPU处理器来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
```python
# 分布式训练运行示例
sh scripts/run_distribute_train.sh /path/dataset /path/evalset pretrained_model.ckpt rank_size
# 单机训练运行示例
sh scripts/run_standalone_train.sh /path/dataset /path/evalset pretrained_model.ckpt
# 运行评估示例
sh scripts/run_eval.sh
```
## 脚本说明
## 脚本和样例代码
```path
└── IBNNet
├── README.md // IBNNet相关描述
├── scripts
├── run_distribute_train.sh // 用于分布式训练的shell脚本
├── run_distribute_train_gpu.sh // 用于GPU分布式训练的shell脚本
├── run_standalone_train.sh // 用于单机训练的shell脚本
├── run_standalone_train.sh // 用于GPU单机训练的shell脚本
├── run_eval.sh // 用于评估的shell脚本
└── run_eval.sh // 用于GPU评估的shell脚本
├── src
├── loss.py //损失函数
├── lr_generator.py //生成学习率
├── config.py // 参数配置
├── dataset.py // 创建数据集
├── resnet_ibn.py // IBNNet架构
├── utils
├── pth2ckpt.py //转换pth文件为ckpt文件
├── export.py
├── eval.py // 测试脚本
├── train.py // 训练脚本
```
## 脚本参数
```python
train.py和config.py中主要参数如下
-- use_modelarts是否使用modelarts平台训练。可选值为True、False。
-- device_id用于训练或评估数据集的设备ID。当使用train.sh进行分布式训练时忽略此参数。
-- device_num使用train.sh进行分布式训练时使用的设备数。
-- train_urlcheckpoint的输出路径。
-- data_url训练集路径。
-- ckpt_urlcheckpoint路径。
-- eval_url验证集路径。
```
## 预训练模型
可以使用utils/pth2ckpt.py将预训练的pth文件转换为ckpt文件。
pth预训练模型文件获取路径如下[预训练模型](https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth)
## 训练过程
### 训练
- 在Ascend环境训练
```shell
sh scripts/run_standalone_train.sh /path/dataset /path/evalset pretrained_model.ckpt
```
- 在GPU环境训练
```shell
sh scripts/run_standalone_train_gpu.sh /path/dataset /path/evalset pretrained_model.ckpt
```
### 分布式训练
- 在Ascend环境训练
```shell
sh scripts/run_distribute_train.sh /path/dataset /path/evalset pretrained_model.ckpt rank_size
```
上述shell脚本将在后台运行分布训练。可以通过`device[X]/test_*.log`文件查看结果。
采用以下方式达到损失值:
```log
epoch: 12 step: 2502, loss is 1.7709649
epoch time: 331584.555 ms, per step time: 132.528 ms
epoch: 12 step: 2502, loss is 1.2770984
epoch time: 331503.971 ms, per step time: 132.496 ms
...
epoch: 82 step: 2502, loss is 0.98658705
epoch time: 331877.856 ms, per step time: 132.645 ms
epoch: 82 step: 2502, loss is 0.82476664
epoch time: 331689.239 ms, per step time: 132.570 ms
```
- 在GPU环境训练
```shell
sh scripts/run_distribute_train_gpu.sh /path/dataset /path/evalset pretrained_model.ckpt rank_size
```
## 评估过程
### 评估
- 在Ascend环境运行时评估ImageNet数据集
```bash
sh scripts/run_eval.sh path/evalset path/ckpt
```
上述命令将在后台运行您可以通过eval.log文件查看结果。测试数据集的准确性如下
```bash
{'Accuracy': 0.7785483870967742}
```
- 在GPU环境运行时评估ImageNet数据集
```bash
sh scripts/run_eval_gpu.sh path/evalset path/ckpt
```
上述命令将在后台运行您可以通过eval.log文件查看结果。测试数据集的准确性如下
```bash
============== Accuracy:{'top_5_accuracy': 0.93684, 'top_1_accuracy': 0.7743} ==============
```
# 模型描述
## 性能
### 评估性能
| 参数 | IBN-Net |
| ------------- | ----------------------------------------------- |
| 模型版本 | resnet50_ibn_a |
| 资源 | Ascend 910 CPU 2.60GHz192内核内存755G |
| 上传日期 | 2021-03-30 |
| MindSpore版本 | 1.1.1-c76-tr5 |
| 数据集 | ImageNet2012 |
| 训练参数 | lr=0.1; gamma=0.1 |
| 优化器 | SGD |
| 损失函数 | SoftmaxCrossEntropyExpand |
| 输出 | 概率 |
| 损失 | 0.6 |
| 速度 | 1卡127毫秒/步8卡132毫秒/步 |
| 总时间 | 1卡65小时8卡9.5小时 |
| 参数(M) | 46.15 |
| 微调检查点 | 293M .ckpt file |
| 脚本 | [脚本路径](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/ibnnet) |
### 推理性能
| 参数 | IBN-Net |
| ------------- | ------------------ |
| 模型版本 | resnet50_ibn_a |
| 资源 | Ascend 910 |
| 上传日期 | 2021/03/30 |
| MindSpore版本 | 1.1.1-c76-tr5 |
| 数据集 | ImageNet2012 |
| 输出 | 概率 |
| 准确性 | 1卡77.45%; 8卡77.45% |
## 使用方法
### 推理
如果您需要使用已训练模型在GPU、Ascend 910、Ascend 310等多个硬件平台上进行推理可参考[此处](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/migrate_3rd_scripts.html)。操作示例如下:
```python
# 加载未知数据集进行推理
dataset = dataset.create_dataset(cfg.data_path, 1, False)
# 定义模型
net = resnet50_ibn_a(num_classes=1000, pretrained=False)
param_dict = load_checkpoint(args.ckpt_url)
load_param_into_net(net, param_dict)
print('Load Pretrained parameters done!')
criterion = SoftmaxCrossEntropyExpand(sparse=True)
step = train_dataset.get_dataset_size()
lr = lr_generator(args.lr, train_epoch, steps_per_epoch=step)
optimizer = nn.SGD(params=net.trainable_params(), learning_rate=lr,
momentum=args.momentum, weight_decay=args.weight_decay)
# 模型变形
model = Model(net, loss_fn=criterion, optimizer=optimizer, metrics={"Accuracy": Accuracy()})
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
loss_cb = LossMonitor()
# 设置并应用检查点参数
config_ck = CheckpointConfig(save_checkpoint_steps=step, keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix="ResNet50_" + str(device_id), config=config_ck, directory='/cache/train_output/device_' + str(device_id))
cb = [ckpoint_cb, time_cb, loss_cb, eval_cb]
model.train(train_epoch, train_dataset, callbacks=cb)
# 加载预训练模型
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
# 对未知数据集进行预测
acc = model.eval(eval_dataset)
print("accuracy: ", acc)
```
### 迁移学习
待补充
# 随机情况说明
在dataset.py中我们设置了“create_dataset_ImageNet”函数内的种子。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,80 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
python eval.py
"""
import argparse
import os
import mindspore.nn as nn
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.resnet_ibn import resnet50_ibn_a
from src.loss import SoftmaxCrossEntropyExpand
from src.dataset import create_dataset_ImageNet as create_dataset
from src.lr_generator import lr_generator
from src.config import cfg
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Datasets
parser.add_argument('--eval_url', required=True, type=str, help='val data path')
# Optimization options
parser.add_argument('--epochs', default=100, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--train_batch', default=256, type=int, metavar='N',
help='train batch size (default: 256)')
parser.add_argument('--test_batch', default=100, type=int, metavar='N',
help='test batch size (default: 100)')
# Checkpoints
parser.add_argument('-c', '--checkpoint', required=True, type=str, metavar='PATH',
help='path to save checkpoint (default: checkpoint)')
# Device options
parser.add_argument('--device_target', type=str, default='Ascend', choices=['GPU', 'Ascend'])
parser.add_argument('--device_num', type=int, default=1)
parser.add_argument('--device_id', type=int, default=0)
args = parser.parse_args()
if __name__ == "__main__":
train_epoch = 1
step = 60
target = args.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
context.set_context(device_id=args.device_id, enable_auto_mixed_precision=True)
lr = lr_generator(cfg.lr, train_epoch, steps_per_epoch=step)
net = resnet50_ibn_a(num_classes=cfg.class_num)
criterion = SoftmaxCrossEntropyExpand(sparse=True)
optimizer = nn.SGD(params=net.trainable_params(), learning_rate=lr,
momentum=cfg.momentum, weight_decay=cfg.weight_decay)
model = Model(net, loss_fn=criterion, optimizer=optimizer, metrics={"top_1_accuracy", "top_5_accuracy"})
print("============== Starting Testing ==============")
# load the saved model for evaluation
param_dict = load_checkpoint(args.checkpoint)
# load parameter to the network
load_param_into_net(net, param_dict)
# load testing dataset
ds_eval = create_dataset(os.path.join(args.eval_url), do_train=False, repeat_num=1,
batch_size=cfg.test_batch, target=target)
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("============== Accuracy:{} ==============".format(acc))

View File

@ -0,0 +1,57 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx, mindir models#################
python export.py
"""
import argparse
import numpy as np
from mindspore import dtype as mstype
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
from src.config import cfg
from src.resnet_ibn import resnet50_ibn_a
parser = argparse.ArgumentParser(description='Classification')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="ibnnet", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
parser.add_argument('--dataset_name', type=str, default='imagenet', choices=['imagenet'],
help='dataset name.')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
if args.dataset_name == 'imagenet':
cfg = cfg
else:
raise ValueError("dataset is not support.")
net = resnet50_ibn_a(num_classes=cfg.class_num)
assert args.ckpt_file is not None, "args.ckpt_file is None."
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.ones([args.batch_size, 3, 224, 224]), mstype.float32)
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)

View File

@ -0,0 +1,65 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run.sh DATA_PATH EVAL_PATH CKPT_PATH RANK_SIZE"
echo "For example: bash run.sh /path/dataset /path/evalset /path/ckpt 8"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
DATA_PATH=$1
export DATA_PATH=${DATA_PATH}
EVAL_PATH=$2
export EVAL_PATH=${EVAL_PATH}
# shellcheck disable=SC2034
CKPT_PATH=$3
export CKPT_PATH=${CKPT_PATH}
RANK_SIZE=$4
EXEC_PATH=$(pwd)
echo "$EXEC_PATH"
test_dist_8pcs()
{
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
export RANK_SIZE=8
}
test_dist_2pcs()
{
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
export RANK_SIZE=2
}
test_dist_${RANK_SIZE}pcs
for((i=0;i<RANK_SIZE;i++))
do
rm -rf device$i
mkdir device$i
cp -r ./src/ ./device$i
cp train.py ./device$i
cd ./device$i
export DEVICE_ID=$i
export RANK_ID=$i
echo "start training for device $i"
env > env$i.log
python train.py --data_url "$DATA_PATH" --eval_url "$EVAL_PATH" --ckpt_url "$CKPT_PATH" \
--device_num "$RANK_SIZE" --pretrained > train.log$i 2>&1 &
cd ../
done
echo "start training"
cd ../

View File

@ -0,0 +1,44 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run.sh DATA_PATH EVAL_PATH CKPT_PATH RANK_SIZE"
echo "For example: bash run.sh /path/dataset /path/evalset /path/ckpt 8"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
set -e
export DEVICE_NUM=$4
export RANK_SIZE=$4
export DATASET_NAME=$1
export EVAL_PATH=$2
export CKPT_PATH=$3
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
rm -rf ./train_parallel
mkdir ./train_parallel
cp -r ./src/ ./train_parallel
# shellcheck disable=SC2035
cp *.py ./train_parallel
cd ./train_parallel
env > env.log
echo "start training"
mpirun -n $4 --allow-run-as-root \
python train.py --device_num $4 --device_target GPU --data_url $1 \
--ckpt_url $3 --eval_url $2 \
--pretrained \
> train.log 2>&1 &

View File

@ -0,0 +1,25 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_PATH=$1
CHECKPOINT=$2
python eval.py \
--device_id 0 \
--checkpoint "$CHECKPOINT" \
--eval_url "$DATA_PATH" \
> eval.log 2>&1 &
echo "start evaluation"

View File

@ -0,0 +1,26 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_PATH=$1
CHECKPOINT=$2
python eval.py \
--device_id 0 \
--checkpoint "$CHECKPOINT" \
--eval_url "$DATA_PATH" \
--device_target "GPU" \
> eval.log 2>&1 &
echo "start evaluation"

View File

@ -0,0 +1,37 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run.sh DATA_PATH EVAL_PATH CKPT_PATH"
echo "For example: bash run.sh /path/dataset /path/evalset /path/ckpt"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
EXE_PATH=$(pwd)
DATA_PATH=$1
EVAL_PATH=$2
CKPT_PATH=$3
python train.py \
--epochs 100 \
--train_url "$EXE_PATH" \
--data_url "$DATA_PATH" \
--eval_url "$EVAL_PATH" \
--ckpt_url "$CKPT_PATH" \
--pretrained \
> train.log 2>&1 &
echo "start training"
cd ../

View File

@ -0,0 +1,38 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run.sh DATA_PATH EVAL_PATH CKPT_PATH"
echo "For example: bash run.sh /path/dataset /path/evalset /path/ckpt"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
EXE_PATH=$(pwd)
DATA_PATH=$1
EVAL_PATH=$2
CKPT_PATH=$3
python train.py \
--epochs 100 \
--train_url "$EXE_PATH" \
--data_url "$DATA_PATH" \
--eval_url "$EVAL_PATH" \
--ckpt_url "$CKPT_PATH" \
--device_target "GPU" \
--pretrained \
> train.log 2>&1 &
echo "start training"
cd ../

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
python config.py
"""
from easydict import EasyDict
cfg = EasyDict({
"class_num": 1000,
"train_batch": 64,
"test_batch": 100,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 1,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 10,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 5,
"lr": 0.1,
"gamma": 0.1,
"schedule": [30, 60, 90]
})

View File

@ -0,0 +1,144 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
python dataset.py
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore.communication.management import init, get_rank, get_group_size
def create_dataset_ImageNet(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
"""
create a train or eval imagenet2012 dataset for ibnnet
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
if target == "Ascend":
device_num, rank_id = _get_rank_info()
else:
init("nccl")
rank_id = get_rank()
device_num = get_group_size()
if device_num == 1:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
else:
trans = [
C.Decode(),
C.Resize(256),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
def create_evalset(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
"""
create a eval imagenet2012 dataset for ibnnet
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
trans = [
C.Decode(),
C.Resize(256),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE", 1))
if rank_size > 1:
rank_size = int(os.environ.get("RANK_SIZE"))
rank_id = int(os.environ.get("RANK_ID"))
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id

View File

@ -0,0 +1,69 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
python loss.py
"""
from mindspore import dtype as mstype
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
class SoftmaxCrossEntropyExpand(nn.Cell):
"""
SoftmaxCrossEntropy
Args:
logit(Tensor): input tensor
label(Tensor): label of images
Returns:
loss(float): loss
"""
def __init__(self, sparse=False):
super(SoftmaxCrossEntropyExpand, self).__init__()
self.exp = ops.Exp()
self.sum = ops.ReduceSum(keep_dims=True)
self.onehot = ops.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.div = ops.RealDiv()
self.log = ops.Log()
self.sum_cross_entropy = ops.ReduceSum(keep_dims=False)
self.mul = ops.Mul()
self.mul2 = ops.Mul()
self.mean = ops.ReduceMean(keep_dims=False)
self.sparse = sparse
self.max = ops.ReduceMax(keep_dims=True)
self.sub = ops.Sub()
self.eps = Tensor(1e-24, mstype.float32)
def construct(self, logit, label):
"""
construct
"""
logit_max = self.max(logit, -1)
exp = self.exp(self.sub(logit, logit_max))
exp_sum = self.sum(exp, -1)
softmax_result = self.div(exp, exp_sum)
if self.sparse:
label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
softmax_result_log = self.log(softmax_result + self.eps)
loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1)
loss = self.mul2(ops.scalar_to_array(-1.0), loss)
loss = self.mean(loss, -1)
return loss

View File

@ -0,0 +1,31 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
python lr_generator.py
"""
import numpy as np
from mindspore import Tensor
from src.config import cfg
def lr_generator(lr_init, total_epochs, steps_per_epoch):
lr_each_step = []
for i in range(total_epochs):
if i in cfg.schedule:
lr_init *= cfg.gamma
for _ in range(steps_per_epoch):
lr_each_step.append(lr_init)
lr_each_step = np.array(lr_each_step).astype(np.float32)
return Tensor(lr_each_step)

View File

@ -0,0 +1,72 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
python modules.py
"""
import mindspore.nn as nn
import mindspore.ops as ops
class IBN(nn.Cell):
r"""Instance-Batch Normalization layer from
`"Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net"
<https://arxiv.org/pdf/1807.09441.pdf>`
Args:
planes (int): Number of channels for the input tensor
ratio (float): Ratio of instance normalization in the IBN layer
"""
def __init__(self, planes, ratio=0.5):
super(IBN, self).__init__()
self.half = int(planes * ratio)
self.IN = nn.GroupNorm(self.half, self.half, affine=True)
self.BN = nn.BatchNorm2d(planes - self.half)
def construct(self, x):
op_split = ops.Split(1, 2)
split = op_split(x)
out1 = self.IN(split[0])
out2 = self.BN(split[1])
op_cat = ops.Concat(1)
out = op_cat((out1, out2))
return out
class SELayer(nn.Cell):
"""SELayer
Args:
x (Tensor): input tensor
"""
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = ops.ReduceMean()
self.fc = nn.SequentialCell(
[
nn.Dense(channel, int(channel / reduction), has_bias=False),
nn.ReLU(),
nn.Dense(int(channel / reduction), channel, has_bias=False),
nn.Sigmoid()
]
)
def construct(self, x):
[b, c, _, _] = x.shape
_reshape = ops.Reshape()
y = _reshape(self.avg_pool(x, (2, 3)), (b, c))
y = _reshape(self.fc(y), (b, c, 1, 1))
broadcast = ops.BroadcastTo(x.shape)
return x * broadcast(y)

View File

@ -0,0 +1,308 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
python resnet_ibn.py
"""
import mindspore.nn as nn
import mindspore.ops as ops
from src.modules import IBN
class BasicBlock_IBN(nn.Cell):
"""
BasicBlock_IBN
Args:
x (Tensor): input tensor
"""
expansion = 1
def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None):
super(BasicBlock_IBN, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
padding=0, has_bias=False)
if ibn == 'a':
self.bn1 = IBN(planes)
else:
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, pad_mode='pad', padding=1, has_bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.IN = nn.GroupNorm(planes, planes, affine=True) if ibn == 'b' else None
self.downsample = downsample
self.stride = stride
def construct(self, x):
"""
construct
"""
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
if self.IN is not None:
out = self.IN(out)
out = self.relu(out)
return out
class Bottleneck_IBN(nn.Cell):
"""
Bottleneck_IBN
Args:
x (Tensor): input tensor
"""
expansion = 4
def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None):
super(Bottleneck_IBN, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, has_bias=False)
if ibn == 'a':
self.bn1 = IBN(planes)
else:
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=0, has_bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, has_bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.IN = nn.GroupNorm(planes * 4, planes * 4, affine=True) if ibn == 'b' else None
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
def construct(self, x):
"""
construct
"""
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
if self.IN is not None:
out = self.IN(out)
out = self.relu(out)
return out
class ResNet_IBN(nn.Cell):
"""
ResNet_IBN
Args:
x (Tensor): input tensor
"""
def __init__(self,
block,
layers,
ibn_cfg=('a', 'a', 'a', None),
num_classes=1000):
self.inplanes = 64
super(ResNet_IBN, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad',
has_bias=False)
if ibn_cfg[0] == 'b':
self.bn1 = nn.GroupNorm(64, 64, affine=True)
else:
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self._make_layer(block, 64, layers[0], ibn=ibn_cfg[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, ibn=ibn_cfg[1])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, ibn=ibn_cfg[2])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, ibn=ibn_cfg[3])
self.avgpool = nn.AvgPool2d(kernel_size=7, stride=7)
self.fc = nn.Dense(512 * block.expansion, num_classes)
def _make_layer(self, block, planes, blocks, stride=1, ibn=None):
"""
_make_layer
"""
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.SequentialCell([
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, has_bias=False),
nn.BatchNorm2d(planes * block.expansion)
])
layers = []
layers.append(block(self.inplanes, planes,
None if ibn == 'b' else ibn,
stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes,
None if (ibn == 'b' and i < blocks - 1) else ibn))
return nn.SequentialCell(*layers)
def construct(self, x):
"""
construct
"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
reshape = ops.Reshape()
x = reshape(x, (x.shape[0], -1))
x = self.fc(x)
return x
def resnet18_ibn_a(**kwargs):
"""
Constructs a ResNet-18-IBN-a model.
"""
model = ResNet_IBN(block=BasicBlock_IBN,
layers=[2, 2, 2, 2],
ibn_cfg=('a', 'a', 'a', None),
**kwargs)
return model
def resnet34_ibn_a(**kwargs):
"""
Constructs a ResNet-34-IBN-a model.
"""
model = ResNet_IBN(block=BasicBlock_IBN,
layers=[3, 4, 6, 3],
ibn_cfg=('a', 'a', 'a', None),
**kwargs)
return model
def resnet50_ibn_a(**kwargs):
"""
Constructs a ResNet-50-IBN-a model.
"""
model = ResNet_IBN(block=Bottleneck_IBN,
layers=[3, 4, 6, 3],
ibn_cfg=('a', 'a', 'a', None),
**kwargs)
return model
def resnet101_ibn_a(**kwargs):
"""
Constructs a ResNet-101-IBN-a model.
"""
model = ResNet_IBN(block=Bottleneck_IBN,
layers=[3, 4, 23, 3],
ibn_cfg=('a', 'a', 'a', None),
**kwargs)
return model
def resnet152_ibn_a(**kwargs):
"""
Constructs a ResNet-152-IBN-a model.
"""
model = ResNet_IBN(block=Bottleneck_IBN,
layers=[3, 8, 36, 3],
ibn_cfg=('a', 'a', 'a', None),
**kwargs)
return model
def resnet18_ibn_b(**kwargs):
"""
Constructs a ResNet-18-IBN-b model.
"""
model = ResNet_IBN(block=BasicBlock_IBN,
layers=[2, 2, 2, 2],
ibn_cfg=('b', 'b', None, None),
**kwargs)
return model
def resnet34_ibn_b(**kwargs):
"""
Constructs a ResNet-34-IBN-b model.
"""
model = ResNet_IBN(block=BasicBlock_IBN,
layers=[3, 4, 6, 3],
ibn_cfg=('b', 'b', None, None),
**kwargs)
return model
def resnet50_ibn_b(**kwargs):
"""
Constructs a ResNet-50-IBN-b model.
"""
model = ResNet_IBN(block=Bottleneck_IBN,
layers=[3, 4, 6, 3],
ibn_cfg=('b', 'b', None, None),
**kwargs)
return model
def resnet101_ibn_b(**kwargs):
"""
Constructs a ResNet-101-IBN-b model.
"""
model = ResNet_IBN(block=Bottleneck_IBN,
layers=[3, 4, 23, 3],
ibn_cfg=('b', 'b', None, None),
**kwargs)
return model
def resnet152_ibn_b(**kwargs):
"""
Constructs a ResNet-152-IBN-b model.
"""
model = ResNet_IBN(block=Bottleneck_IBN,
layers=[3, 8, 36, 3],
ibn_cfg=('b', 'b', None, None),
**kwargs)
return model

View File

@ -0,0 +1,147 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
python train.py
"""
import argparse
import os
import mindspore.nn as nn
from mindspore import context
from mindspore.train.model import Model, ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Callback
from mindspore.nn.metrics import Accuracy
from mindspore.communication.management import init, get_rank
from src.loss import SoftmaxCrossEntropyExpand
from src.resnet_ibn import resnet50_ibn_a
from src.dataset import create_dataset_ImageNet as create_dataset, create_evalset
from src.lr_generator import lr_generator
from src.config import cfg
parser = argparse.ArgumentParser(description='Mindspore ImageNet Training')
parser.add_argument('--use_modelarts', action="store_true",
help="using this argument for modelarts")
# Datasets
parser.add_argument('--train_url', default='.', type=str)
parser.add_argument('--data_url', required=True, type=str, help='data path')
parser.add_argument('--ckpt_url', required=True, type=str, help="ckpt path")
parser.add_argument('--eval_url', required=True, type=str, help="eval path")
# Optimization options
parser.add_argument('--epochs', default=100, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--schedule', type=int, nargs='+', default=[30, 60, 90],
help='Decrease learning rate at these epochs.')
# Checkpoints
parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
help='path to save checkpoint (default: checkpoint)')
parser.add_argument('--pretrained', action="store_true",
help='use pre-trained model')
# Device options
parser.add_argument('--device_target', type=str,
default='Ascend', choices=['GPU', 'Ascend'])
parser.add_argument('--device_num', type=int, default=1)
parser.add_argument('--device_id', type=int, default=0)
args = parser.parse_args()
class EvalCallBack(Callback):
"""
Precision verification using callback function.
"""
# define the operator required
def __init__(self, models, eval_ds, epochs_per_eval, file_name):
super(EvalCallBack, self).__init__()
self.models = models
self.eval_dataset = eval_ds
self.epochs_per_eval = epochs_per_eval
self.file_name = file_name
# define operator function in epoch end
def epoch_end(self, run_context):
cb_param = run_context.original_args()
cur_epoch = cb_param.cur_epoch_num
if cur_epoch > 90:
acc = self.models.eval(self.eval_dataset, dataset_sink_mode=False)
self.epochs_per_eval["epoch"].append(cur_epoch)
self.epochs_per_eval["acc"].append(acc["Accuracy"])
print(acc)
if __name__ == "__main__":
train_epoch = args.epochs
target = args.device_target
context.set_context(mode=context.GRAPH_MODE,
device_target=target, save_graphs=False)
device_id = args.device_id
if args.device_num > 1:
if target == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id,
enable_auto_mixed_precision=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
auto_parallel_search_mode="recursive_programming")
init()
elif target == 'GPU':
init()
context.set_auto_parallel_context(device_num=args.device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
auto_parallel_search_mode="recursive_programming")
else:
context.set_context(device_id=device_id)
train_dataset = create_dataset(dataset_path=args.data_url, do_train=True, repeat_num=1,
batch_size=cfg.train_batch, target=target)
eval_dataset = create_evalset(dataset_path=args.eval_url, do_train=False, repeat_num=1,
batch_size=cfg.test_batch, target=target)
net = resnet50_ibn_a(num_classes=cfg.class_num)
if args.pretrained:
param_dict = load_checkpoint(args.ckpt_url)
load_param_into_net(net, param_dict)
criterion = SoftmaxCrossEntropyExpand(sparse=True)
step = train_dataset.get_dataset_size()
lr = lr_generator(cfg.lr, train_epoch, steps_per_epoch=step)
optimizer = nn.SGD(params=net.trainable_params(), learning_rate=lr,
momentum=cfg.momentum, weight_decay=cfg.weight_decay)
model = Model(net, loss_fn=criterion, optimizer=optimizer,
metrics={"Accuracy": Accuracy()})
config_ck = CheckpointConfig(
save_checkpoint_steps=step, keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="ResNet50_" + str(device_id), config=config_ck,
directory='.')
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
loss_cb = LossMonitor()
epoch_per_eval = {"epoch": [], "acc": []}
eval_cb = EvalCallBack(model, eval_dataset, epoch_per_eval, "ibn")
cb = [ckpoint_cb, time_cb, loss_cb, eval_cb]
if args.device_num == 1:
model.train(train_epoch, train_dataset, callbacks=cb)
elif args.device_num > 1 and get_rank() % 8 == 0:
model.train(train_epoch, train_dataset, callbacks=cb)
else:
model.train(train_epoch, train_dataset)

View File

@ -0,0 +1,79 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
python pth2ckpt.py
"""
import torch
from mindspore.train.serialization import save_checkpoint
from mindspore import Tensor
param = {
'bn1.bias': 'bn1.beta',
'bn1.weight': 'bn1.gamma',
'IN.weight': 'IN.gamma',
'IN.bias': 'IN.beta',
'BN.bias': 'BN.beta',
'in.weight': 'in.gamma',
'bn.weight': 'bn.gamma',
'bn.bias': 'bn.beta',
'bn2.weight': 'bn2.gamma',
'bn2.bias': 'bn2.beta',
'bn3.bias': 'bn3.beta',
'bn3.weight': 'bn3.gamma',
'BN.running_mean': 'BN.moving_mean',
'BN.running_var': 'BN.moving_variance',
'bn.running_mean': 'bn.moving_mean',
'bn.running_var': 'bn.moving_variance',
'bn1.running_mean': 'bn1.moving_mean',
'bn1.running_var': 'bn1.moving_variance',
'bn2.running_mean': 'bn2.moving_mean',
'bn2.running_var': 'bn2.moving_variance',
'bn3.running_mean': 'bn3.moving_mean',
'bn3.running_var': 'bn3.moving_variance',
'downsample.1.running_mean': 'downsample.1.moving_mean',
'downsample.1.running_var': 'downsample.1.moving_variance',
'downsample.0.weight': 'downsample.1.weight',
'downsample.1.bias': 'downsample.1.beta',
'downsample.1.weight': 'downsample.1.gamma'
}
def pytorch2mindspore():
"""
Returns:
object:
"""
par_dict = torch.load('resnet50_ibn_a-d9d0bb7b.pth', map_location='cpu')
new_params_list = []
for name in par_dict:
param_dict = {}
parameter = par_dict[name]
print(name)
for fix in param:
if name.endswith(fix):
name = name[:name.rfind(fix)]
name = name + param[fix]
print('========================ibn_name', name)
param_dict['name'] = name
param_dict['data'] = Tensor(parameter.numpy())
new_params_list.append(param_dict)
save_checkpoint(new_params_list, 'resnet50_ibn_a.ckpt')
pytorch2mindspore()