!18850 master_resnext152

Merge pull request !18850 from zhanggy/master
This commit is contained in:
i-robot 2021-08-10 01:09:02 +00:00 committed by Gitee
commit b8767cfd53
26 changed files with 2893 additions and 0 deletions

View File

@ -0,0 +1,248 @@
# Contents
- [ResNeXt152 Description](#resnext152-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Model Export](#model-export)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [ResNeXt152 Description](#contents)
ResNeXt is a simple, highly modularized network architecture for image classification. It designs results in a homogeneous, multi-branch architecture that has only a few hyper-parameters to set in ResNeXt. This strategy exposes a new dimension, which we call “cardinality” (the size of the set of transformations), as an essential factor in addition to the dimensions of depth and width.
[Paper](https://arxiv.org/abs/1611.05431): Xie S, Girshick R, Dollár, Piotr, et al. Aggregated Residual Transformations for Deep Neural Networks. 2016.
# [Model architecture](#contents)
The overall network architecture of ResNeXt is show below:
[Link](https://arxiv.org/abs/1611.05431)
# [Dataset](#contents)
Dataset used: [imagenet](http://www.image-net.org/)
- Dataset size: ~125G, 1.2W colorful images in 1000 classes
- Train: 120G, 1.2W images
- Test: 5G, 50000 images
- Data format: RGB images
- Note: Data will be processed in src/dataset.py
# [Features](#contents)
## [Mixed Precision](#contents)
The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching reduce precision.
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Script description](#contents)
## [Script and sample code](#contents)
```python
.
└─resnext152
├─README.md
├─scripts
├─run_standalone_train.sh # launch standalone training for ascend(1p)
├─run_distribute_train.sh # launch distributed training for ascend(8p)
└─run_eval.sh # launch evaluating
├─src
├─backbone
├─_init_.py # initialize
├─resnet.py # resnext152 backbone
├─utils
├─_init_.py # initialize
├─cunstom_op.py # network operation
├─logging.py # print log
├─optimizers_init_.py # get parameters
├─sampler.py # distributed sampler
├─var_init_.py # calculate gain value
├─_init_.py # initialize
├─config.py # parameter configuration
├─crossentropy.py # CrossEntropy loss function
├─dataset.py # data preprocessing
├─eval_callback.py # Inference during training
├─head.py # common head
├─image_classification.py # get resnet
├─metric.py # Inference
├─linear_warmup.py # linear warmup learning rate
├─warmup_cosine_annealing.py # learning rate each step
├─warmup_step_lr.py # warmup step learning rate
├─eval.py # eval net
├──train.py # train net
├──export.py # export mindir script
├──mindspore_hub_conf.py # mindspore hub interface
```
## [Script Parameters](#contents)
Parameters for both training and evaluating can be set in config.py.
```config
"image_height": '224,224' # image size
"num_classes": 1000, # dataset class number
"per_batch_size": 128, # batch size of input tensor
"lr": 0.05, # base learning rate
"lr_scheduler": 'cosine_annealing', # learning rate mode
"lr_epochs": '30,60,90,120', # epoch of lr changing
"lr_gamma": 0.1, # decrease lr by a factor of exponential lr_scheduler
"eta_min": 0, # eta_min in cosine_annealing scheduler
"T_max": 150, # T-max in cosine_annealing scheduler
"max_epoch": 150, # max epoch num to train the model
"warmup_epochs" : 1, # warmup epoch
"weight_decay": 0.0001, # weight decay
"momentum": 0.9, # momentum
"is_dynamic_loss_scale": 0, # dynamic loss scale
"loss_scale": 1024, # loss scale
"label_smooth": 1, # label_smooth
"label_smooth_factor": 0.1, # label_smooth_factor
"ckpt_interval": 2000, # ckpt_interval
"ckpt_path": 'outputs/', # checkpoint save location
"is_save_on_master": 1,
"rank": 0, # local rank of distributed
"group_size": 1 # world size of distributed
```
## [Training Process](#contents)
### Usage
You can start training by python script:
```script
python train.py --data_dir ~/imagenet/train/ --platform Ascend --is_distributed 0
```
or shell script:
```script
Ascend:
# distribute training example(8p)
sh run_distribute_train.sh RANK_TABLE_FILE DATA_PATH
# standalone training
sh run_standalone_train.sh DEVICE_ID DATA_PATH
```
#### Launch
```bash
# distributed training example(8p) for Ascend
sh scripts/run_distribute_train.sh RANK_TABLE_FILE /dataset/train
# standalone training example for Ascend
sh scripts/run_standalone_train.sh 0 /dataset/train
```
You can find checkpoint file together with result in log.
## [Evaluation Process](#contents)
### Usage
You can start training by python script:
```script
python eval.py --data_dir ~/imagenet/val/ --platform Ascend --pretrained resnext.ckpt
```
or shell script:
```script
# Evaluation
sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH PLATFORM
```
PLATFORM is Ascend, default is Ascend.
#### Launch
```bash
# Evaluation with checkpoint
sh scripts/run_eval.sh 0 /opt/npu/datasets/classification/val /resnext152_100.ckpt Ascend
#Directly use the script to run
python eval.py --data_dir /opt/npu/pvc/dataset/storage/imagenet/val/ --platform Ascend --pretrained /root/test/resnext152_64x4d/outputs_demo/best_acc_4.ckpt
```
#### Result
Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log.
```log
acc=80.08%(TOP1)
acc=94.71%(TOP5)
```
## [Model Export](#contents)
```shell
python export.py --device_target [PLATFORM] --ckpt_file [CKPT_PATH] --file_format [EXPORT_FORMAT]
```
`EXPORT_FORMAT` should be in ["AIR", "ONNX", "MINDIR"]
# [Model description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | ResNeXt152 | |
| -------------------------- | --------------------------------------------- | ---- |
| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G | |
| uploaded Date | 06/30/2021 | |
| MindSpore Version | 1.2 | |
| Dataset | ImageNet | |
| Training Parameters | src/config.py | |
| Optimizer | Momentum | |
| Loss Function | SoftmaxCrossEntropy | |
| Loss | 1.28923 | |
| Accuracy | 80.08%(TOP1) | |
| Total time | 7.8 h 8ps | |
| Checkpoint for Fine tuning | 192 M(.ckpt file) | |
#### Inference Performance
| Parameters | | | |
| ----------------- | ---- | ---- | ---------------- |
| Resource | | | Ascend 910 |
| uploaded Date | | | 06/20/2021 |
| MindSpore Version | | | 1.2 |
| Dataset | | | ImageNet, 1.2W |
| batch_size | | | 1 |
| outputs | | | probability |
| Accuracy | | | acc=80.08%(TOP1) |
# [Description of Random Situation](#contents)
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,252 @@
# 目录
- [目录](#目录)
- [ResNeXt152说明](#resnext152说明)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [特性](#特性)
- [混合精度](#混合精度)
- [环境要求](#环境要求)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [用法](#用法)
- [样例](#样例)
- [评估过程](#评估过程)
- [用法](#用法-1)
- [样例](#样例-1)
- [结果](#结果)
- [模型导出](#模型导出)
- [模型描述](#模型描述)
- [性能](#性能)
- [训练性能](#训练性能)
- [推理性能](#推理性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
# ResNeXt152说明
ResNeXt是一个简单、高度模块化的图像分类网络架构。ResNeXt的设计为统一的、多分支的架构该架构仅需设置几个超参数。此策略提供了一个新维度我们将其称为“基数”转换集的大小它是深度和宽度维度之外的一个重要因素。
[论文](https://arxiv.org/abs/1611.05431) Xie S, Girshick R, Dollár, Piotr, et al. Aggregated Residual Transformations for Deep Neural Networks. 2016.
# 模型架构
ResNeXt整体网络架构如下
[链接](https://arxiv.org/abs/1611.05431)
# 数据集
使用的数据集:[ImageNet](http://www.image-net.org/)
- 数据集大小约125G, 共1000个类包含1.2万张彩色图像
- 训练集120G1.2万张图像
- 测试集5G5万张图像
- 数据格式RGB图像。
- 注数据在src/dataset.py中处理。
# 特性
## 混合精度
采用[混合精度](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
以FP16算子为例如果输入数据类型为FP32MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志搜索“reduce precision”查看精度降低的算子。
# 环境要求
- 硬件Ascend
- 准备Ascend处理器搭建硬件环境。如需试用昇腾处理器请发送[申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)至ascend@huawei.com审核通过即可获得资源。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
# 脚本说明
## 脚本及样例代码
```path
.
└─resnext152
├─README.md
├─scripts
├─run_standalone_train.sh # 启动Ascend单机训练单卡
├─run_distribute_train.sh # 启动Ascend分布式训练8卡
└─run_eval.sh # 启动评估
├─src
├─backbone
├─_init_.py # 初始化
├─resnet.py # ResNeXt152骨干
├─utils
├─_init_.py # 初始化
├─cunstom_op.py # 网络操作
├─logging.py # 打印日志
├─optimizers_init_.py # 获取参数
├─sampler.py # 分布式采样器
├─var_init_.py # 计算增益值
├─_init_.py # 初始化
├─config.py # 参数配置
├─crossentropy.py # 交叉熵损失函数
├─dataset.py # 数据预处理
├─eval_callback.py # 训练时推理
├─head.py # 常见头
├─image_classification.py # 获取ResNet
├─metric.py # 推理
├─linear_warmup.py # 线性热身学习率
├─warmup_cosine_annealing.py # 每次迭代的学习率
├─warmup_step_lr.py # 热身迭代学习率
├─eval.py # 评估网络
├──train.py # 训练网络
├──mindspore_hub_conf.py # MindSpore Hub接口
```
## 脚本参数
在config.py中可以同时配置训练和评估参数。
```python
"image_height": '224,224' # 图像大小
"num_classes": 1000, # 数据集类数
"per_batch_size": 128, # 输入张量的批次大小
"lr": 0.05, # 基础学习率
"lr_scheduler": 'cosine_annealing', # 学习率模式
"lr_epochs": '30,60,90,120', # LR变化轮次
"lr_gamma": 0.1, # 减少LR的exponential lr_scheduler因子
"eta_min": 0, # cosine_annealing调度器中的eta_min
"T_max": 150, # cosine_annealing调度器中的T-max
"max_epoch": 150, # 训练模型的最大轮次数量
"backbone": 'resnext152', # 骨干网络
"warmup_epochs" : 1, # 热身轮次
"weight_decay": 0.0001, # 权重衰减
"momentum": 0.9, # 动量
"is_dynamic_loss_scale": 0, # 动态损失放大
"loss_scale": 1024, # 损失放大
"label_smooth": 1, # 标签平滑
"label_smooth_factor": 0.1, # 标签平滑因子
"ckpt_interval": 2000, # 检查点间隔
"ckpt_path": 'outputs/', # 检查点保存位置
"is_save_on_master": 1,
"rank": 0, # 分布式本地进程序号
"group_size": 1 # 分布式进程总数
```
## 训练过程
### 用法
您可以通过python脚本开始训练
```shell
python train.py --data_dir ~/imagenet/train/ --platform Ascend --is_distributed 0
```
或通过shell脚本开始训练
```shell
Ascend:
# 分布式训练示例8卡
sh run_distribute_train.sh RANK_TABLE_FILE DATA_PATH
# 单机训练
sh run_standalone_train.sh DEVICE_ID DATA_PATH
```
### 样例
```shell
# Ascend分布式训练示例8卡
sh scripts/run_distribute_train.sh RANK_TABLE_FILE /dataset/train
# Ascend单机训练示例
sh scripts/run_standalone_train.sh 0 /dataset/train
```
您可以在日志中找到检查点文件和结果。
## 评估过程
### 用法
您可以通过python脚本开始训练
```shell
python eval.py --data_dir ~/imagenet/val/ --platform Ascend --pretrained resnext.ckpt
```
或通过shell脚本开始训练
```shell
# 评估
sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH PLATFORM
```
PLATFORM is Ascend, default is Ascend.
#### 样例
```shell
# 检查点评估
sh scripts/run_eval.sh 0 /opt/npu/datasets/classification/val /resnext152_100.ckpt Ascend
#或者直接使用脚本运行
python eval.py --data_dir /opt/npu/pvc/dataset/storage/imagenet/val/ --platform Ascend --pretrained /root/test/resnext152_64x4d/outputs_demo/best_acc_0.ckpt
```
#### 结果
评估结果保存在脚本路径下。您可以在日志中找到类似以下的结果。
```log
acc=80.08%(TOP1)
acc=94.71%(TOP5)
```
## 模型导出
```shell
python export.py --device_target [PLATFORM] --ckpt_file [CKPT_PATH] --file_format [EXPORT_FORMAT]
```
`EXPORT_FORMAT` 可选 ["AIR", "ONNX", "MINDIR"].
# 模型描述
## 性能
### 训练性能
| 参数 | ResNeXt152 | |
| -------------------------- | ---------------------------------------------------------- | ------------------------- |
| 资源 | Ascend 910CPU2.60GHz192核内存755GB | |
| 上传日期 | 2021-6-30 | |
| MindSpore版本 | 1.2 | |
| 数据集 | ImageNet | |
| 训练参数 | src/config.py | |
| 优化器 | Momentum | |
| 损失函数 | Softmax交叉熵 | |
| 损失 | 1.2892 | |
| 准确率 | 80.08%(TOP1) | |
| 总时长 | 7.8小时 8卡 | |
| 调优检查点 | 192 M.ckpt文件 | |
#### 推理性能
| 参数 | | | |
| -------------------------- | ----------------------------- | ------------------------- | -------------------- |
| 资源 | | | Ascend 910 |
| 上传日期 | | | 2021-6-20 |
| MindSpore版本 | | | 1.2 |
| 数据集 | | | ImageNet 1.2万 |
| batch_size | | | 1 |
| 输出 | | | 概率 |
| 准确率 | | | acc=80.08%(TOP1) |
# 随机情况说明
dataset.py中设置了“create_dataset”函数内的种子同时还使用了train.py中的随机种子。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,265 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Eval"""
import os
import time
import argparse
import datetime
import glob
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size, release
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from src.utils.logging import get_logger
from src.utils.auto_mixed_precision import auto_mixed_precision
from src.utils.var_init import load_pretrain_model
from src.image_classification import get_network
from src.dataset import classification_dataset
from src.config import config
class ParameterReduce(nn.Cell):
"""ParameterReduce"""
def __init__(self):
super(ParameterReduce, self).__init__()
self.cast = P.Cast()
self.reduce = P.AllReduce()
def construct(self, x):
one = self.cast(F.scalar_to_array(1.0), mstype.float32)
out = x * one
ret = self.reduce(out)
return ret
def parse_args(cloud_args=None):
"""parse_args"""
parser = argparse.ArgumentParser('mindspore classification test')
parser.add_argument('--platform',
type=str,
default='Ascend',
choices=('Ascend', 'GPU'),
help='run platform')
# dataset related
parser.add_argument('--data_dir',
type=str,
default='/opt/npu/datasets/classification/val',
help='eval data dir')
parser.add_argument('--per_batch_size',
default=32,
type=int,
help='batch size for per npu')
# network related
parser.add_argument('--graph_ckpt',
type=int,
default=1, help='graph ckpt or feed ckpt')
parser.add_argument('--pretrained',
default='',
type=str,
help='fully path of pretrained model to load. '
'If it is a direction, it will test all ckpt')
# logging related
parser.add_argument('--log_path', type=str, default='outputs/', help='path to save log')
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
# roma obs
parser.add_argument('--train_url', type=str, default="", help='train url')
args, _ = parser.parse_known_args()
args = merge_args(args, cloud_args)
args.image_size = config['image_size']
args.num_classes = config['num_classes']
args.rank = config['rank']
args.group_size = config['group_size']
args.image_size = list(map(int, args.image_size.split(',')))
# init distributed
if args.is_distributed:
if args.platform == "Ascend":
init()
elif args.platform == "GPU":
init("nccl")
args.rank = get_rank()
args.group_size = get_group_size()
else:
args.rank = 0
args.group_size = 1
args.outputs_dir = os.path.join(args.log_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
return args
def get_top5_acc(top5_arg, gt_class):
sub_count = 0
for top5, gt_file in zip(top5_arg, gt_class):
if gt_file in top5:
sub_count += 1
return sub_count
def merge_args(args, cloud_args):
"""merge_args"""
args_dict = vars(args)
if isinstance(cloud_args, dict):
for key in cloud_args.keys():
val = cloud_args[key]
if key in args_dict and val:
arg_type = type(args_dict[key])
if arg_type is not type(None):
val = arg_type(val)
args_dict[key] = val
return args
def get_result(args, model, top1_correct, top5_correct, img_tot):
"""calculate top1 and top5 value."""
results = [[top1_correct], [top5_correct], [img_tot]]
args.logger.info('before results={}'.format(results))
if args.is_distributed:
model_md5 = model.replace('/', '')
tmp_dir = '/cache'
if not os.path.exists(tmp_dir):
os.mkdir(tmp_dir)
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(args.rank, model_md5)
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(args.rank, model_md5)
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(args.rank, model_md5)
np.save(top1_correct_npy, top1_correct)
np.save(top5_correct_npy, top5_correct)
np.save(img_tot_npy, img_tot)
while True:
rank_ok = True
for other_rank in range(args.group_size):
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5)
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5)
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5)
if not os.path.exists(top1_correct_npy) or not os.path.exists(top5_correct_npy) or \
not os.path.exists(img_tot_npy):
rank_ok = False
if rank_ok:
break
top1_correct_all = 0
top5_correct_all = 0
img_tot_all = 0
for other_rank in range(args.group_size):
top1_correct_npy = '/cache/top1_rank_{}_{}.npy'.format(other_rank, model_md5)
top5_correct_npy = '/cache/top5_rank_{}_{}.npy'.format(other_rank, model_md5)
img_tot_npy = '/cache/img_tot_rank_{}_{}.npy'.format(other_rank, model_md5)
top1_correct_all += np.load(top1_correct_npy)
top5_correct_all += np.load(top5_correct_npy)
img_tot_all += np.load(img_tot_npy)
results = [[top1_correct_all], [top5_correct_all], [img_tot_all]]
results = np.array(results)
else:
results = np.array(results)
args.logger.info('after results={}'.format(results))
return results
def test(cloud_args=None):
"""test"""
args = parse_args(cloud_args)
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target=args.platform, save_graphs=False)
if os.getenv('DEVICE_ID', "not_set").isdigit():
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
# init distributed
if args.is_distributed:
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
gradients_mean=True)
args.logger.save_args(args)
# network
args.logger.important_info('start create network')
if os.path.isdir(args.pretrained):
models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt')))
print(models)
if args.graph_ckpt:
f_key = (lambda x: -1 *
int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split('_')[0]))
else:
f_key = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1])
args.models = sorted(models, key=f_key)
else:
args.models = [args.pretrained,]
for model in args.models:
de_dataset = classification_dataset(args.data_dir, image_size=args.image_size,
per_batch_size=args.per_batch_size,
max_epoch=1, rank=args.rank, group_size=args.group_size,
mode='eval')
eval_dataloader = de_dataset.create_tuple_iterator(output_numpy=True, num_epochs=1)
network = get_network(num_classes=args.num_classes, platform=args.platform)
load_pretrain_model(model, network, args)
img_tot = 0
top1_correct = 0
top5_correct = 0
if args.platform == "Ascend":
network.to_float(mstype.float16)
else:
auto_mixed_precision(network)
network.set_train(False)
t_end = time.time()
it_name = 0
for data, gt_classes in eval_dataloader:
output = network(Tensor(data, mstype.float32))
output = output.asnumpy()
top1_output = np.argmax(output, (-1))
top5_output = np.argsort(output)[:, -5:]
t1_correct = np.equal(top1_output, gt_classes).sum()
top1_correct += t1_correct
top5_correct += get_top5_acc(top5_output, gt_classes)
img_tot += args.per_batch_size
if args.rank == 0 and it_name == 0:
t_end = time.time()
it_name = 1
if args.rank == 0:
time_used = time.time() - t_end
fps = (img_tot - args.per_batch_size) * args.group_size / time_used
args.logger.info('Inference Performance: {:.2f} img/sec'.format(fps))
results = get_result(args, model, top1_correct, top5_correct, img_tot)
top1_correct = results[0, 0]
top5_correct = results[1, 0]
img_tot = results[2, 0]
acc1 = 100.0 * top1_correct / img_tot
acc5 = 100.0 * top5_correct / img_tot
args.logger.info('after allreduce eval: top1_correct={}, tot={},'
'acc={:.2f}%(TOP1)'.format(top1_correct, img_tot, acc1))
args.logger.info('after allreduce eval: top5_correct={}, tot={},'
'acc={:.2f}%(TOP5)'.format(top5_correct, img_tot, acc5))
if args.is_distributed:
release()
if __name__ == "__main__":
test()

View File

@ -0,0 +1,53 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
resnext export mindir.
"""
import argparse
import numpy as np
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.config import config
from src.image_classification import get_network
parser = argparse.ArgumentParser(description='checkpoint export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True,
help="Checkpoint file path.")
parser.add_argument('--width', type=int, default=224, help='input width')
parser.add_argument('--height', type=int, default=224, help='input height')
parser.add_argument("--file_name", type=str,
default="resnext152", help="output file name.")
parser.add_argument("--file_format", type=str,
choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
parser.add_argument("--device_target", type=str, default="Ascend",
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
net = get_network(num_classes=config['num_classes'],
platform=args.device_target)
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
input_shp = [args.batch_size, 3, args.height, args.width]
input_array = Tensor(
np.random.uniform(-1.0, 1.0, size=input_shp).astype(np.float32))
export(net, input_array, file_name=args.file_name,
file_format=args.file_format)

View File

@ -0,0 +1,57 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR=$2
export RANK_TABLE_FILE=$1
export RANK_SIZE=8
export HCCL_CONNECT_TIMEOUT=600
echo "hccl connect time out has changed to 600 second"
PATH_CHECKPOINT=""
if [ $# == 3 ]
then
PATH_CHECKPOINT=$3
fi
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "the number of logical core" $cores
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
core_gap=`expr $avg_core_per_rank \- 1`
echo "avg_core_per_rank" $avg_core_per_rank
echo "core_gap" $core_gap
for((i=0;i<RANK_SIZE;i++))
do
start=`expr $i \* $avg_core_per_rank`
export DEVICE_ID=$i
export RANK_ID=$i
export DEPLOY_MODE=0
export GE_USE_STATIC_MEMORY=1
end=`expr $start \+ $core_gap`
cmdopt=$start"-"$end
rm -rf LOG_less$i
mkdir ./LOG_less$i
cp *.py ./LOG_less$i
cd ./LOG_less$i || exit
echo "start training for rank $i, device $DEVICE_ID"
env > env_less.log
taskset -c $cmdopt python ../train.py \
--is_distribute=1 \
--device_id=$DEVICE_ID \
--pretrained=$PATH_CHECKPOINT \
--data_dir=$DATA_DIR > log_less.txt 2>&1 &
cd ../
done

View File

@ -0,0 +1,29 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
DATA_DIR=$2
PATH_CHECKPOINT=$3
PLATFORM=Ascend
if [ $# == 4 ]
then
PLATFORM=$4
fi
python eval.py \
--pretrained=$PATH_CHECKPOINT \
--platform=$PLATFORM \
--data_dir=$DATA_DIR > log.txt 2>&1 &

View File

@ -0,0 +1,30 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
DATA_DIR=$2
PATH_CHECKPOINT=""
if [ $# == 3 ]
then
PATH_CHECKPOINT=$3
fi
python train.py \
--is_distribute=0 \
--device_id=$DEVICE_ID \
--pretrained=$PATH_CHECKPOINT \
--data_dir=$DATA_DIR > log.txt 2>&1 &

View File

@ -0,0 +1,16 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""resnet"""
from .resnet import *

View File

@ -0,0 +1,302 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
ResNet based ResNext
"""
import mindspore.nn as nn
from mindspore.ops.operations import TensorAdd, Split, Concat
from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal
from src.utils.cunstom_op import SEBlock, GroupConv
__all__ = ['ResNet', 'resnext50', 'resnext101', 'resnext152']
def weight_variable(shape, factor=0.1):
return TruncatedNormal(0.02)
def conv7x7(in_channels, out_channels, stride=1, padding=3, has_bias=False, groups=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=stride, has_bias=has_bias,
padding=padding, pad_mode="pad", group=groups)
def conv3x3(in_channels, out_channels, stride=1, padding=1, has_bias=False, groups=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, has_bias=has_bias,
padding=padding, pad_mode="pad", group=groups)
def conv1x1(in_channels, out_channels, stride=1, padding=0, has_bias=False, groups=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, has_bias=has_bias,
padding=padding, pad_mode="pad", group=groups)
class _DownSample(nn.Cell):
"""
Downsample for ResNext-ResNet.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
stride (int): Stride size for the 1*1 convolutional layer.
Returns:
Tensor, output tensor.
Examples:
>>>DownSample(32, 64, 2)
"""
def __init__(self, in_channels, out_channels, stride):
super(_DownSample, self).__init__()
self.conv = conv1x1(in_channels, out_channels,
stride=stride, padding=0)
self.bn = nn.BatchNorm2d(out_channels)
def construct(self, x):
out = self.conv(x)
out = self.bn(out)
return out
class BasicBlock(nn.Cell):
"""
ResNet basic block definition.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
stride (int): Stride size for the first convolutional layer. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>>BasicBlock(32, 256, stride=2)
"""
expansion = 1
def __init__(self, in_channels, out_channels, stride=1, down_sample=None, use_se=False,
platform="Ascend", **kwargs):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride=stride)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = P.ReLU()
self.conv2 = conv3x3(out_channels, out_channels, stride=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.use_se = use_se
if self.use_se:
self.se = SEBlock(out_channels)
self.down_sample_flag = False
if down_sample is not None:
self.down_sample = down_sample
self.down_sample_flag = True
self.add = TensorAdd()
def construct(self, x):
"""describe network construct"""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.use_se:
out = self.se(out)
if self.down_sample_flag:
identity = self.down_sample(x)
out = self.add(out, identity)
out = self.relu(out)
return out
class Bottleneck(nn.Cell):
"""
ResNet Bottleneck block definition.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
stride (int): Stride size for the initial convolutional layer. Default: 1.
Returns:
Tensor, the ResNet unit's output.
Examples:
>>>Bottleneck(3, 256, stride=2)
"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1, down_sample=None,
base_width=64, groups=1, use_se=False, platform="Ascend", **kwargs):
super(Bottleneck, self).__init__()
width = int(out_channels * (base_width / 64.0)) * groups
self.groups = groups
self.conv1 = conv1x1(in_channels, width, stride=1)
self.bn1 = nn.BatchNorm2d(width)
self.relu = P.ReLU()
self.conv3x3s = nn.CellList()
if platform == "GPU":
self.conv2 = nn.Conv2d(
width, width, 3, stride, pad_mode='pad', padding=1, group=groups)
else:
self.conv2 = GroupConv(
width, width, 3, stride, pad=1, groups=groups)
self.op_split = Split(axis=1, output_num=self.groups)
self.op_concat = Concat(axis=1)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = conv1x1(width, out_channels * self.expansion, stride=1)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
self.use_se = use_se
if self.use_se:
self.se = SEBlock(out_channels * self.expansion)
self.down_sample_flag = False
if down_sample is not None:
self.down_sample = down_sample
self.down_sample_flag = True
self.cast = P.Cast()
self.add = TensorAdd()
def construct(self, x):
"""describe network construct"""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.use_se:
out = self.se(out)
if self.down_sample_flag:
identity = self.down_sample(x)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResNet(nn.Cell):
"""
ResNet architecture.
Args:
block (cell): Block for network.
layers (list): Numbers of block in different layers.
width_per_group (int): Width of every group.
groups (int): Groups number.
Returns:
Tuple, output tensor tuple.
Examples:
>>>ResNet()
"""
def __init__(self, block, layers, width_per_group=64, groups=1, use_se=False, platform="Ascend"):
super(ResNet, self).__init__()
self.in_channels = 64
self.groups = groups
self.base_width = width_per_group
self.conv = conv7x7(3, self.in_channels, stride=2, padding=3)
self.bn = nn.BatchNorm2d(self.in_channels)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.layer1 = self._make_layer(
block, 64, layers[0], use_se=use_se, platform=platform)
self.layer2 = self._make_layer(
block, 128, layers[1], stride=2, use_se=use_se, platform=platform)
self.layer3 = self._make_layer(
block, 256, layers[2], stride=2, use_se=use_se, platform=platform)
self.layer4 = self._make_layer(
block, 512, layers[3], stride=2, use_se=use_se, platform=platform)
self.out_channels = 512 * block.expansion
self.cast = P.Cast()
def construct(self, x):
"""describe network construct"""
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def _make_layer(self, block, out_channels, blocks_num, stride=1, use_se=False, platform="Ascend"):
"""_make_layer"""
down_sample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
down_sample = _DownSample(self.in_channels,
out_channels * block.expansion,
stride=stride)
layers = []
layers.append(block(self.in_channels,
out_channels,
stride=stride,
down_sample=down_sample,
base_width=self.base_width,
groups=self.groups,
use_se=use_se,
platform=platform))
self.in_channels = out_channels * block.expansion
for _ in range(1, blocks_num):
layers.append(block(self.in_channels, out_channels, base_width=self.base_width,
groups=self.groups, use_se=use_se, platform=platform))
return nn.SequentialCell(layers)
def get_out_channels(self):
return self.out_channels
def resnext50(platform="Ascend"):
return ResNet(Bottleneck, [3, 4, 6, 3], width_per_group=4, groups=32, platform=platform)
def resnext101(platform="Ascend"):
return ResNet(Bottleneck, [3, 4, 23, 3], width_per_group=4, groups=64, platform=platform)
def resnext152(platform="Ascend"):
return ResNet(Bottleneck, [3, 8, 36, 3], width_per_group=4, groups=64, platform=platform)

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""config"""
config = {
"image_size": '224,224',
"num_classes": 1000,
"lr": 0.4,
"lr_scheduler": 'cosine_annealing',
"lr_epochs": '30,60,90,120',
"lr_gamma": 0.1,
"eta_min": 0,
"T_max": 150,
"max_epoch": 150,
"warmup_epochs": 1,
"weight_decay": 0.0001,
"momentum": 0.9,
"is_dynamic_loss_scale": 0,
"loss_scale": 1024,
"label_smooth": 1,
"label_smooth_factor": 0.1,
"ckpt_interval": 2,
"ckpt_save_max": 30,
"ckpt_path": 'outputs_demo/',
"is_save_on_master": 1,
"rank": 0,
"group_size": 1
}

View File

@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
define loss function for network.
"""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
"""
the redefined loss function with SoftmaxCrossEntropyWithLogits.
"""
def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes -1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)
return loss

View File

@ -0,0 +1,158 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
dataset processing.
"""
import os
from mindspore.common import dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as V_C
from PIL import Image, ImageFile
from src.utils.sampler import DistributedSampler
ImageFile.LOAD_TRUNCATED_IMAGES = True
class TxtDataset():
"""
create txt dataset.
Args:
Returns:
de_dataset.
"""
def __init__(self, root, txt_name):
super(TxtDataset, self).__init__()
self.imgs = []
self.labels = []
fin = open(txt_name, "r")
for line in fin:
img_name, label = line.strip().split(' ')
self.imgs.append(os.path.join(root, img_name))
self.labels.append(int(label))
fin.close()
def __getitem__(self, index):
img = Image.open(self.imgs[index]).convert('RGB')
return img, self.labels[index]
def __len__(self):
return len(self.imgs)
def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank, group_size,
mode='train',
input_mode='folder',
root='',
num_parallel_workers=None,
shuffle=None,
sampler=None,
class_indexing=None,
drop_remainder=True,
transform=None,
target_transform=None):
"""
A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt".
If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images
are written into a textfile.
Args:
data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"".
Or path of the textfile that contains every image's path of the dataset.
image_size (Union(int, sequence)): Size of the input images.
per_batch_size (int): the batch size of evey step during training.
max_epoch (int): the number of epochs.
rank (int): The shard ID within num_shards (default=None).
group_size (int): Number of shards that the dataset should be divided
into (default=None).
mode (str): "train" or others. Default: " train".
input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder".
root (str): the images path for "input_mode="txt"". Default: " ".
num_parallel_workers (int): Number of workers to read the data. Default: None.
shuffle (bool): Whether or not to perform shuffle on the dataset
(default=None, performs shuffle).
sampler (Sampler): Object used to choose samples from the dataset. Default: None.
class_indexing (dict): A str-to-int mapping from folder name to index
(default=None, the folder names will be sorted
alphabetically and each class will be given a
unique index starting from 0).
Examples:
>>> from src.dataset import classification_dataset
>>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images
>>> data_dir = "/path/to/imagefolder_directory"
>>> de_dataset = classification_dataset(data_dir, image_size=[224, 244],
>>> per_batch_size=64, max_epoch=100,
>>> rank=0, group_size=4)
>>> # Path of the textfile that contains every image's path of the dataset.
>>> data_dir = "/path/to/dataset/images/train.txt"
>>> images_dir = "/path/to/dataset/images"
>>> de_dataset = classification_dataset(data_dir, image_size=[224, 244],
>>> per_batch_size=64, max_epoch=100,
>>> rank=0, group_size=4,
>>> input_mode="txt", root=images_dir)
"""
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
if transform is None:
if mode == 'train':
transform_img = [
V_C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
V_C.RandomHorizontalFlip(prob=0.5),
V_C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
V_C.Normalize(mean=mean, std=std),
V_C.HWC2CHW()
]
else:
transform_img = [
V_C.Decode(),
V_C.Resize((256, 256)),
V_C.CenterCrop(image_size),
V_C.Normalize(mean=mean, std=std),
V_C.HWC2CHW()
]
else:
transform_img = transform
if target_transform is None:
transform_label = [C.TypeCast(mstype.int32)]
else:
transform_label = target_transform
if input_mode == 'folder':
de_dataset = de.ImageFolderDataset(data_dir, num_parallel_workers=num_parallel_workers,
shuffle=shuffle, sampler=sampler, class_indexing=class_indexing,
num_shards=group_size, shard_id=rank)
else:
dataset = TxtDataset(root, data_dir)
sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle)
de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler)
de_dataset = de_dataset.map(operations=transform_img, input_columns="image",
num_parallel_workers=num_parallel_workers)
de_dataset = de_dataset.map(operations=transform_label, input_columns="label",
num_parallel_workers=num_parallel_workers)
columns_to_project = ["image", "label"]
de_dataset = de_dataset.project(columns=columns_to_project)
de_dataset = de_dataset.batch(per_batch_size, drop_remainder=drop_remainder)
de_dataset = de_dataset.repeat(max_epoch)
return de_dataset

View File

@ -0,0 +1,95 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Evaluation callback when training"""
import os
import stat
import time
from mindspore import save_checkpoint
from mindspore import log as logger
from mindspore.train.callback import Callback
class EvalCallBack(Callback):
"""
Evaluation callback when training.
Args:
eval_function (function): evaluation function.
eval_param_dict (dict): evaluation parameters' configure dict.
interval (int): run evaluation interval, default is 1.
eval_start_epoch (int): evaluation start epoch, default is 1.
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
best_ckpt_name (str): bast checkpoint name, default is `best.ckpt`.
metrics_name (str): evaluation metrics name, default is `acc`.
Returns:
None
Examples:
>>> EvalCallBack(eval_function, eval_param_dict)
"""
def __init__(self, eval_function, eval_param_dict, interval=1, eval_start_epoch=1, save_best_ckpt=True,
ckpt_directory="./", best_ckpt_name="best.ckpt", metrics_name="acc"):
super(EvalCallBack, self).__init__()
self.eval_param_dict = eval_param_dict
self.eval_function = eval_function
self.eval_start_epoch = eval_start_epoch
if interval < 1:
raise ValueError("interval should >= 1.")
self.interval = interval
self.save_best_ckpt = save_best_ckpt
self.best_res = 0
self.best_epoch = 0
if not os.path.isdir(ckpt_directory):
os.makedirs(ckpt_directory)
self.bast_ckpt_path = os.path.join(ckpt_directory, best_ckpt_name)
self.metrics_name = metrics_name
def remove_ckpoint_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
def epoch_end(self, run_context):
"""Callback when epoch end."""
cb_params = run_context.original_args()
cur_epoch = cb_params.cur_epoch_num
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
eval_start = time.time()
res = self.eval_function(self.eval_param_dict)
eval_cost = time.time() - eval_start
print("epoch: {}, {}: {}, eval_cost:{:.2f}".format(cur_epoch, self.metrics_name, res, eval_cost),
flush=True)
if res >= self.best_res:
self.best_res = res
self.best_epoch = cur_epoch
print("update best result: {}".format(res), flush=True)
if self.save_best_ckpt:
if os.path.exists(self.bast_ckpt_path):
self.remove_ckpoint_file(self.bast_ckpt_path)
save_checkpoint(cb_params.train_network, self.bast_ckpt_path)
print("update best checkpoint at: {}".format(self.bast_ckpt_path), flush=True)
def end(self, run_context):
print("End training, the best {0} is: {1}, the best {0} epoch is {2}".format(self.metrics_name,
self.best_res,
self.best_epoch), flush=True)

View File

@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
common architecture.
"""
import mindspore.nn as nn
from src.utils.cunstom_op import GlobalAvgPooling
__all__ = ['CommonHead']
class CommonHead(nn.Cell):
"""
common architecture definition.
Args:
num_classes (int): Number of classes.
out_channels (int): Output channels.
Returns:
Tensor, output tensor.
"""
def __init__(self, num_classes, out_channels):
super(CommonHead, self).__init__()
self.avgpool = GlobalAvgPooling()
self.fc = nn.Dense(out_channels, num_classes, has_bias=True).add_flags_recursive(fp16=True)
def construct(self, x):
x = self.avgpool(x)
x = self.fc(x)
return x

View File

@ -0,0 +1,104 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Image classifiation.
"""
import math
import mindspore.nn as nn
from mindspore.common import initializer as init
import src.backbone as backbones
import src.head as heads
from src.utils.var_init import default_recurisive_init, KaimingNormal
class ImageClassificationNetwork(nn.Cell):
"""
architecture of image classification network.
Args:
Returns:
Tensor, output tensor.
"""
def __init__(self, backbone, head, include_top=True, activation="None"):
super(ImageClassificationNetwork, self).__init__()
self.backbone = backbone
self.include_top = include_top
self.need_activation = False
if self.include_top:
self.head = head
if activation != "None":
self.need_activation = True
if activation == "Sigmoid":
self.activation = P.Sigmoid()
elif activation == "Softmax":
self.activation = P.Softmax()
else:
raise NotImplementedError(
f"The activation {activation} not in [Sigmoid, Softmax].")
def construct(self, x):
x = self.backbone(x)
if self.include_top:
x = self.head(x)
if self.need_activation:
x = self.activation(x)
return x
class Resnet(ImageClassificationNetwork):
"""
Resnet architecture.
Args:
backbone_name (string): backbone.
num_classes (int): number of classes, Default is 1000.
Returns:
Resnet.
"""
def __init__(self, backbone_name, num_classes=1000, platform="Ascend", include_top=True, activation="None"):
self.backbone_name = backbone_name
backbone = backbones.__dict__[self.backbone_name](platform=platform)
out_channels = backbone.get_out_channels()
head = heads.CommonHead(num_classes=num_classes,
out_channels=out_channels)
super(Resnet, self).__init__(backbone, head, include_top, activation)
default_recurisive_init(self)
for cell in self.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(init.initializer(
KaimingNormal(a=math.sqrt(5), mode='fan_out',
nonlinearity='relu'),
cell.weight.shape, cell.weight.dtype))
elif isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
cell.beta.set_data(init.initializer('zeros', cell.beta.shape))
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
for cell in self.cells_and_names():
if isinstance(cell, backbones.resnet.Bottleneck):
cell.bn3.gamma.set_data(init.initializer(
'zeros', cell.bn3.gamma.shape))
elif isinstance(cell, backbones.resnet.BasicBlock):
cell.bn2.gamma.set_data(init.initializer(
'zeros', cell.bn2.gamma.shape))
def get_network(**kwargs):
return Resnet('resnext152', **kwargs)

View File

@ -0,0 +1,142 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
learning rate generator.
"""
import math
from collections import Counter
import numpy as np
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
"""
Applies liner Increasing to generate learning rate array in warmup stage.
Args:
current_step(int): current step in warmup stage.
warmup_steps(int): all steps in warmup stage.
base_lr(float): init learning rate.
init_lr(float): end learning rate
Returns:
float, learning rate.
"""
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
"""
Applies cosine decay to generate learning rate array with warmup.
Args:
lr(float): init learning rate
steps_per_epoch(int): steps of one epoch
warmup_epochs(int): number of warmup epochs
max_epoch(int): total epoch of training
T_max(int): max epoch in decay.
eta_min(float): end learning rate
Returns:
np.array, learning rate array.
"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
lr_each_step = []
for i in range(total_steps):
last_epoch = i // steps_per_epoch
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
"""
Applies step decay to generate learning rate array with warmup.
Args:
lr(float): init learning rate
lr_epochs(list): learning rate decay epoches list
steps_per_epoch(int): steps of one epoch
warmup_epochs(int): number of warmup epochs
max_epoch(int): total epoch of training
gamma(float): attenuation constants.
Returns:
np.array, learning rate array.
"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
milestones = lr_epochs
milestones_steps = []
for milestone in milestones:
milestones_step = milestone * steps_per_epoch
milestones_steps.append(milestones_step)
lr_each_step = []
lr = base_lr
milestones_steps_counter = Counter(milestones_steps)
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
lr = lr * gamma**milestones_steps_counter[i]
lr_each_step.append(lr)
return np.array(lr_each_step).astype(np.float32)
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
lr_epochs = []
for i in range(1, max_epoch):
if i % epoch_size == 0:
lr_epochs.append(i)
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
def get_lr(args):
"""generate learning rate array."""
if args.lr_scheduler == 'exponential':
lr = warmup_step_lr(args.lr,
args.lr_epochs,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
gamma=args.lr_gamma,
)
elif args.lr_scheduler == 'cosine_annealing':
lr = warmup_cosine_annealing_lr(args.lr,
args.steps_per_epoch,
args.warmup_epochs,
args.max_epoch,
args.T_max,
args.eta_min)
else:
raise NotImplementedError(args.lr_scheduler)
return lr

View File

@ -0,0 +1,132 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""evaluation metric."""
from mindspore.communication.management import GlobalComm
from mindspore.ops import operations as P
import mindspore.nn as nn
import mindspore.common.dtype as mstype
class ClassifyCorrectCell(nn.Cell):
r"""
Cell that returns correct count of the prediction in classification network.
This Cell accepts a network as arguments.
It returns orrect count of the prediction to calculate the metrics.
Args:
network (Cell): The network Cell.
Inputs:
- **data** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
- **label** (Tensor) - Tensor of shape :math:`(N, \ldots)`.
Outputs:
Tuple, containing a scalar correct count of the prediction
Examples:
>>> # For a defined network Net without loss function
>>> net = Net()
>>> eval_net = nn.ClassifyCorrectCell(net)
"""
def __init__(self, network):
super(ClassifyCorrectCell, self).__init__(auto_prefix=False)
self._network = network
self.argmax = P.Argmax()
self.equal = P.Equal()
self.cast = P.Cast()
self.reduce_sum = P.ReduceSum()
self.allreduce = P.AllReduce(P.ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
def construct(self, data, label):
outputs = self._network(data)
y_pred = self.argmax(outputs)
y_pred = self.cast(y_pred, mstype.int32)
y_correct = self.equal(y_pred, label)
y_correct = self.cast(y_correct, mstype.float32)
y_correct = self.reduce_sum(y_correct)
total_correct = self.allreduce(y_correct)
return (total_correct,)
class DistAccuracy(nn.Metric):
r"""
Calculates the accuracy for classification data in distributed mode.
The accuracy class creates two local variables, correct number and total number that are used to compute the
frequency with which predictions matches labels. This frequency is ultimately returned as the accuracy: an
idempotent operation that simply divides correct number by total number.
.. math::
\text{accuracy} =\frac{\text{true_positive} + \text{true_negative}}
{\text{true_positive} + \text{true_negative} + \text{false_positive} + \text{false_negative}}
Args:
eval_type (str): Metric to calculate the accuracy over a dataset, for classification (single-label).
Examples:
>>> y_correct = Tensor(np.array([20]))
>>> metric = nn.DistAccuracy(batch_size=3, device_num=8)
>>> metric.clear()
>>> metric.update(y_correct)
>>> accuracy = metric.eval()
"""
def __init__(self, batch_size, device_num):
super(DistAccuracy, self).__init__()
self.clear()
self.batch_size = batch_size
self.device_num = device_num
def clear(self):
"""Clears the internal evaluation result."""
self._correct_num = 0
self._total_num = 0
def update(self, *inputs):
"""
Updates the internal evaluation result :math:`y_{pred}` and :math:`y`.
Args:
inputs: Input `y_correct`. `y_correct` is a `scalar Tensor`.
`y_correct` is the right prediction count that gathered from all devices
it's a scalar in float type
Raises:
ValueError: If the number of the input is not 1.
"""
if len(inputs) != 1:
raise ValueError('Distribute accuracy needs 1 input (y_correct), but got {}'.format(len(inputs)))
y_correct = self._convert_data(inputs[0])
self._correct_num += y_correct
self._total_num += self.batch_size * self.device_num
def eval(self):
"""
Computes the accuracy.
Returns:
Float, the computed result.
Raises:
RuntimeError: If the sample size is 0.
"""
if self._total_num == 0:
raise RuntimeError('Accuracy can not be calculated, because the number of samples is 0.')
return self._correct_num / self._total_num

View File

@ -0,0 +1,53 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Auto mixed precision."""
import mindspore.nn as nn
from mindspore.ops import functional as F
from mindspore._checkparam import Validator as validator
from mindspore.common import dtype as mstype
class OutputTo(nn.Cell):
"Cast cell output back to float16 or float32"
def __init__(self, op, to_type=mstype.float16):
super(OutputTo, self).__init__(auto_prefix=False)
self._op = op
validator.check_type_name('to_type', to_type, [mstype.float16, mstype.float32], None)
self.to_type = to_type
def construct(self, x):
return F.cast(self._op(x), self.to_type)
def auto_mixed_precision(network):
"""Do keep batchnorm fp32."""
cells = network.name_cells()
change = False
network.to_float(mstype.float16)
for name in cells:
subcell = cells[name]
if subcell == network:
continue
elif name == 'fc':
network.insert_child_to_cell(name, OutputTo(subcell, mstype.float32))
change = True
elif isinstance(subcell, (nn.BatchNorm2d, nn.BatchNorm1d)):
network.insert_child_to_cell(name, OutputTo(subcell.to_float(mstype.float32), mstype.float16))
change = True
else:
auto_mixed_precision(subcell)
if isinstance(network, nn.SequentialCell) and change:
network.cell_list = list(network.cells())

View File

@ -0,0 +1,105 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network operations
"""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
class GlobalAvgPooling(nn.Cell):
"""
global average pooling feature map.
Args:
mean (tuple): means for each channel.
"""
def __init__(self):
super(GlobalAvgPooling, self).__init__()
self.mean = P.ReduceMean(False)
def construct(self, x):
x = self.mean(x, (2, 3))
return x
class SEBlock(nn.Cell):
"""
squeeze and excitation block.
Args:
channel (int): number of feature maps.
reduction (int): weight.
"""
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = GlobalAvgPooling()
self.fc1 = nn.Dense(channel, channel // reduction)
self.relu = P.ReLU()
self.fc2 = nn.Dense(channel // reduction, channel)
self.sigmoid = P.Sigmoid()
self.reshape = P.Reshape()
self.shape = P.Shape()
self.sum = P.Sum()
self.cast = P.Cast()
def construct(self, x):
"""describe network construct"""
b, c = self.shape(x)
y = self.avg_pool(x)
y = self.reshape(y, (b, c))
y = self.fc1(y)
y = self.relu(y)
y = self.fc2(y)
y = self.sigmoid(y)
y = self.reshape(y, (b, c, 1, 1))
return x * y
class GroupConv(nn.Cell):
"""
group convolution operation.
Args:
in_channels (int): Input channels of feature map.
out_channels (int): Output channels of feature map.
kernel_size (int): Size of convolution kernel.
stride (int): Stride size for the group convolution layer.
Returns:
tensor, output tensor.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode="pad", pad=0, groups=1, has_bias=False):
super(GroupConv, self).__init__()
assert in_channels % groups == 0 and out_channels % groups == 0
self.groups = groups
self.convs = nn.CellList()
self.op_split = P.Split(axis=1, output_num=self.groups)
self.op_concat = P.Concat(axis=1)
self.cast = P.Cast()
for _ in range(groups):
self.convs.append(nn.Conv2d(in_channels//groups, out_channels//groups,
kernel_size=kernel_size, stride=stride, has_bias=has_bias,
padding=pad, pad_mode=pad_mode, group=1))
def construct(self, x):
features = self.op_split(x)
outputs = ()
for i in range(self.groups):
outputs = outputs + (self.convs[i](self.cast(features[i], mstype.float32)),)
out = self.op_concat(outputs)
return out

View File

@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
get logger.
"""
import logging
import os
import sys
from datetime import datetime
class LOGGER(logging.Logger):
"""
set up logging file.
Args:
logger_name (string): logger name.
log_dir (string): path of logger.
Returns:
string, logger path
"""
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
"""set up log file"""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
logger = LOGGER("mindversion", rank)
logger.setup_logging_file(path, rank)
return logger

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
optimizer parameters.
"""
def get_param_groups(network):
"""get param groups"""
decay_params = []
no_decay_params = []
for x in network.trainable_params():
parameter_name = x.name
if parameter_name.endswith('.bias'):
# all bias not using weight decay
no_decay_params.append(x)
elif parameter_name.endswith('.gamma'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
elif parameter_name.endswith('.beta'):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params.append(x)
else:
decay_params.append(x)
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]

View File

@ -0,0 +1,53 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
choose samples from the dataset
"""
import math
import numpy as np
class DistributedSampler():
"""
sampling the dataset.
Args:
Returns:
num_samples, number of samples.
"""
def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
self.dataset = dataset
self.rank = rank
self.group_size = group_size
self.dataset_length = len(self.dataset)
self.num_samples = int(math.ceil(self.dataset_length * 1.0 / self.group_size))
self.total_size = self.num_samples * self.group_size
self.shuffle = shuffle
self.seed = seed
def __iter__(self):
if self.shuffle:
self.seed = (self.seed + 1) & 0xffffffff
np.random.seed(self.seed)
indices = np.random.permutation(self.dataset_length).tolist()
else:
indices = list(range(len(self.dataset_length)))
indices += indices[:(self.total_size - len(indices))]
indices = indices[self.rank::self.group_size]
return iter(indices)
def __len__(self):
return self.num_samples

View File

@ -0,0 +1,228 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Initialize.
"""
import os
import math
from functools import reduce
import numpy as np
import mindspore.nn as nn
from mindspore.common import initializer as init
from mindspore.train.serialization import load_checkpoint, load_param_into_net
def _calculate_gain(nonlinearity, param=None):
r"""
Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ====================================================
Args:
nonlinearity: the non-linear function
param: optional parameter for the non-linear function
Examples:
>>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
if nonlinearity == 'tanh':
return 5.0 / 3
if nonlinearity == 'relu':
return math.sqrt(2.0)
if nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
def _assignment(arr, num):
"""Assign the value of `num` to `arr`."""
if arr.shape == ():
arr = arr.reshape((1))
arr[:] = num
arr = arr.reshape(())
else:
if isinstance(num, np.ndarray):
arr[:] = num[:]
else:
arr[:] = num
return arr
def _calculate_in_and_out(arr):
"""
Calculate n_in and n_out.
Args:
arr (Array): Input array.
Returns:
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
"""
dim = len(arr.shape)
if dim < 2:
raise ValueError("If initialize data with xavier uniform, the dimension of data must greater than 1.")
n_in = arr.shape[1]
n_out = arr.shape[0]
if dim > 2:
counter = reduce(lambda x, y: x * y, arr.shape[2:])
n_in *= counter
n_out *= counter
return n_in, n_out
def _select_fan(array, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_in_and_out(array)
return fan_in if mode == 'fan_in' else fan_out
class KaimingInit(init.Initializer):
r"""
Base Class. Initialize the array with He kaiming algorithm.
Args:
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function, recommended to use only with
``'relu'`` or ``'leaky_relu'`` (default).
"""
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu'):
super(KaimingInit, self).__init__()
self.mode = mode
self.gain = _calculate_gain(nonlinearity, a)
def _initialize(self, arr):
pass
class KaimingUniform(KaimingInit):
r"""
Initialize the array with He kaiming uniform algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingUniform(w, mode='fan_in', nonlinearity='relu')
"""
def _initialize(self, arr):
fan = _select_fan(arr, self.mode)
bound = math.sqrt(3.0) * self.gain / math.sqrt(fan)
data = np.random.uniform(-bound, bound, arr.shape)
_assignment(arr, data)
class KaimingNormal(KaimingInit):
r"""
Initialize the array with He kaiming normal algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingNormal(w, mode='fan_out', nonlinearity='relu')
"""
def _initialize(self, arr):
fan = _select_fan(arr, self.mode)
std = self.gain / math.sqrt(fan)
data = np.random.normal(0, std, arr.shape)
_assignment(arr, data)
def default_recurisive_init(custom_cell):
"""default_recurisive_init"""
for _, cell in custom_cell.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.shape,
cell.weight.dtype))
if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight)
bound = 1 / math.sqrt(fan_in)
cell.bias.set_data(init.initializer(init.Uniform(bound),
cell.bias.shape,
cell.bias.dtype))
elif isinstance(cell, nn.Dense):
cell.weight.set_data(init.initializer(KaimingUniform(a=math.sqrt(5)),
cell.weight.shape,
cell.weight.dtype))
if cell.bias is not None:
fan_in, _ = _calculate_in_and_out(cell.weight)
bound = 1 / math.sqrt(fan_in)
cell.bias.set_data(init.initializer(init.Uniform(bound),
cell.bias.shape,
cell.bias.dtype))
elif isinstance(cell, (nn.BatchNorm2d, nn.BatchNorm1d)):
pass
def load_pretrain_model(ckpt_file, network, args):
"""load pretrain model."""
if os.path.isfile(ckpt_file):
param_dict = load_checkpoint(ckpt_file)
param_dict_new = {}
for key, values in param_dict.items():
if key.startswith('moments.'):
continue
elif key.startswith('network.'):
param_dict_new[key[8:]] = values
else:
param_dict_new[key] = values
load_param_into_net(network, param_dict_new)
args.logger.info('load model {} success'.format(ckpt_file))

View File

@ -0,0 +1,331 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train ImageNet."""
import os
import time
import ast
import argparse
import datetime
import mindspore.nn as nn
from mindspore import Tensor, context
from mindspore.context import ParallelMode
from mindspore.nn.optim import Momentum
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import CheckpointConfig, Callback
from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
from mindspore.common import set_seed
from src.dataset import classification_dataset
from src.crossentropy import CrossEntropy
from src.lr_generator import get_lr
from src.utils.logging import get_logger
from src.utils.optimizers__init__ import get_param_groups
from src.utils.var_init import load_pretrain_model
from src.image_classification import get_network
from src.config import config
from src.eval_callback import EvalCallBack
set_seed(1)
class BuildTrainNetwork(nn.Cell):
"""build training network"""
def __init__(self, network, criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
def construct(self, input_data, label):
output = self.network(input_data)
loss = self.criterion(output, label)
return loss
class ProgressMonitor(Callback):
"""monitor loss and time"""
def __init__(self, args):
super(ProgressMonitor, self).__init__()
self.me_epoch_start_time = 0
self.me_epoch_start_step_num = 0
self.args = args
self.ckpt_history = []
def begin(self, run_context):
self.args.logger.info('start network train...')
def epoch_begin(self, run_context):
pass
def epoch_end(self, run_context, *me_args):
"""describe network construct"""
cb_params = run_context.original_args()
me_step = cb_params.cur_step_num - 1
real_epoch = me_step // self.args.steps_per_epoch
time_used = time.time() - self.me_epoch_start_time
fps_mean = (self.args.per_batch_size * (me_step-self.me_epoch_start_step_num))
fps_mean = fps_mean * self.args.group_size
fps_mean = fps_mean / time_used
self.args.logger.info('epoch[{}], iter[{}], loss:{}, '
'mean_fps:{:.2f}'
'imgs/sec'.format(real_epoch,
me_step,
cb_params.net_outputs,
fps_mean))
if self.args.rank_save_ckpt_flag:
import glob
ckpts = glob.glob(os.path.join(self.args.outputs_dir, '*.ckpt'))
for ckpt in ckpts:
ckpt_fn = os.path.basename(ckpt)
if not ckpt_fn.startswith('{}-'.format(self.args.rank)):
continue
if ckpt in self.ckpt_history:
continue
self.ckpt_history.append(ckpt)
self.args.logger.info('epoch[{}], iter[{}], loss:{}, '
'ckpt:{},'
'ckpt_fn:{}'.format(real_epoch,
me_step,
cb_params.net_outputs,
ckpt,
ckpt_fn))
self.me_epoch_start_step_num = me_step
self.me_epoch_start_time = time.time()
def step_begin(self, run_context):
pass
def step_end(self, run_context, *me_args):
pass
def end(self, run_context):
self.args.logger.info('end network train...')
def parse_args(cloud_args=None):
"""parameters"""
parser = argparse.ArgumentParser('mindspore classification training')
parser.add_argument('--platform', type=str, default='Ascend',
choices=('Ascend', 'GPU'), help='run platform')
# dataset related
parser.add_argument('--data_dir', type=str, default='', help='train data dir')
parser.add_argument('--per_batch_size', default=128, type=int, help='batch size for per gpu')
# network related
parser.add_argument('--pretrained',
default='',
type=str,
help='model_path, local pretrained model to load')
# distributed related
parser.add_argument('--is_distributed', type=int, default=1, help='if multi device')
# roma obs
parser.add_argument('--train_url', type=str, default="", help='train url')
#new argument
parser.add_argument("--eval_interval", type=int, default=1,
help="Evaluation interval when run_eval is True, default is 1.")
parser.add_argument("--eval_start_epoch", type=int, default=120,
help="Evaluation start epoch when run_eval is True, default is 120.")
parser.add_argument("--save_best_ckpt", type=ast.literal_eval, default=True,
help="Save best checkpoint when run_eval is True, default is True.")
#dataset of eval dataset
parser.add_argument('--eval_data_dir',
type=str,
default='/opt/npu/pvc/dataset/storage/imagenet/val',
help='eval data dir')
parser.add_argument('--eval_per_batch_size',
default=32,
type=int,
help='batch size for per npu')
parser.add_argument("--run_eval",
type=ast.literal_eval,
default=True,
help="Run evaluation when training, default is True.")
#best ckpt
parser.add_argument('--eval_log_path',
type=str,
default='eval_outputs/',
help='path to save log')
parser.add_argument('--eval_is_distributed',
type=int,
default=0,
help='if multi device')
args, _ = parser.parse_known_args()
args = merge_args(args, cloud_args)
args.image_size = config['image_size']
args.num_classes = config['num_classes']
args.lr = config['lr']
args.lr_scheduler = config['lr_scheduler']
args.lr_epochs = config['lr_epochs']
args.lr_gamma = config['lr_gamma']
args.eta_min = config['eta_min']
args.T_max = config['T_max']
args.max_epoch = config['max_epoch']
args.warmup_epochs = config['warmup_epochs']
args.weight_decay = config['weight_decay']
args.momentum = config['momentum']
args.is_dynamic_loss_scale = config['is_dynamic_loss_scale']
args.loss_scale = config['loss_scale']
args.label_smooth = config['label_smooth']
args.label_smooth_factor = config['label_smooth_factor']
args.ckpt_interval = config['ckpt_interval']
args.ckpt_save_max = config['ckpt_save_max']
args.ckpt_path = config['ckpt_path']
args.is_save_on_master = config['is_save_on_master']
args.rank = config['rank']
args.group_size = config['group_size']
args.lr_epochs = list(map(int, args.lr_epochs.split(',')))
args.image_size = list(map(int, args.image_size.split(',')))
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True,
device_target=args.platform, save_graphs=False)
# init distributed
if args.is_distributed:
init()
args.rank = get_rank()
args.group_size = get_group_size()
else:
args.rank = 0
args.group_size = 1
if args.is_dynamic_loss_scale == 1:
args.loss_scale = 1 # for dynamic loss scale can not set loss scale in momentum opt
# select for master rank save ckpt or all rank save, compatible for model parallel
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if args.rank == 0:
args.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
args.logger = get_logger(args.outputs_dir, args.rank)
return args
def merge_args(args, cloud_args):
"""dictionary"""
args_dict = vars(args)
if isinstance(cloud_args, dict):
for key in cloud_args.keys():
val = cloud_args[key]
if key in args_dict and val:
arg_type = type(args_dict[key])
if arg_type is not type(None):
val = arg_type(val)
args_dict[key] = val
return args
def apply_eval(eval_param):
eval_model = eval_param["model"]
eval_ds = eval_param["dataset"]
metrics_name = eval_param["metrics_name"]
res = eval_model.eval(eval_ds)
return res[metrics_name]
def train(cloud_args=None):
"""training process"""
args = parse_args(cloud_args)
if os.getenv('DEVICE_ID', "not_set").isdigit():
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
# init distributed
if args.is_distributed:
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=args.group_size,
gradients_mean=True)
# dataloader
de_dataset = classification_dataset(args.data_dir, args.image_size,
args.per_batch_size, 1,
args.rank, args.group_size, num_parallel_workers=8)
de_dataset.map_model = 4 # !!!important
args.steps_per_epoch = de_dataset.get_dataset_size()
#eval_dataset
args.logger.save_args(args)
# network
args.logger.important_info('start create network')
# get network and init
network = get_network(num_classes=args.num_classes, platform=args.platform)
load_pretrain_model(args.pretrained, network, args)
# lr scheduler
lr = get_lr(args)
# optimizer
opt = Momentum(params=get_param_groups(network),
learning_rate=Tensor(lr),
momentum=args.momentum,
weight_decay=args.weight_decay,
loss_scale=args.loss_scale)
# loss
if not args.label_smooth:
args.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=args.label_smooth_factor, num_classes=args.num_classes)
if args.is_dynamic_loss_scale == 1:
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536,
scale_factor=2,
scale_window=2000)
else:
loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False)
model = Model(network, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale_manager,
metrics={'acc'}, amp_level="O3")
# checkpoint save
progress_cb = ProgressMonitor(args)
callbacks = [progress_cb,]
#eval dataset
if args.eval_data_dir is None or (not os.path.isdir(args.eval_data_dir)):
raise ValueError("{} is not a existing path.".format(args.eval_data_dir))
#code like eval.py
#if run eval
if args.run_eval:
if args.eval_data_dir is None or (not os.path.isdir(args.eval_data_dir)):
raise ValueError("{} is not a existing path.".format(args.eval_data_dir))
eval_de_dataset = classification_dataset(args.eval_data_dir,
image_size=args.image_size,
per_batch_size=args.eval_per_batch_size,
max_epoch=1,
rank=args.rank,
group_size=args.group_size,
mode='eval')
eval_param_dict = {"model": model, "dataset": eval_de_dataset, "metrics_name": "acc"}
eval_callback = EvalCallBack(apply_eval,
eval_param_dict,
interval=args.eval_interval,
eval_start_epoch=args.eval_start_epoch,
save_best_ckpt=args.save_best_ckpt,
ckpt_directory=args.ckpt_path,
best_ckpt_name="best_acc.ckpt",
metrics_name="acc"
)
callbacks.append(eval_callback)
if args.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.ckpt_interval * args.steps_per_epoch,
keep_checkpoint_max=args.ckpt_save_max)
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix='{}'.format(args.rank))
callbacks.append(ckpt_cb)
model.train(args.max_epoch, de_dataset, callbacks=callbacks, dataset_sink_mode=True)
if __name__ == "__main__":
train()