!16774 push genet code

From: @cuihulan
Reviewed-by: @c_34,@oacjiewen
Signed-off-by: @c_34
This commit is contained in:
mindspore-ci-bot 2021-05-29 12:10:14 +08:00 committed by Gitee
commit a1e6d6a8a5
13 changed files with 1656 additions and 0 deletions

View File

@ -0,0 +1,270 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [GENet概述](#GENet概述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [特性](#特性)
- [混合精度](#混合精度)
- [环境要求](#环境要求)
- [脚本说明](#脚本说明)
- [脚本和样例代码](#脚本和样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [用法](#用法)
- [启动](#启动)
- [结果](#结果)
- [评估过程](#评估过程)
- [用法](#用法-1)
- [启动](#启动-1)
- [结果](#结果-1)
- [模型描述](#模型描述)
- [性能](#性能)
- [训练性能](#训练性能)
- [评估性能](#评估性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# GENet_Res50概述
GENet_Res50是一个基于GEBlock构建于ResNet50之上的卷积神经网络可以将ImageNet图像分成1000个目标类准确率达78.47%。
[论文](https://arxiv.org/abs/1810.12348)
## 模型架构
在对应的代码实现中, extra设为False时对应GEθ-结构extra为True时mlp=False则对应GEθ结构mlp=True则对应GEθ+结构。
GENet_Res50总体网络架构如下:
[链接](https://arxiv.org/abs/1810.12348)
## 数据集
使用的数据集:[imagenet 2017](http://www.image-net.org/)
Imagenet 2017和Imagenet 2012 数据集一致
- 数据集大小144G共1000个类、125万张彩色图像
- 训练集138G共120万张图像
- 测试集6G共5万张图像
- 数据格式RGB
- 注数据在src/dataset.py中处理。
## 特性
## 混合精度
采用[混合精度](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
以FP16算子为例如果输入数据类型为FP32MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志搜索“reduce precision”查看精度降低的算子。
# 环境要求
- 硬件昇腾处理器Ascend
- 使用昇腾处理器来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
## 脚本说明
## 脚本和样例代码
```python
├── GENet_Res50
├── Readme.md
├── scripts
│ ├──run_distribute_train.sh # 使用昇腾处理器进行八卡训练的shell脚本
│ ├──run_train.sh # 使用昇腾处理器进行单卡训练的shell脚本
│ ├──run_eval.sh # 使用昇腾处理器进行评估的单卡shell脚本
├──src
│ ├──config.py # 参数配置
│ ├──dataset.py # 创建数据集
│ ├──lr_generator.py # 配置学习速率
│ ├──crossentropy.py # 定义GENet_Res50的交叉熵
│ ├──GENet.py # GENet_Res50的网络模型
│ ├──GEBlock.py # GENet_Res50的Block模型
├── train.py # 训练脚本
├── eval.py # 评估脚本
├── export.py
```
### 脚本参数
在config.py中可以同时配置训练参数和评估参数。
- 配置GENet_Res50和ImageNet2012数据集。
```python
"class_num": 1000,
"batch_size": 256,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 150,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 10,
"keep_checkpoint_max": 5,
"decay_mode":"linear",
"save_checkpoint_path": "./checkpoints",
"hold_epochs": 0,
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0.8,
"lr_end": 0.0
```
## 训练过程
### 用法
- 晟腾Ascend:
```python
八卡bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [MLP] [EXTRA][PRETRAINED_CKPT_PATH]\(可选)
单卡bash run_train.sh [DATASET_PATH] [MLP] [EXTRA] [DEVICE_ID] [PRETRAINED_CKPT_PATH](optional)
```
### 启动
```python
# 训练示例
# 八卡:
Ascend: bash run_distribute_train.sh ~/hccl_8p_01234567_127.0.0.1.json /data/imagenet/imagenet_original/train True True
# 单卡:
Ascend: bash run_train.sh /data/imagenet/imagenet_original/val True True 5
```
### 结果
八卡训练结果保存在示例路径中。检查点默认保存在`./train_parallel$i/`,训练日志重定向到`./train/device$i/train.log`,单卡训练结果保存在./train_standalone下内容如下
```python
epoch: 1 step: 5000, loss is 4.8995576
epoch: 2 step: 5000, loss is 3.9235563
epoch: 3 step: 5000, loss is 3.833077
epoch: 4 step: 5000, loss is 3.2795618
epoch: 5 step: 5000, loss is 3.1978393
```
## 评估过程
### 用法
使用python或shell脚本开始训练。shell脚本的使用方法如下
- 昇腾Ascendbash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [MLP] [EXTRA] [DEVICE_ID]
### 启动
```shell
# 推理示例
shell:
Ascend: sh run_eval.sh Ascend ~/imagenet/val/ ~/train/GENet-150_625.ckpt True True 0
```
> 训练过程中可以生成检查点。
### 结果
推理结果保存在示例路径中,可以在`./eval/log`中找到如下结果:
```python
result: {'top_5_accuracy': 0.9412860576923077, 'top_1_accuracy': 0.7847355769230769}
```
# 模型描述
## 性能
### 训练性能
| 参数 | GENet_Res50 θ-version (mlp&extra = False)|
| -------------------------- | ---------------------------------------------------------- |
| 模型版本 | V1 |
| 资源 | Ascend 910 八卡; CPU 2.60GHz192核内存 2048G系统 Euler2.8 |
| 上传日期 | 2021-04-26 |
| MindSpore版本 | 1.1.1 |
| 数据集 | ImageNet |
| 训练参数 | src/config.py |
| 优化器 | Momentum |
| 损失函数 | SoftmaxCrossEntropy |
| 输出 | ckpt file |
| 损失 | 1.6 |
| 准确率 |77.8%|
| 总时长 | 8h |
| 参数(M) | batch_size=256, epoch=220 |
| 微调检查点 ||
| 推理模型 ||
| 参数 | GENet_Res50 θversion (mlp=False & extra=True) |
| -------------------------- | ---------------------------------------------------------- |
| 模型版本 | V1 |
| 资源 | Ascend 910 八卡; CPU 2.60GHz192核内存 2048G系统 Euler2.8 |
| 上传日期 | 2021-04-26 |
| MindSpore版本 | 1.1.1 |
| 数据集 | ImageNet |
| 训练参数 | src/config.py |
| 优化器 | Momentum |
| 损失函数 | SoftmaxCrossEntropy |
| 输出 | ckpt file |
| 损失 | 1.6 |
| 准确率 |78%|
| 总时长 | 19h |
| 参数(M) | batch_size=256, epoch=150 |
| 微调检查点 ||
| 推理模型 ||
| 参数 | GENet_Res50 θ+version (mlp=True & extra=True) |
| -------------------------- | ---------------------------------------------------------- |
| 模型版本 | V1 |
| 资源 | Ascend 910 八卡; CPU 2.60GHz192核内存 2048G系统 Euler2.8 |
| 上传日期 | 2021-04-26 |
| MindSpore版本 | 1.1.1 |
| 数据集 | ImageNet |
| 训练参数 | src/config.py |
| 优化器 | Momentum |
| 损失函数 | SoftmaxCrossEntropy |
| 输出 | ckpt file |
| 损失 | 1.6 |
| 准确率 |78.47%|
| 总时长 | 19h |
| 参数(M) | batch_size=256, epoch=150 |
| 微调检查点 ||
| 推理模型 ||
### 评估性能
| 参数列表 | GENet |
| -------------------------- | ----------------------------- |
| 模型版本 | V1 |
| 资源 | Ascend 910系统 Euler2.8 |
| 上传日期 | 2021-04-26 |
| MindSpore版本 | 1.1.1 |
| 数据集 | ImageNet 2012 |
| batch_size | 2561卡 |
| 输出 | 概率 |
| 准确率 | θ-ACC1[77.8%] θ-ACC1[78%] θ+ACC1[78.47%] |
| 速度 | |
| 总时间 | 3分钟 |
| 推理模型 ||
## 随机情况说明
dataset.py中设置了“create_dataset”函数内的种子同时还使用了train.py中的随机种子。
## ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,120 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train GENet."""
import os
import argparse
from mindspore import context
from mindspore.common import set_seed
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.CrossEntropySmooth import CrossEntropySmooth
from src.GENet import GE_resnet50 as Net
from src.dataset import create_dataset
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
parser.add_argument('--train_url', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', choices=("Ascend", "GPU", "CPU"),
help="Device target, support Ascend, GPU and CPU.")
parser.add_argument('--extra', type=str, default="False",
help='whether to use Depth-wise conv to down sample')
parser.add_argument('--mlp', type=str, default="True",
help='bottleneck . whether to use 1*1 conv')
parser.add_argument('--is_modelarts', type=str, default="False", help='is train on modelarts')
args_opt = parser.parse_args()
if args_opt.extra.lower() == "false":
from src.config import config3 as config
else:
if args_opt.mlp.lower() == "false":
from src.config import config2 as config
else:
from src.config import config1 as config
if args_opt.is_modelarts == "True":
import moxing as mox
set_seed(1)
def trans_char_to_bool(str_):
"""
Args:
str_: string
Returns:
bool
"""
result = False
if str_.lower() == "true":
result = True
return result
if __name__ == '__main__':
target = args_opt.device_target
local_data_url = args_opt.data_url
local_pretrained_url = args_opt.checkpoint_path
if args_opt.is_modelarts == "True":
local_data_url = "/cache/data"
mox.file.copy_parallel(args_opt.data_url, local_data_url)
local_pretrained_path = "/cache/pretrained"
mox.file.make_dirs(local_pretrained_path)
filename = "pretrained.ckpt"
local_pretrained_url = os.path.join(local_pretrained_path, filename)
mox.file.copy(args_opt.checkpoint_path, local_pretrained_url)
# init context
context.set_context(mode=context.GRAPH_MODE,
device_target=target,
save_graphs=False)
if target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
# create dataset
dataset = create_dataset(dataset_path=local_data_url,
do_train=False,
batch_size=config.batch_size,
target=target)
step_size = dataset.get_dataset_size()
# define net
mlp = trans_char_to_bool(args_opt.mlp)
extra = trans_char_to_bool(args_opt.extra)
# define net
net = Net(class_num=config.class_num, extra=extra, mlp=mlp)
# load checkpoint
param_dict = load_checkpoint(local_pretrained_url)
load_param_into_net(net, param_dict)
net.set_train(False)
# define loss, model
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True,
reduction='mean',
smooth_factor=config.label_smooth_factor,
num_classes=config.class_num)
# define model
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
# eval model
res = model.eval(dataset)
print("result:", res, "ckpt=", args_opt.checkpoint_path)

View File

@ -0,0 +1,64 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Export GENet_Res50 on ImageNet"""
import argparse
import numpy as np
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.GENet import GE_resnet50 as net
from src.config import config1 as config
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', choices=("Ascend", "GPU", "CPU"),
help="Device target, support Ascend, GPU and CPU.")
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--extra', type=str, default="True",
help='whether to use Depth-wise conv to down sample')
parser.add_argument('--mlp', type=str, default="True", help='bottleneck . whether to use 1*1 conv')
args_opt = parser.parse_args()
def trans_char_to_bool(str_):
"""
Args:
str_: string
Returns:
bool
"""
result = False
if str_.lower() == "true":
result = True
return result
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target,
save_graphs=False)
# define fusion network
mlp = trans_char_to_bool(args_opt.mlp)
extra = trans_char_to_bool(args_opt.extra)
network = net(class_num=config.class_num, extra=extra, mlp=mlp)
# load checkpoint
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
not_load_param = load_param_into_net(network, param_dict)
if not_load_param:
raise ValueError("Load param into network fail!")
# export network
print("============== Starting export ==============")
inputs = Tensor(np.ones([1, 3, 224, 224]))
export(network, inputs, file_name="GENet_Res50")
print("============== End export ==============")

View File

@ -0,0 +1,87 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ] && [ $# != 5 ]
then
echo "Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH] [MLP] [EXTRA] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ $# == 5 ]
then
PATH3=$(get_real_path $5)
fi
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -d $PATH2 ]
then
echo "error: DATASET_PATH=$PATH2 is not a directory"
exit 1
fi
if [ $# == 5 ] && [ ! -f $PATH3 ]
then
echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file"
exit 1
fi
export SERVER_ID=0
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
rank_start=$((DEVICE_NUM * SERVER_ID))
first_device=0
export RANK_TABLE_FILE=$PATH1
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$((first_device+i))
export RANK_ID=$((rank_start + i))
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
if [ $# == 4 ]
then
python train.py --data_url=$PATH2 --mlp=$3 --extra=$4 &> log &
fi
if [ $# == 5 ]
then
python train.py --data_url=$PATH2 --mlp=$3 --extra=$4 --pre_trained=$PATH3 &> log &
fi
cd ..
done

View File

@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 5 ]
then
echo "Usage: bash run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [MLP] [EXTRA] [DEVICE_ID]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=$5
export RANK_SIZE=1
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --data_url=$PATH1 --checkpoint_path=$PATH2 --mlp=$3 --extra=$4 &> log &
cd ..

View File

@ -0,0 +1,74 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 5 ] && [ $# != 4 ]
then
echo "Usage: bash run_train.sh [DATASET_PATH] [MLP] [EXTRA] [DEVICE_ID] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
if [ $# == 5 ]
then
PATH2=$(get_real_path $5)
fi
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ $# == 5 ] && [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_FILE=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export RANK_SIZE=1
export DEVICE_ID=$4
export RANK_ID=0
rm -rf ./train_standalone
mkdir ./train_standalone
cp ../*.py ./train_standalone
cp *.sh ./train_standalone
cp -r ../src ./train_standalone
cd ./train_standalone || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
if [ $# == 4 ]
then
python train.py --data_url=$PATH1 --mlp=$2 --extra=$3 &> log &
fi
if [ $# == 5 ]
then
python train.py --data_url=$PATH1 --mlp=$2 --extra=$3 --pre_trained=$PATH2 &> log &
fi
cd ..

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / num_classes, mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss

View File

@ -0,0 +1,133 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" GEBlock."""
import mindspore.nn as nn
import mindspore as ms
from mindspore.ops import operations as P
class GEBlock(nn.Cell):
"""
Args:
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
spatial(int) : output_size of block
extra_params(bool) : Whether to use DW Conv to down-sample
mlp(bool) : Whether to combine SENet (using 1*1 conv)
Returns:
Tensor, output tensor.
Examples:
>>> GEBlock(3, 128, 2, 56, True, True)
"""
def __init__(self, in_channel, out_channel, stride, spatial, extra_params, mlp):
super().__init__()
expansion = 4
self.mlp = mlp
self.extra_params = extra_params
# middle channel num
channel = out_channel // expansion
self.conv1 = nn.Conv2dBnAct(in_channel, channel, kernel_size=1, stride=1,
has_bn=True, pad_mode="same", activation='relu')
self.conv2 = nn.Conv2dBnAct(channel, channel, kernel_size=3, stride=stride,
has_bn=True, pad_mode="same", activation='relu')
self.conv3 = nn.Conv2dBnAct(channel, out_channel, kernel_size=1, stride=1, pad_mode='same',
has_bn=True)
# whether down-sample identity
self.down_sample = False
if stride != 1 or in_channel != out_channel:
self.down_sample = True
self.down_layer = None
if self.down_sample:
self.down_layer = nn.Conv2dBnAct(in_channel, out_channel,
kernel_size=1, stride=stride,
pad_mode='same', has_bn=True)
if extra_params:
cellList = []
# implementation of DW Conv has some bug while kernel_size is too big, so down sample
if spatial >= 56:
cellList.extend([nn.Conv2d(in_channels=out_channel,
out_channels=out_channel,
kernel_size=3,
stride=2,
pad_mode="same"),
nn.BatchNorm2d(out_channel)])
spatial //= 2
cellList.extend([nn.Conv2d(in_channels=out_channel,
out_channels=out_channel,
kernel_size=spatial,
group=out_channel,
stride=1,
padding=0,
pad_mode="pad"),
nn.BatchNorm2d(out_channel)])
self.downop = nn.SequentialCell(cellList)
else:
self.downop = P.ReduceMean(keep_dims=True)
if mlp:
mlpLayer = []
mlpLayer.append(nn.Conv2d(in_channels=out_channel,
out_channels=out_channel//16,
kernel_size=1))
mlpLayer.append(nn.ReLU())
mlpLayer.append(nn.Conv2d(in_channels=out_channel//16,
out_channels=out_channel,
kernel_size=1))
self.mlpLayer = nn.SequentialCell(mlpLayer)
self.sigmoid = nn.Sigmoid()
self.add = ms.ops.Add()
self.relu = nn.ReLU()
self.mul = ms.ops.Mul()
def construct(self, x):
"""
Args:
x : input Tensor.
"""
identity = x
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
if self.down_sample:
identity = self.down_layer(identity)
if self.extra_params:
out_ge = self.downop(out)
else:
out_ge = self.downop(out, (2, 3))
if self.mlp:
out_ge = self.mlpLayer(out_ge)
out_ge = self.sigmoid(out_ge)
out = self.mul(out, out_ge)
out = self.add(out, identity)
out = self.relu(out)
return out

View File

@ -0,0 +1,317 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""GENet."""
import math
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
from src.GEBlock import GEBlock
def calculate_gain(nonlinearity, param=None):
"""calculate_gain"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d',
'conv_transpose2d', 'conv_transpose3d']
res = 0
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
res = 1
elif nonlinearity == 'tanh':
res = 5.0 / 3
elif nonlinearity == 'relu':
res = math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
return res
def _calculate_fan_in_and_fan_out(tensor):
"""
_calculate_fan_in_and_fan_out
"""
dimensions = len(tensor)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor"
" with fewer than 2 dimensions")
if dimensions == 2: # Linear
fan_in = tensor[1]
fan_out = tensor[0]
else:
num_input_fmaps = tensor[1]
num_output_fmaps = tensor[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = tensor[2] * tensor[3]
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def _calculate_correct_fan(tensor, mode):
"""
for pylint.
"""
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
"""
for pylint.
"""
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
"""
for pylint.
"""
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)
def _conv3x3(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 3, 3)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride,
padding=0, pad_mode='same', weight_init=weight)
def _conv1x1(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 1, 1)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride,
padding=0, pad_mode='same', weight_init=weight)
def _conv7x7(in_channel, out_channel, stride=1):
weight_shape = (out_channel, in_channel, 7, 7)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)
def _bn(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.95,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.95,
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
def _fc(in_channel, out_channel):
weight_shape = (out_channel, in_channel)
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)
class GENet(nn.Cell):
"""
GENet architecture.
Args:
block (Cell): Block for network.
layer_nums (list): Numbers of block in different layers.
in_channels (list): Input channel in each layer.
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
spatial(list): Numbers of output spatial size of different groups.
num_classes (int): The number of classes that the training images are belonging to.
extra_params(bool) : Whether to use DW Conv to down-sample
mlp(bool) : Whether to combine SENet (using 1*1 conv)
Returns:
Tensor, output tensor.
Examples:
>>> GENet(GEBlock,
>>> [3, 4, 6, 3],
>>> [64, 256, 512, 1024],
>>> [256, 512, 1024, 2048],
>>> [1, 2, 2, 2],
>>> [56,28,14,7]
>>> 1001,True,True)
"""
def __init__(self,
block,
layer_nums,
in_channels,
out_channels,
strides,
spatial,
num_classes,
extra_params,
mlp):
super(GENet, self).__init__()
if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
self.extra = extra_params
# initial stage
self.conv1 = _conv7x7(3, 64, stride=2)
self.bn1 = _bn(64)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block=block,
layer_num=layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0],
spatial=spatial[0],
extra_params=extra_params,
mlp=mlp)
self.layer2 = self._make_layer(block=block,
layer_num=layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1],
spatial=spatial[1],
extra_params=extra_params,
mlp=mlp)
self.layer3 = self._make_layer(block=block,
layer_num=layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2],
spatial=spatial[2],
extra_params=extra_params,
mlp=mlp)
self.layer4 = self._make_layer(block=block,
layer_num=layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3],
spatial=spatial[3],
extra_params=extra_params,
mlp=mlp)
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes)
def _make_layer(self, block, layer_num, in_channel, out_channel,
stride, spatial, extra_params, mlp):
"""
Make stage network of GENet.
Args:
block (Cell): GENet block.
layer_num (int): Layer number.
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
spatial(int): output spatial size of every block in same group.
extra_params(bool) : Whether to use DW Conv to down-sample
mlp(bool) : Whether to combine SENet (using 1*1 conv)
Returns:
SequentialCell, the output layer.
"""
layers = []
ge_block = block(in_channel=in_channel,
out_channel=out_channel,
stride=stride,
spatial=spatial,
extra_params=extra_params,
mlp=mlp)
layers.append(ge_block)
for _ in range(1, layer_num):
ge_block = block(in_channel=out_channel,
out_channel=out_channel,
stride=1,
spatial=spatial,
extra_params=extra_params,
mlp=mlp)
layers.append(ge_block)
return nn.SequentialCell(layers)
def construct(self, x):
"""
Args:
x : input Tensor.
"""
# initial stage
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1 = self.maxpool(x)
# four groups
c2 = self.layer1(c1)
c3 = self.layer2(c2)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
out = self.mean(c5, (2, 3))
out = self.flatten(out)
out = self.end_point(out)
return out
def GE_resnet50(class_num=1000, extra=True, mlp=True):
"""
Get GE-ResNet50 neural network.
Default : GE Theta+ version (best)
Args:
class_num (int): Class number.
extra(bool) : Whether to use DW Conv to down-sample
mlp(bool) : Whether to combine SENet (using 1*1 conv)
Returns:
Cell, cell instance of GENet-ResNet50 neural network.
Examples:
>>> net = GE_resnet50(1000)
"""
return GENet(block=GEBlock,
layer_nums=[3, 4, 6, 3],
in_channels=[64, 256, 512, 1024],
out_channels=[256, 512, 1024, 2048],
strides=[1, 2, 2, 2],
spatial=[56, 28, 14, 7],
num_classes=class_num,
extra_params=extra,
mlp=mlp)

View File

@ -0,0 +1,82 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
# config optimizer for resnet50, imagenet2012. Momentum is default, Thor is optional.
cfg = ed({
'optimizer': 'Momentum',
})
config1 = ed({
"class_num": 1000,
"batch_size": 256,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 150,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 5,
"keep_checkpoint_max": 5,
"decay_mode": "linear",
"save_checkpoint_path": "./checkpoints",
"hold_epochs": 0,
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0.8,
"lr_end": 0.0
})
config2 = ed({
"class_num": 1000,
"batch_size": 256,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 150,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 5,
"keep_checkpoint_max": 5,
"decay_mode": "linear",
"save_checkpoint_path": "./checkpoints",
"hold_epochs": 0,
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0.8,
"lr_end": 0.0
})
config3 = ed({
"class_num": 1000,
"batch_size": 256,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 220,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 5,
"keep_checkpoint_max": 5,
"decay_mode": "cosine",
"save_checkpoint_path": "./checkpoints",
"hold_epochs": 0,
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0.8,
"lr_end": 0.0
})

View File

@ -0,0 +1,108 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
create train or eval dataset.
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore.communication.management import init, get_rank, get_group_size
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32,
target="Ascend", distribute=False):
"""
create a train or eval imagenet2012 dataset for resnet50
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
distribute(bool): data for distribute or not. Default: False
Returns:
dataset
"""
if target == "Ascend":
device_num, rank_id = _get_rank_info()
else:
if distribute:
init()
rank_id = get_rank()
device_num = get_group_size()
else:
device_num = 1
if device_num == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
else:
trans = [
C.Decode(),
C.Resize(256),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=8)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
def _get_rank_info():
"""
get rank size and rank id
"""
# rank_size = int(os.getenv("RANK_SIZE", default=1))
rank_size = int(os.getenv("RANK_SIZE"))
if rank_size > 1:
rank_size = get_group_size()
rank_id = get_rank()
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id

View File

@ -0,0 +1,84 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def _generate_linear_lr(lr_init, lr_end, total_steps):
"""
Applies liner decay to generate learning rate array.
Args:
lr_init(float): init learning rate.
lr_end(float): end learning rate
total_steps(int): all steps in training.
Returns:
np.array, learning rate array.
"""
lr_each_step = []
for i in range(total_steps):
lr = lr_init - (lr_init - lr_end) * (i) / (total_steps)
lr_each_step.append(lr)
return lr_each_step
def _generate_cosine_lr(lr_init, total_steps):
"""
Applies cosine decay to generate learning rate array.
Args:
lr_init(float): init learning rate.
lr_end(float): end learning rate
total_steps(int): all steps in training.
warmup_steps(int): all steps in warmup epochs.
Returns:
np.array, learning rate array.
"""
decay_steps = total_steps
lr_each_step = []
for i in range(total_steps):
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_init * decayed
lr_each_step.append(lr)
return lr_each_step
def get_lr(lr_init, lr_end, total_epochs, steps_per_epoch, decay_mode):
"""
generate learning rate array
Args:
lr_init(float): init learning rate
lr_end(float): end learning rate
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
Returns:
np.array, learning rate array
"""
total_steps = steps_per_epoch * total_epochs
if decay_mode == "cosine":
lr_each_step = _generate_cosine_lr(lr_init, total_steps)
else:
lr_each_step = _generate_linear_lr(lr_init, lr_end, total_steps)
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step

View File

@ -0,0 +1,213 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train GENet."""
import os
import argparse
from mindspore import context
from mindspore import Tensor
from mindspore.nn.optim import Momentum
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.callback import LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init
from mindspore.common import set_seed
from mindspore.parallel import set_algo_parameters
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
from src.CrossEntropySmooth import CrossEntropySmooth
from src.GENet import GE_resnet50 as net
from src.lr_generator import get_lr
from src.dataset import create_dataset
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--data_url', type=str, default=None, help='Dataset path')
parser.add_argument('--train_url', type=str, default=None, help='Train output path')
parser.add_argument('--device_target', type=str, default='Ascend', choices=("Ascend", "GPU", "CPU"),
help="Device target, support Ascend, GPU and CPU.")
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path')
parser.add_argument('--extra', type=str, default="True",
help='whether to use Depth-wise conv to down sample')
parser.add_argument('--mlp', type=str, default="True", help='bottleneck . whether to use 1*1 conv')
parser.add_argument('--is_modelarts', type=str, default="False", help='is train on modelarts')
args_opt = parser.parse_args()
if args_opt.extra.lower() == "false":
from src.config import config3 as config
else:
if args_opt.mlp.lower() == "false":
from src.config import config2 as config
else:
from src.config import config1 as config
if args_opt.is_modelarts == "True":
import moxing as mox
set_seed(1)
def filter_checkpoint_parameter_by_list(origin_dict, param_filter):
"""remove useless parameters according to filter_list"""
for key in list(origin_dict.keys()):
for name in param_filter:
if name in key:
print("Delete parameter from checkpoint: ", key)
del origin_dict[key]
break
def trans_char_to_bool(str_):
"""
Args:
str_: string
Returns:
bool
"""
result = False
if str_.lower() == "true":
result = True
return result
if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv("RANK_SIZE"))
ckpt_save_dir = config.save_checkpoint_path
local_train_data_url = args_opt.data_url
if args_opt.is_modelarts == "True":
local_summary_dir = "/cache/summary"
local_data_url = "/cache/data"
local_train_url = "/cache/ckpt"
local_zipfolder_url = "/cache/tarzip"
ckpt_save_dir = local_train_url
mox.file.make_dirs(local_train_url)
mox.file.make_dirs(local_summary_dir)
filename = "imagenet_original.tar.gz"
# transfer dataset
local_data_url = os.path.join(local_data_url, str(device_id))
mox.file.make_dirs(local_data_url)
local_zip_path = os.path.join(local_zipfolder_url, str(device_id), filename)
obs_zip_path = os.path.join(args_opt.data_url, filename)
mox.file.copy(obs_zip_path, local_zip_path)
unzip_command = "tar -xvf %s -C %s" % (local_zip_path, local_data_url)
os.system(unzip_command)
local_train_data_url = os.path.join(local_data_url, "imagenet_original", "train")
target = args_opt.device_target
if target != 'Ascend':
raise ValueError("Unsupported device target.")
run_distribute = False
if device_num > 1:
run_distribute = True
# init context
context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False)
if run_distribute:
context.set_context(device_id=device_id,
enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
set_algo_parameters(elementwise_op_strategy_follow=True)
context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160])
init()
# create dataset
dataset = create_dataset(dataset_path=local_train_data_url, do_train=True, repeat_num=1,
batch_size=config.batch_size, target=target, distribute=run_distribute)
step_size = dataset.get_dataset_size()
# define net
mlp = trans_char_to_bool(args_opt.mlp)
extra = trans_char_to_bool(args_opt.extra)
net = net(class_num=config.class_num, extra=extra, mlp=mlp)
# init weight
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict)
else:
for _, cell in net.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.set_data(weight_init.initializer(weight_init.HeUniform(),
cell.weight.shape,
cell.weight.dtype))
if isinstance(cell, nn.Dense):
cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.shape,
cell.weight.dtype))
lr = get_lr(config.lr_init, config.lr_end, config.epoch_size, step_size, config.decay_mode)
lr = Tensor(lr)
# define opt
decayed_params = []
no_decayed_params = []
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
{'params': no_decayed_params},
{'order_params': net.trainable_params()}]
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
# define loss, model
if target == "Ascend":
if not config.use_label_smooth:
config.label_smooth_factor = 0.0
loss = CrossEntropySmooth(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor,
num_classes=config.class_num)
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale,
metrics={'acc'}, amp_level="O2", keep_batchnorm_fp32=False)
else:
raise ValueError("Unsupported device target.")
# define callbacks
time_cb = TimeMonitor(data_size=step_size)
loss_cb = LossMonitor()
rank_id = int(os.getenv("RANK_ID"))
cb = [time_cb, loss_cb]
if rank_id == 0:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs*step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="GENet", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
dataset_sink_mode = target != "CPU"
model.train(config.epoch_size, dataset, callbacks=cb,
sink_size=dataset.get_dataset_size(), dataset_sink_mode=dataset_sink_mode)
if device_id == 0 and args_opt.is_modelarts == "True":
mox.file.copy_parallel(ckpt_save_dir, args_opt.train_url)