hardnet_master

This commit is contained in:
warrior0 2021-03-26 11:38:40 +08:00 committed by warrior
parent dd607fbb65
commit cf5157b89f
13 changed files with 1532 additions and 0 deletions

View File

@ -0,0 +1,333 @@
# 目录
<!-- TOC -->
- [目录](#目录)
- [HarDNet描述](#hardnet描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [特性](#特性)
- [混合精度](#混合精度)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [训练](#训练)
- [分布式训练](#分布式训练)
- [评估过程](#评估过程)
- [评估](#评估)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [ImageNet上的HarDNet](#ImageNet上的hardnet)
- [推理性能](#推理性能)
- [ImageNet上的HarDNet](#ImageNet上的hardnet)
- [使用流程](#使用流程)
- [推理](#推理)
- [迁移学习](#迁移学习)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# HarDNet描述
HarDNet指的是Harmonic DenseNet: A low memory traffic network其突出的特点就是低内存占用率。过去几年随着更强的计算能力和更大的数据集我们能够训练更加复杂的网络。对于实时应用我们面临的问题是如何在提高计算效率的同时降低功耗。在这种情况下作者们提出了HarDNet在两者之间寻求最佳平衡。
[论文](https://arxiv.org/abs/1909.00948)Chao P , Kao C Y , Ruan Y , et al. HarDNet: A Low Memory Traffic Network[C]// 2019 IEEE/CVF International Conference on Computer Vision (ICCV). IEEE, 2020.
# 模型架构
作者对每一层的MoC施加一个软约束以设计一个低CIO网络模型并合理增加MACs。避免使用MoC非常低的层例如具有非常大输入/输出通道比的Conv1x1层。受Densely Connected Networks的启发作者提出了Harmonic Densely Connected Network (HarD- Net) 。首先减少来自DenseNet的大部分层连接以降低级联损耗。然后通过增加层的通道宽度来平衡输入/输出通道比率。
# 数据集
使用的数据集ImageNet2012
- 数据集大小125G共1000个类、1.2万张彩色图像
- 训练集120G共1.2万张图像
- 测试集5G共5万张图像
- 数据格式RGB
- 注数据将在src/dataset.py中处理。
# 特性
## 混合精度
采用[混合精度](https://www.mindspore.cn/tutorial/training/zh-CN/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
以FP16算子为例如果输入数据类型为FP32MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志搜索“reduce precision”查看精度降低的算子。
# 环境要求
- 硬件Ascend/GPU
- 使用Ascend或GPU处理器来搭建硬件环境。
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- Ascend处理器环境运行
```python
# 运行训练示例
python3 train.py > train.log 2>&1 & --dataset_path /path/dataset --pre_ckpt_path /path/pretrained_path --isModelArts False --distribute False
OR
bash run_single_train.sh /path/dataset 0 /path/pretrained_path
# 运行分布式训练示例
python3 train.py > train.log 2>&1 & --dataset_path /path/dataset --pre_ckpt_path /path/pretrained_path --isModelArts False
OR
bash run_distribute_train.sh /path/dataset /path/pretrain_path 8
# 运行评估示例
python3 eval.py > eval.log 2>&1 & --dataset_path /path/dataset --ckpt_path /path/ckpt
bash run_eval.sh /path/dataset 0 /path/ckpt
```
对于分布式训练需要提前创建JSON格式的hccl配置文件。
请遵循以下链接中的说明:
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
- 默认使用ImageNet2012数据集。您也可以将`$dataset_type`传入脚本,以便选择其他数据集。如需查看更多详情,请参考指定脚本。
# 脚本说明
## 脚本及样例代码
```bash
├── model_zoo
├── README.md // 所有模型相关说明
├── hardnet
├── README.md // hardnet相关说明
├── scripts
│ ├──run_single_train.sh // 单卡到Ascend的shell脚本
│ ├──run_distribute_train.sh // 分布式到Ascend的shell脚本
│ ├──run_eval.sh // Ascend评估的shell脚本
├── src
│ ├──dataset.py // 创建数据集
│ ├──hardnet.py // hardnet架构
│ ├──EntropyLoss.py // loss函数
| ├──config.py //参数配置
| ├──lr_generator.py //学习率创建相关
| ├──HarDNet85.ckpt //预训练权重
| ├──pth2ckpt.py //将作者给的预训练权重转换为.ckpt文件
├── train.py // 训练脚本
├── eval.py // 评估脚本
├── export.py //将checkpoint文件导出到air/onnx下
```
## 脚本参数
在config.py中可以同时配置训练参数和评估参数。
- 配置ImageNet数据集。
```python
"class_num": 1000 #数据集类数
"batch_size": 256 #训练批次大小
"loss_scale": 1024 #损失量表的浮点值
"momentum": 0.9 #动量
"weight_decay": 6e-5 #权重的衰减值
"epoch_size": 150 #总计训练epoch数
"pretrain_epoch_size": 0 #预训练批次
"save_checkpoint": True #是否保存checkpoint文件
"save_checkpoint_epochs": 5 #保存checkpoint的epoch频率
"keep_checkpoint_max": 10 #只存最后一个keep_checkpoint_max检查点
"save_checkpoint_path": "/home/hardnet/result/HarDNet-150_625.ckpt" #checkpoint文件保存的绝对全路径
"warmup_epochs": 5 #预热次数
"lr_decay_mode": "cosine" #学习速率衰减模式,包括步长、多边形或默认
"lr_init": 0.05 #初始学习率
"lr_end": 0.00001 #结束学习率
"lr_max": 0.1 #最大学习率
```
更多配置细节请参考脚本`config.py`。
## 训练过程
### 加载预训练权重
论文作者给出的预训练权重:[HarDNet85.pth](https://ping-chao.com/hardnet/hardnet85-a28faa00.pth)
```bash
python3 pth2ckpt.py --dataset_path /path/pthfile
```
### 训练
- Ascend处理器环境运行
```bash
python3 train.py > train.log 2>&1 & --dataset_path /path/dataset --pre_ckpt_path /path/pretrained_path --isModelArts False --distribute False
OR
bash run_single_train.sh /path/dataset 0 /path/pretrained_path
```
上述python命令将在后台运行您可以通过train.log文件查看结果。
训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式达到损失值:
```bash
# grep "loss is " train.log
epoch:1 step:625, loss is 2.4842823
epcoh:2 step:625, loss is 3.0897788
...
```
模型检查点保存在当前目录下。
### 分布式训练
- Ascend处理器环境运行
```bash
python3 train.py > train.log 2>&1 & --dataset_path /path/dataset --pre_ckpt_path /path/pretrained_path --isModelArts False
OR
bash run_distribute_train.sh /path/dataset /path/pretrain_path 8
```
上述shell脚本将在后台运行分布训练。您可以通过train_parallel[X]/log文件查看结果。采用以下方式达到损失值
```bash
# grep "result:" device*/log
device0/log:epoch:1 step:625, loss is 2.4302931
device0/log:epcoh:2 step:625, loss is 2.4023874
...
device1/log:epoch:1 step:625, loss is 2.3458025
device1/log:epcoh:2 step:625, loss is 2.3729336
...
...
```
## 评估过程
### 评估
- 在Ascend环境运行时评估ImageNet数据集
在运行以下命令之前请检查用于评估的检查点路径。请将检查点路径设置为绝对全路径例如“username/hardnet/train_hardnet_390.ckpt”。
```bash
python3 eval.py > eval.log 2>&1 & --dataset_path /path/dataset --ckpt_path /path/ckpt
OR
bash run_eval.sh /path/dataset 0 /path/ckpt
```
上述python命令将在后台运行您可以通过eval.log文件查看结果。测试数据集的准确性如下
```bash
# grep "accuracy:" eval.log
accuracy:{'acc':0.774}
```
对于分布式训练后评估请将checkpoint_path设置为最后保存的检查点文件如“username/hardnet/device0/train_hardnet-150-625.ckpt”。测试数据集的准确性如下
```bash
# grep "accuracy:" dist.eval.log
accuracy:{'acc':0.777}
```
# 模型描述
## 性能
### 评估性能
#### ImageNet上的HarDNet
| 参数 | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| 模型版本 | Inception V1 |
| 资源 | Ascend 910 CPU 2.60GHz192核内存755G |
| 上传日期 | 2021-3-22 |
| MindSpore版本 | 1.1.1-aarch64 |
| 数据集 | ImageNet2012 |
| 训练参数 | epoch=150, steps=625, batch_size = 256, lr=0.1 |
| 优化器 | Momentum |
| 损失函数 | Softmax交叉熵 |
| 输出 | 概率 |
| 损失 | 0.0016 |
| 速度 | 单卡347毫秒/步; 8卡358毫秒/步 |
| 总时长 | 单卡72小时50分钟; 8卡10小时14分钟 |
| 参数(M) | 13.0 |
| 微调检查点 | 280M (.ckpt文件) |
| 脚本 | [hardnet脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/hardnet) |
### 推理性能
#### ImageNet上的HarDNet
| 参数 | Ascend |
| ------------------- | --------------------------- |
| 模型版本 | Inception V1 |
| 资源 | Ascend 910 |
| 上传日期 | 2020-09-20 |
| MindSpore版本 | 1.1.1-aarch64 |
| 数据集 | ImageNet2012 |
| batch_size | 256 |
| 输出 | 概率 |
| 准确性 | 8卡: 78% |
## 使用流程
### 推理
如果您需要使用此训练模型在Ascend 910上进行推理可参考此[链接](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html)。下面是操作步骤示例:
- Ascend处理器环境运行
```python
# 设置上下文
context.set_context(mode=context.GRAPH_MODE,
device_target=target,
save_graphs=False,
device_id=device_id)
# 加载未知数据集进行推理
predict_data = create_dataset_ImageNet(dataset_path=args.dataset_path,
do_train=False,
repeat_num=1,
batch_size=config.batch_size,
target=target)
# 定义网络
network = HarDNet85(num_classes=config.class_num)
# 加载checkpoint
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(network, param_dict)
# 定义损失函数
loss = CrossEntropySmooth(smooth_factor=args.label_smooth_factor,
num_classes=config.class_num)
# 定义模型
model = Model(network, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
# 对未知数据集进行预测
acc = model.eval(predict_data)
print("==============Acc: {} ==============".format(acc))
```
### 迁移学习
待补充
# 随机情况说明
在dataset.py中我们设置了“create_dataset”函数内的种子同时还使用了train.py中的随机种子。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,91 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############test hardnet example on imagenet#################
python3 eval.py
"""
import argparse
import random
import numpy as np
from mindspore import context
from mindspore import dataset
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model
from src.dataset import create_dataset_ImageNet
from src.HarDNet import HarDNet85
from src.EntropyLoss import CrossEntropySmooth
from src.config import config
random.seed(1)
np.random.seed(1)
dataset.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--use_hardnet', type=bool, default=True, help='Enable HarnetUnit')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--dataset_path', type=str, default='/data/imagenet_original/val/',
help='Dataset path')
parser.add_argument('--ckpt_path', type=str,
default='/home/hardnet/result/HarDNet-150_625.ckpt',
help='if mode is test, must provide path where the trained ckpt file')
parser.add_argument('--label_smooth_factor', type=float, default=0.1, help='label_smooth_factor')
parser.add_argument('--device_id', type=int, default=0, help='device_id')
args = parser.parse_args()
def test(ckpt_path):
"""run eval"""
target = args.device_target
# init context
context.set_context(mode=context.GRAPH_MODE,
device_target=target,
save_graphs=False,
device_id=args.device_id)
# dataset
predict_data = create_dataset_ImageNet(dataset_path=args.dataset_path,
do_train=False,
repeat_num=1,
batch_size=config.batch_size,
target=target)
step_size = predict_data.get_dataset_size()
if step_size == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")
# define net
network = HarDNet85(num_classes=config.class_num)
# load checkpoint
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(network, param_dict)
# define loss, model
loss = CrossEntropySmooth(smooth_factor=args.label_smooth_factor,
num_classes=config.class_num)
model = Model(network, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
print("Dataset path: {}".format(args.dataset_path))
print("Ckpt path :{}".format(ckpt_path))
print("Class num: {}".format(config.class_num))
print("Backbone hardnet")
print("============== Starting Testing ==============")
acc = model.eval(predict_data)
print("==============Acc: {} ==============".format(acc))
if __name__ == '__main__':
path = args.ckpt_path
test(path)

View File

@ -0,0 +1,51 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into air, onnx, mindir models#################
python export.py
"""
import argparse
import numpy as np
import mindspore.common.dtype as ms
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
from src.config import config
from src.HarDNet import HarDNet85
parser = argparse.ArgumentParser(description='Classification')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="hardnet", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
net = HarDNet85(num_classes=config.class_num)
assert args.ckpt_file is not None, "config.checkpoint_path is None."
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.ones([args.batch_size, 3, config.image_height, config.image_width]), ms.float32)
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)

View File

@ -0,0 +1,94 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_distribute_train.sh DATA_PATH pretrain_path RANK_SIZE"
echo "For example: bash run_distribute_train.sh /path/dataset /path/pretrain_path 8"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
set -e
DATA_PATH=$1
PRETRAINED_PATH=$2
export DATA_PATH=${DATA_PATH}
RANK_SIZE=$3
EXEC_PATH=$(pwd)
echo "$EXEC_PATH"
test_dist_8pcs()
{
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
export RANK_SIZE=8
}
test_dist_4pcs()
{
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_4pcs.json
export RANK_SIZE=4
}
test_dist_2pcs()
{
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
export RANK_SIZE=2
}
test_dist_${RANK_SIZE}pcs
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
for((i=1;i<${RANK_SIZE};i++))
do
rm -rf device$i
mkdir device$i
cd ./device$i
mkdir src
cd ../
cp ../*.py ./device$i
cp ../src/*.py ./device$i/src
cd ./device$i
export DEVICE_ID=$i
export RANK_ID=$i
echo "start training for device $i"
env > env$i.log
nohup python3 -u train.py --dataset_path ${DATA_PATH} --isModelArts False --distribute True --pre_ckpt_path ${PRETRAINED_PATH} > train$i.log 2>&1 &
echo "$i finish"
cd ../
done
rm -rf device0
mkdir device0
cd ./device0
mkdir src
cd ../
cp ../*.py ./device0
cp ../src/*.py ./device0/src
cd ./device0
export DEVICE_ID=0
export RANK_ID=0
echo "start training for device 0"
env > env0.log
nohup python3 -u train.py --dataset_path ${DATA_PATH} --isModelArts False --distribute True --pre_ckpt_path ${PRETRAINED_PATH} > train0.log 2>&1
if [ $? -eq 0 ];then
echo "training success"
else
echo "training failed"
exit 2
fi
echo "finish"
cd ../

View File

@ -0,0 +1,46 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_eval.sh DATA_PATH DEVICE_ID CKPT_PATH"
echo "For example: bash run_eval.sh /path/dataset 0 /path/ckpt"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
set -e
DATA_PATH=$1
DEVICE_ID=$2
export DATA_PATH=${DATA_PATH}
EXEC_PATH=$(pwd)
echo "$EXEC_PATH"
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
cd ../
export DEVICE_ID=$2
export RANK_ID=$2
env > env0.log
python eval.py --dataset_path $1 --device_id $2 --ckpt_path $3> eval.log 2>&1
if [ $? -eq 0 ];then
echo "testing success"
else
echo "testing failed"
exit 2
fi
echo "finish"
cd ../

View File

@ -0,0 +1,43 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_standalone_train.sh DATA_PATH DEVICE_ID PRETRAINED_PATH"
echo "For example: bash run_single_train.sh /path/dataset 0 /path/pretrained_path"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
set -e
DATA_PATH=$1
DEVICE_ID=$2
export DATA_PATH=${DATA_PATH}
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
cd ../
export DEVICE_ID=$2
export RANK_ID=$2
env > env0.log
python3 train.py --dataset_path $1 --isModelArts False --distribute False --device_id $2 --pre_ckpt_path $3 > train.log 2>&1
if [ $? -eq 0 ];then
echo "training success"
else
echo "training failed"
exit 2
fi
echo "finish"
cd ../

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()
self.onehot = P.OneHot()
self.sparse = sparse
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
def construct(self, logit, label):
if self.sparse:
label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, label)
return loss

View File

@ -0,0 +1,311 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""HarDNet"""
import mindspore.nn as nn
from mindspore.ops import operations as P
class GlobalAvgpooling(nn.Cell):
"""
GlobalAvgpooling function
"""
def __init__(self):
super(GlobalAvgpooling, self).__init__()
self.mean = P.ReduceMean(True)
self.shape = P.Shape()
self.reshape = P.Reshape()
def construct(self, x):
x = self.mean(x, (2, 3))
b, c, _, _ = self.shape(x)
x = self.reshape(x, (b, c))
return x
class _ConvLayer(nn.Cell):
"""
convlayer
"""
def __init__(self, in_channels, out_channels, kernel=3, stride=1, dropout=0.9, bias=False):
super(_ConvLayer, self).__init__()
self.ConvLayer_Conv = nn.Conv2d(in_channels, out_channels,
kernel_size=kernel,
stride=stride,
has_bias=bias,
padding=kernel // 2,
pad_mode="pad")
self.ConvLayer_BN = nn.BatchNorm2d(out_channels)
self.ConvLayer_RE = nn.ReLU6()
def construct(self, x):
out = self.ConvLayer_Conv(x)
out = self.ConvLayer_BN(out)
out = self.ConvLayer_RE(out)
return out
class _DWConvLayer(nn.Cell):
"""
dwconvlayer
"""
def __init__(self, in_channels, out_channels, stride=1, bias=False):
super(_DWConvLayer, self).__init__()
self.DWConvLayer_Conv = nn.Conv2d(in_channels, in_channels,
kernel_size=3,
stride=stride,
has_bias=bias,
padding=1,
pad_mode="pad")
self.DWConvLayer_BN = nn.BatchNorm2d(in_channels)
def construct(self, x):
out = self.DWConvLayer_Conv(x)
out = self.DWConvLayer_BN(out)
return out
class _CombConvLayer(nn.Cell):
"""
combconvlayer
"""
def __init__(self, in_channels, out_channels, kernel=1, stride=1, dropout=0.9, bias=False):
super(CombConvLayer, self).__init__()
self.CombConvLayer_Conv = _ConvLayer(in_channels, out_channels, kernel=kernel)
self.CombConvLayer_DWConv = _DWConvLayer(out_channels, out_channels, stride=stride)
def construct(self, x):
out = CombConvLayer_Conv(x)
out = CombConvLayer_DWConv(out)
return out
class _HarDBlock(nn.Cell):
"""the HarDBlock function"""
def get_link(self, layer, bash_ch, growth_rate, grmul):
"""
link all layers
"""
if layer == 0:
return bash_ch, 0, []
out_channels = growth_rate
link = []
for i in range(10):
dv = 2 ** i
if layer % dv == 0:
k = layer - dv
link.append(k)
if i > 0:
out_channels *= grmul
out_channels = int(int(out_channels + 1) / 2) * 2
in_channels = 0
for i in link:
ch, _, _ = self.get_link(i, bash_ch, growth_rate, grmul)
in_channels += ch
return out_channels, in_channels, link
def get_out_ch(self):
return self.out_channels
def __init__(self, in_channels, growth_rate, grmul, n_layers, keepBase=False, residual_out=False, dwconv=False):
super(_HarDBlock, self).__init__()
self.keepBase = keepBase
self.links = []
self.layer_list = nn.CellList()
self.out_channels = 0
for i in range(n_layers):
outch, inch, link = self.get_link(i + 1, in_channels, growth_rate, grmul)
self.links.append(link)
if dwconv:
layer = _CombConvLayer(inch, outch)
self.layer_list.append(layer)
else:
layer = _ConvLayer(inch, outch)
self.layer_list.append(layer)
if (i % 2 == 0) or (i == n_layers - 1):
self.out_channels += outch
self.concate = P.Concat(axis=1)
def construct(self, x):
""""
construct all parameters
"""
layers_ = [x]
for layer in range(len(self.layer_list)):
link = self.links[layer]
tin = []
for i in link:
tin.append(layers_[i])
if len(tin) > 1:
input_ = tin[0]
for j in range(len(tin) - 1):
input_ = self.concate((input_, tin[j + 1]))
else:
input_ = tin[0]
out = self.layer_list[layer](input_)
layers_.append(out)
t = len(layers_)
out_ = []
for j in range(t):
if (j == 0 and self.keepBase) or (j == t - 1) or (j % 2 == 1):
out_.append(layers_[j])
output = out_[0]
for k in range(len(out_) - 1):
output = self.concate((output, out_[k + 1]))
return output
class _CommenHead(nn.Cell):
"""
the transition layer
"""
def __init__(self, num_classes, out_channels, drop_rate):
super(_CommenHead, self).__init__()
self.avgpool = GlobalAvgpooling()
self.flat = nn.Flatten()
self.drop = nn.Dropout(keep_prob=drop_rate)
self.dense = nn.Dense(out_channels, num_classes, has_bias=True)
def construct(self, x):
x = self.avgpool(x)
x = self.flat(x)
x = self.drop(x)
x = self.dense(x)
return x
class HarDNet(nn.Cell):
"""
the HarDNet layers
"""
__constants__ = ['layers']
def __init__(self, depth_wise=False, arch=68, pretrained=False):
super(HarDNet, self).__init__()
first_ch = [32, 64]
second_kernel = 3
max_pool = True
grmul = 1.7
drop_rate = 0.9
# HarDNet68
ch_list = [128, 256, 320, 640, 1024]
gr = [14, 16, 20, 40, 160]
n_layers = [8, 16, 16, 16, 4]
downSamp = [1, 0, 1, 1, 0]
if arch == 85:
# HarDNet85
first_ch = [48, 96]
ch_list = [192, 256, 320, 480, 720, 1280]
gr = [24, 24, 28, 36, 48, 256]
n_layers = [8, 16, 16, 16, 16, 4]
downSamp = [1, 0, 1, 0, 1, 0]
drop_rate = 0.2
elif arch == 39:
# HarDNet39
first_ch = [24, 48]
ch_list = [96, 320, 640, 1024]
grmul = 1.6
gr = [16, 20, 64, 160]
n_layers = [4, 16, 8, 4]
downSamp = [1, 1, 1, 0]
if depth_wise:
second_kernel = 1
max_pool = False
drop_rate = 0.05
blks = len(n_layers)
self.layers = nn.CellList()
self.layers.append(_ConvLayer(3, first_ch[0], kernel=3, stride=2, bias=False))
self.layers.append(_ConvLayer(first_ch[0], first_ch[1], kernel=second_kernel))
if max_pool:
self.layers.append(nn.MaxPool2d(kernel_size=3, stride=2))
else:
self.layers.append(_DWConvLayer(first_ch[1], first_ch[1], stride=2))
ch = first_ch[1]
for i in range(blks):
blk = _HarDBlock(ch, gr[i], grmul, n_layers[i], dwconv=depth_wise)
ch = blk.get_out_ch()
self.layers.append(blk)
if i == blks - 1 and arch == 85:
self.layers.append(nn.Dropout(keep_prob=0.9))
self.layers.append(_ConvLayer(ch, ch_list[i], kernel=1))
ch = ch_list[i]
if downSamp[i] == 1:
if max_pool:
self.layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
else:
self.layers.append(_DWConvLayer(ch, ch, stride=2))
self.out_channels = ch_list[blks - 1]
self.droprate = drop_rate
def construct(self, x):
for layer in self.layers:
x = layer(x)
return x
def get_out_channels(self):
return self.out_channels
def get_drop_rate(self):
return self.droprate
class HarDNet68(nn.Cell):
"""
hardnet68
"""
def __init__(self, num_classes):
super(HarDNet68, self).__init__()
self.net = HarDNet(depth_wise=False, arch=68, pretrained=False)
out_channels = self.net.get_out_channels()
drop_rate = self.net.get_drop_rate()
self.head = _CommenHead(num_classes, out_channels, drop_rate)
def construct(self, x):
x = self.net(x)
x = self.head(x)
return x
class HarDNet85(nn.Cell):
"""
hardnet85
"""
def __init__(self, num_classes):
super(HarDNet85, self).__init__()
self.net = HarDNet(depth_wise=False, arch=85, pretrained=False)
out_channels = self.net.get_out_channels()
drop_rate = self.net.get_drop_rate()
self.head = _CommenHead(num_classes, out_channels, drop_rate)
def construct(self, x):
x = self.net(x)
x = self.head(x)
return x

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in main.py
"""
from easydict import EasyDict
config = EasyDict({
"class_num": 1000,
"batch_size": 256,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 6e-5,
"epoch_size": 150,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 10,
"image_height": 224,
"image_width": 224,
"keep_checkpoint_max": 10,
"save_checkpoint_path": '/home/hardnet/result/',
"warmup_epochs": 5,
"lr_decay_mode": "cosine",
"lr_init": 0.05,
"lr_end": 0.00001,
"lr_max": 0.1
})

View File

@ -0,0 +1,101 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore.communication.management import init, get_rank, get_group_size
def create_dataset_ImageNet(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
"""
create a train or eval imagenet2012 dataset for hardnet
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
if target == "Ascend":
device_num, rank_id = _get_rank_info()
else:
init("nccl")
rank_id = get_rank()
device_num = get_group_size()
image_size = 224
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
if device_num == 1:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
# define map operations
if do_train:
trans = [
C.Decode(),
C.RandomResizedCrop(image_size),
C.RandomHorizontalFlip(),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
else:
trans = [
C.Decode(),
C.Resize(256),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
def _get_rank_info():
"""
get rank size and rank id
"""
rank_size = int(os.environ.get("RANK_SIZE"))
if rank_size > 1:
rank_size = int(os.environ.get("RANK_SIZE"))
rank_id = int(os.environ.get("RANK_ID"))
else:
rank_size = 1
rank_id = 0
return rank_size, rank_id

View File

@ -0,0 +1,140 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math
import numpy as np
def power_lr(lr_init, total_epochs, steps_per_epoch):
lr_each_step = []
for _ in range(steps_per_epoch):
lr_each_step.append(lr_init)
for _ in range(total_epochs - 1):
lr_init = lr_init - lr_init * 0.1
for _ in range(steps_per_epoch):
lr_each_step.append(lr_init)
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
"""
generate learning rate array
Args:
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode, including steps, poly or default
Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
if lr_decay_mode == 'steps':
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
elif lr_decay_mode == 'poly':
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr = float(lr_max) * base * base
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
elif lr_decay_mode == 'cosine':
decay_steps = total_steps - warmup_steps
for i in range(total_steps):
if i < warmup_steps:
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_max * decayed
lr_each_step.append(lr)
else:
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr)
lr_each_step = np.array(lr_each_step).astype(np.float32)
return lr_each_step
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
lr = float(init_lr) + lr_inc * current_step
return lr
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch=120, global_step=0):
"""
generate learning rate array with cosine
Args:
lr(float): base learning rate
steps_per_epoch(int): steps size of one epoch
warmup_epochs(int): number of warmup epochs
max_epoch(int): total epochs of training
global_step(int): the current start index of lr array
Returns:
np.array, learning rate array
"""
base_lr = lr
warmup_init_lr = 0
total_steps = int(max_epoch * steps_per_epoch)
warmup_steps = int(warmup_epochs * steps_per_epoch)
decay_steps = total_steps - warmup_steps
lr_each_step = []
for i in range(total_steps):
if i < warmup_steps:
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = base_lr * decayed
lr_each_step.append(lr)
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[global_step:]
return learning_rate

View File

@ -0,0 +1,67 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""pth --> ckpt"""
import argparse
import torch
from mindspore.train.serialization import save_checkpoint
from mindspore import Tensor
def replace_self(name1, str1, str2):
return name1.replace(str1, str2)
parser = argparse.ArgumentParser(description='')
parser.add_argument('--pth_path', type=str, default='/disk3/pth2ckpt/hardnet85.pth',
help='pth path')
parser.add_argument('--device_target', type=str, default='cpu',
help='device target')
args = parser.parse_args()
print(args)
if __name__ == '__main__':
par_dict = torch.load(args.pth_path, map_location=args.device_target)
new_params_list = []
for name in par_dict:
param_dict = {}
parameter = par_dict[name]
print(name)
name = replace_self(name, ".layers.", ".layer_list.")
name = replace_self(name, "base.", "net.layers.")
name = replace_self(name, "conv", "ConvLayer_Conv")
name = replace_self(name, "norm", "ConvLayer_BN")
name = replace_self(name, "base.16.3.weight", "head.dense.weight")
name = replace_self(name, "base.16.3.bias", "head.dense.bias")
if name.endswith('ConvLayer_BN.weight'):
name = name[:name.rfind('ConvLayer_BN.weight')]
name = name + 'ConvLayer_BN.gamma'
elif name.endswith('ConvLayer_BN.bias'):
name = name[:name.rfind('ConvLayer_BN.bias')]
name = name + 'ConvLayer_BN.beta'
elif name.endswith('.running_mean'):
name = name[:name.rfind('.running_mean')]
name = name + '.moving_mean'
elif name.endswith('.running_var'):
name = name[:name.rfind('.running_var')]
name = name + '.moving_variance'
print('========================hardnet_name', name)
param_dict['name'] = name
param_dict['data'] = Tensor(parameter.numpy())
new_params_list.append(param_dict)
save_checkpoint(new_params_list, 'HarDNet85.ckpt')

View File

@ -0,0 +1,178 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
train hardnet
"""
import os
import math
import argparse
import ast
from mindspore import context
from mindspore import Tensor
from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
from src.dataset import create_dataset_ImageNet
from src.lr_scheduler import get_lr
from src.HarDNet import HarDNet85
from src.EntropyLoss import CrossEntropySmooth
from src.config import config
parser = argparse.ArgumentParser(description='Image classification with HarDNet on Imagenet')
parser.add_argument('--dataset_path', type=str, default='/home/hardnet/imagenet_original/train/',
help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--device_num', type=int, default=8, help='Device num')
parser.add_argument('--pre_trained', type=str, default=True)
parser.add_argument('--train_url', type=str)
parser.add_argument('--data_url', type=str)
parser.add_argument('--pre_ckpt_path', type=str, default='/home/work/user-job-dir/hardnet/src/HarDNet85.ckpt')
parser.add_argument('--label_smooth_factor', type=float, default=0.1, help='label_smooth_factor')
parser.add_argument('--isModelArts', type=ast.literal_eval, default=True)
parser.add_argument('--distribute', type=ast.literal_eval, default=True)
parser.add_argument('--device_id', type=int, default=0, help='device_id')
args = parser.parse_args()
if args.isModelArts:
import moxing as mox
if __name__ == '__main__':
target = args.device_target
if args.distribute:
# init context
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
else:
device_id = args.device_id
context.set_context(mode=context.GRAPH_MODE,
device_target=target,
save_graphs=False,
device_id=args.device_id)
if args.isModelArts:
import moxing as mox
# download dataset from obs to cache
mox.file.copy_parallel(src_url=args.data_url, dst_url='/cache/dataset/device_' + os.getenv('DEVICE_ID'))
train_dataset_path = '/cache/dataset/device_' + os.getenv('DEVICE_ID')
# create dataset
train_dataset = create_dataset_ImageNet(dataset_path=train_dataset_path,
do_train=True,
repeat_num=1,
batch_size=config.batch_size,
target=target)
else:
train_dataset = create_dataset_ImageNet(dataset_path=args.dataset_path,
do_train=True,
repeat_num=1,
batch_size=config.batch_size,
target=target)
step_size = train_dataset.get_dataset_size()
# init lr
lr = get_lr(lr_init=config.lr_init,
lr_end=config.lr_end,
lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs,
total_epochs=config.epoch_size,
steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode)
lr = Tensor(lr)
# define net
network = HarDNet85(num_classes=config.class_num)
print("----network----")
# init weight
if args.pre_trained:
param_dict = load_checkpoint(args.pre_ckpt_path)
load_param_into_net(network, param_dict)
else:
for _, cell in network.cells_and_names():
if isinstance(cell, nn.Conv2d):
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(gain=1 / math.sqrt(3)),
cell.weight.shape,
cell.weight.dtype)
if isinstance(cell, nn.BatchNorm2d):
cell.gamma.set_data(weight_init.initializer('ones', cell.gamma.shape))
cell.beta.set_data(weight_init.initializer('zeros', cell.beta.shape))
if isinstance(cell, nn.Dense):
cell.bias.default_input = weight_init.initializer('zeros', cell.bias.shape, cell.bias.dtype)
# define opt
decayed_params = []
no_decayed_params = []
for param in network.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
{'params': no_decayed_params},
{'order_params': network.trainable_params()}]
net_opt = nn.Momentum(group_params, lr, config.momentum,
weight_decay=config.weight_decay,
loss_scale=config.loss_scale)
# define loss
loss = CrossEntropySmooth(smooth_factor=args.label_smooth_factor,
num_classes=config.class_num)
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
model = Model(network, loss_fn=loss, optimizer=net_opt,
loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O3")
# define callbacks
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
loss_cb = LossMonitor()
cb = [time_cb, loss_cb]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
if args.isModelArts:
save_checkpoint_path = '/cache/train_output/device_' + os.getenv('DEVICE_ID') + '/'
else:
save_checkpoint_path = config.save_checkpoint_path
ckpt_cb = ModelCheckpoint(prefix="HarDNet85",
directory=save_checkpoint_path,
config=config_ck)
cb += [ckpt_cb]
print("\n\n========================")
print("Dataset path: {}".format(args.dataset_path))
print("Total epoch: {}".format(config.epoch_size))
print("Batch size: {}".format(config.batch_size))
print("Class num: {}".format(config.class_num))
print("======= Multiple Training begin========")
model.train(config.epoch_size, train_dataset,
callbacks=cb, dataset_sink_mode=True)
if args.isModelArts:
mox.file.copy_parallel(src_url='/cache/train_output', dst_url=args.train_url)