This commit is contained in:
kqzhang 2021-06-24 16:57:30 +08:00
parent d621cdbc97
commit 9e7114e6f5
7 changed files with 307 additions and 74 deletions

View File

@ -103,6 +103,25 @@ HarDNet指的是Harmonic DenseNet: A low memory traffic network其突出的
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
- GPU环境运行
```python
# 运行训练示例
export CUDA_VISIBLE_DEVICES=0
python3 train.py --device_target 'GPU' --distribute False --dataset_path /path/dataset --pre_ckpt_path /path/pretrained_path > train.log 2>&1 &
bash run_single_train_gpu.sh 0 /path/dataset /path/pretrain_path
# 运行分布式训练示例
bash run_distribute_train_gpu.sh 8 0,1,2,3,4,5,6,7 /path/dataset /path/pretrain_path
# 运行评估示例
export CUDA_VISIBLE_DEVICES=0
python3 eval.py --device_target 'GPU' --dataset_path /path/dataset --ckpt_path /path/ckpt_path > eval.log 2>&1 &
bash run_eval_gpu.sh /path/dataset 0 /path/ckpt
```
- 默认使用ImageNet2012数据集。您也可以将`$dataset_type`传入脚本,以便选择其他数据集。如需查看更多详情,请参考指定脚本。
# 脚本说明
@ -118,6 +137,9 @@ HarDNet指的是Harmonic DenseNet: A low memory traffic network其突出的
│ ├──run_single_train.sh // 单卡到Ascend的shell脚本
│ ├──run_distribute_train.sh // 分布式到Ascend的shell脚本
│ ├──run_eval.sh // Ascend评估的shell脚本
| ├──run_single_train_gpu.sh // 单卡到GPU的shell脚本
│ ├──run_distribute_train_gpu.sh // 分布式到GPU的shell脚本
│ ├──run_eval_gpu.sh // GPU评估的shell脚本
├── src
│ ├──dataset.py // 创建数据集
│ ├──hardnet.py // hardnet架构
@ -191,6 +213,28 @@ HarDNet指的是Harmonic DenseNet: A low memory traffic network其突出的
模型检查点保存在当前目录下。
- GPU处理器环境运行
```bash
export CUDA_VISIBLE_DEVICES=0
python3 train.py --device_target 'GPU' --isModelArts False --distribute False --dataset_path /path/dataset --pre_ckpt_path /path/pretrained_path > train.log 2>&1 &
bash run_single_train_gpu.sh 0 /path/dataset /path/pretrain_path
```
上述python命令将在后台运行您可以通过train.log文件查看结果。
训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式达到损失值:
```bash
# grep "loss is " train.log
epoch:1 step:5000, loss is 3.0897788
epcoh:2 step:5000, loss is 2.4842823
...
```
模型检查点保存在当前目录下。
### 分布式训练
- Ascend处理器环境运行
@ -214,6 +258,35 @@ HarDNet指的是Harmonic DenseNet: A low memory traffic network其突出的
...
```
- GPU处理器环境运行
```bash
bash run_distribute_train_gpu.sh 8 0,1,2,3,4,5,6,7 /path/dataset /path/pretrain_path
```
上述shell脚本将在后台运行分布训练。您可以通过train.log文件查看结果。采用以下方式达到损失值
```bash
# grep "result:" train.log
epoch: 1 step: 625, loss is 2.7857578
epoch: 1 step: 625, loss is 2.7340727
epoch: 1 step: 625, loss is 2.7651663
epoch: 1 step: 625, loss is 2.8074665
epoch: 1 step: 625, loss is 2.8567638
epoch: 1 step: 625, loss is 2.768191
epoch: 1 step: 625, loss is 3.0651402
epoch: 1 step: 625, loss is 3.039652
epoch time: 1753885.943 ms, per step time: 2806.218 ms
epoch time: 1753861.017 ms, per step time: 2806.178 ms
epoch time: 1753959.524 ms, per step time: 2806.335 ms
epoch time: 1753182.479 ms, per step time: 2805.092 ms
epoch time: 1753981.462 ms, per step time: 2806.370 ms
epoch time: 1753181.926 ms, per step time: 2805.091 ms
epoch time: 1753266.931 ms, per step time: 2805.227 ms
epoch time: 1753218.315 ms, per step time: 2805.149 ms
...
```
## 评估过程
### 评估
@ -242,6 +315,31 @@ HarDNet指的是Harmonic DenseNet: A low memory traffic network其突出的
accuracy:{'acc':0.777}
```
- 在GPU环境运行时评估ImageNet数据集
在运行以下命令之前请检查用于评估的检查点路径。请将检查点路径设置为绝对全路径例如“username/hardnet/train_hardnet_390.ckpt”。
```bash
export CUDA_VISIBLE_DEVICES=0
python3 eval.py --device_target 'GPU' --dataset_path /path/dataset --ckpt_path /path/ckpt_path > eval.log 2>&1 &
bash run_eval_gpu.sh /path/dataset 0 /path/ckpt
```
上述python命令将在后台运行您可以通过eval.log文件查看结果。测试数据集的准确性如下
```bash
# grep "accuracy:" eval.log
accuracy:{'acc':0.775}
```
对于分布式训练后评估请将checkpoint_path设置为最后保存的检查点文件如“username/hardnet/result/train_hardnet-150-625.ckpt”。测试数据集的准确性如下
```bash
# grep "accuracy:" dist.eval.log
accuracy:{'acc':0.777}
```
## 推理过程
### 导出MindIR
@ -284,38 +382,38 @@ bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [DEVICE_ID]
#### ImageNet上的HarDNet
| 参数 | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| 模型版本 | Inception V1 |
| 资源 | Ascend 910 CPU 2.60GHz192核内存755G |
| 上传日期 | 2021-3-22 |
| MindSpore版本 | 1.1.1-aarch64 |
| 数据集 | ImageNet2012 |
| 训练参数 | epoch=150, steps=625, batch_size = 256, lr=0.1 |
| 优化器 | Momentum |
| 损失函数 | Softmax交叉熵 |
| 输出 | 概率 |
| 损失 | 0.0016 |
| 速度 | 单卡347毫秒/步; 8卡358毫秒/步 |
| 总时长 | 单卡72小时50分钟; 8卡10小时14分钟 |
| 参数(M) | 13.0 |
| 微调检查点 | 280M (.ckpt文件) |
| 脚本 | [hardnet脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/hardnet) |
| 参数 | Ascend |GPU |
| -------------------- | ------------------------- | -------------------------- |
| 模型版本 | Inception V1 | Inception V1 |
| 资源 | Ascend 910 | Tesla V100 |
| 上传日期 | 2021-3-22 | 2021-4-21 |
| MindSpore版本 | 1.1.1-aarch64 | 1.1.1-aarch64 |
| 数据集 | ImageNet2012 | ImageNet2012 |
| 训练参数 | epoch=150, steps=625, batch_size = 256, lr=0.1 | epoch=150, steps=625, batch_size = 256, lr=0.1 |
| 优化器 | Momentum | Momentum |
| 损失函数 | Softmax交叉熵 | Softmax交叉熵 |
| 输出 | 概率 | 概率 |
| 损失 | 0.0016 | 0.0016 |
| 速度 | 单卡347毫秒/步; 8卡358毫秒/步 | 8卡2806毫秒/步 |
| 总时长 | 单卡72小时50分钟; 8卡10小时14分钟 | 8卡71小时14分钟 |
| 参数(M) | 13.0 | 13.0 |
| 微调检查点 | 280M (.ckpt文件) | 281M (.ckpt文件) |
| 脚本 | [hardnet脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/hardnet) | [hardnet脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/hardnet) |
### 推理性能
#### ImageNet上的HarDNet
| 参数 | Ascend |
| ------------------- | --------------------------- |
| 模型版本 | Inception V1 |
| 资源 | Ascend 910 |
| 上传日期 | 2020-09-20 |
| MindSpore版本 | 1.1.1-aarch64 |
| 数据集 | ImageNet2012 |
| batch_size | 256 |
| 输出 | 概率 |
| 准确性 | 8卡: 78% |
| 参数 | Ascend | GPU |
| ------------------- | --------------------------- | --------------------------- |
| 模型版本 | Inception V1 | Inception V1 |
| 资源 | Ascend 910 | Tesla V100 |
| 上传日期 | 2021-03-22 | 2020-04-21 |
| MindSpore版本 | 1.1.1-aarch64 | 1.1.1-aarch64 |
| 数据集 | ImageNet2012 | ImageNet2012 |
| batch_size | 256 | 256 |
| 输出 | 概率 | 概率 |
| 准确性 | 8卡: 78% | 8卡: 77.7% |
## 使用流程
@ -328,9 +426,9 @@ bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [DEVICE_ID]
```python
# 设置上下文
context.set_context(mode=context.GRAPH_MODE,
device_target=target,
save_graphs=False,
device_id=device_id)
device_target="Ascend",
save_graphs=False,
device_id=device_id)
# 加载未知数据集进行推理
predict_data = create_dataset_ImageNet(dataset_path=args.dataset_path,
@ -358,6 +456,38 @@ bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [DEVICE_ID]
print("==============Acc: {} ==============".format(acc))
```
如果您需要使用此训练模型在GPU上进行推理可参考此[链接](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html)。下面是操作步骤示例:
- GPU处理器环境运行
```python
# 设置上下文
context.set_context(mode=context.GRAPH_MODE,
device_target="GPU",
save_graphs=False,)
# 加载未知数据集进行推理
dataset = dataset.create_dataset(cfg.data_path, 1, False)
# 定义网络
network = HarDNet85(num_classes=config.class_num)
# 加载checkpoint
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(network, param_dict)
# 定义损失函数
loss = CrossEntropySmooth(smooth_factor=args.label_smooth_factor,
num_classes=config.class_num)
# 定义模型
model = Model(network, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
# 对未知数据集进行预测
acc = model.eval(dataset)
print("==============Acc: {} ==============".format(acc))
```
### 迁移学习
待补充

View File

@ -34,13 +34,10 @@ np.random.seed(1)
dataset.config.set_seed(1)
parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--use_hardnet', type=bool, default=True, help='Enable HarnetUnit')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--dataset_path', type=str, default='/data/imagenet_original/val/',
help='Dataset path')
parser.add_argument('--ckpt_path', type=str,
default='/home/hardnet/result/HarDNet-150_625.ckpt',
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
parser.add_argument('--ckpt_path', type=str, default='',
help='if mode is test, must provide path where the trained ckpt file')
parser.add_argument('--label_smooth_factor', type=float, default=0.1, help='label_smooth_factor')
parser.add_argument('--device_id', type=int, default=0, help='device_id')
@ -52,8 +49,10 @@ def test(ckpt_path):
# init context
context.set_context(mode=context.GRAPH_MODE,
device_target=target,
save_graphs=False,
device_id=args.device_id)
save_graphs=False)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
# dataset
predict_data = create_dataset_ImageNet(dataset_path=args.dataset_path,

View File

@ -0,0 +1,38 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_distribute_train.sh DEVICE_NUM DEVICE_ID(0,1,2,3,4,5,6,7) DATA_PATH PRETRAINED_PATH"
echo "For example: sh run_distribute_train.sh 8 0,1,2,3,4,5,6,7 /path/dataset /path/pretrain_path"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
set -e
DATA_PATH=$3
PRETRAINED_PATH=$4
if [ $1 -lt 1 ] && [ $1 -gt 8 ]
then
echo "error: DEVICE_NUM=$1 is not in (1-8)"
exit 1
fi
export DEVICE_NUM=$1
export RANK_SIZE=$1
export CUDA_VISIBLE_DEVICES="$2"
cd ../
mpirun -n $1 --allow-run-as-root python3 train.py --device_target 'GPU' --isModelArts False --dataset_path ${DATA_PATH} --pre_ckpt_path ${PRETRAINED_PATH} > train.log 2>&1 &

View File

@ -0,0 +1,27 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_eval.sh DATA_PATH DEVICE_ID CKPT_PATH"
echo "For example: sh run_eval.sh /path/dataset 0 /path/ckpt"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
set -e
export CUDA_VISIBLE_DEVICES="$2"
cd ../
python3 eval.py --device_target 'GPU' --dataset_path $1 --ckpt_path $3 > eval.log 2>&1 &

View File

@ -0,0 +1,30 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_single_train.sh DEVICE_ID DATA_PATH PRETRAINED_PATH"
echo "For example: sh run_single_train.sh 0 /path/dataset /path/pretrain_path"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
set -e
DATA_PATH=$2
PRETRAINED_PATH=$3
export CUDA_VISIBLE_DEVICES=$1
cd ../
python3 train.py --device_target 'GPU' --isModelArts False --distribute False --dataset_path ${DATA_PATH} --pre_ckpt_path ${PRETRAINED_PATH} > train.log 2>&1 &

View File

@ -81,13 +81,13 @@ class _CombConvLayer(nn.Cell):
combconvlayer
"""
def __init__(self, in_channels, out_channels, kernel=1, stride=1, dropout=0.9, bias=False):
super(CombConvLayer, self).__init__()
super(_CombConvLayer, self).__init__()
self.CombConvLayer_Conv = _ConvLayer(in_channels, out_channels, kernel=kernel)
self.CombConvLayer_DWConv = _DWConvLayer(out_channels, out_channels, stride=stride)
def construct(self, x):
out = CombConvLayer_Conv(x)
out = CombConvLayer_DWConv(out)
out = self.CombConvLayer_Conv(x)
out = self.CombConvLayer_DWConv(out)
return out
@ -178,11 +178,11 @@ class _CommenHead(nn.Cell):
"""
the transition layer
"""
def __init__(self, num_classes, out_channels, drop_rate):
def __init__(self, num_classes, out_channels, keep_rate):
super(_CommenHead, self).__init__()
self.avgpool = GlobalAvgpooling()
self.flat = nn.Flatten()
self.drop = nn.Dropout(keep_prob=drop_rate)
self.drop = nn.Dropout(keep_prob=keep_rate)
self.dense = nn.Dense(out_channels, num_classes, has_bias=True)
def construct(self, x):
@ -204,7 +204,7 @@ class HarDNet(nn.Cell):
second_kernel = 3
max_pool = True
grmul = 1.7
drop_rate = 0.9
keep_rate = 0.9
# HarDNet68
ch_list = [128, 256, 320, 640, 1024]
@ -219,7 +219,7 @@ class HarDNet(nn.Cell):
gr = [24, 24, 28, 36, 48, 256]
n_layers = [8, 16, 16, 16, 16, 4]
downSamp = [1, 0, 1, 0, 1, 0]
drop_rate = 0.2
keep_rate = 0.8
elif arch == 39:
# HarDNet39
first_ch = [24, 48]
@ -232,7 +232,7 @@ class HarDNet(nn.Cell):
if depth_wise:
second_kernel = 1
max_pool = False
drop_rate = 0.05
keep_rate = 0.95
blks = len(n_layers)
self.layers = nn.CellList()
@ -261,7 +261,7 @@ class HarDNet(nn.Cell):
else:
self.layers.append(_DWConvLayer(ch, ch, stride=2))
self.out_channels = ch_list[blks - 1]
self.droprate = drop_rate
self.keeprate = keep_rate
def construct(self, x):
for layer in self.layers:
@ -272,8 +272,8 @@ class HarDNet(nn.Cell):
def get_out_channels(self):
return self.out_channels
def get_drop_rate(self):
return self.droprate
def get_keep_rate(self):
return self.keeprate
class HarDNet68(nn.Cell):
"""
@ -283,9 +283,9 @@ class HarDNet68(nn.Cell):
super(HarDNet68, self).__init__()
self.net = HarDNet(depth_wise=False, arch=68, pretrained=False)
out_channels = self.net.get_out_channels()
drop_rate = self.net.get_drop_rate()
keep_rate = self.net.get_keep_rate()
self.head = _CommenHead(num_classes, out_channels, drop_rate)
self.head = _CommenHead(num_classes, out_channels, keep_rate)
def construct(self, x):
x = self.net(x)
@ -301,9 +301,9 @@ class HarDNet85(nn.Cell):
super(HarDNet85, self).__init__()
self.net = HarDNet(depth_wise=False, arch=85, pretrained=False)
out_channels = self.net.get_out_channels()
drop_rate = self.net.get_drop_rate()
keep_rate = self.net.get_keep_rate()
self.head = _CommenHead(num_classes, out_channels, drop_rate)
self.head = _CommenHead(num_classes, out_channels, keep_rate)
def construct(self, x):
x = self.net(x)

View File

@ -25,7 +25,7 @@ from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init
from mindspore.communication.management import init, get_group_size, get_rank
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
@ -37,14 +37,12 @@ from src.config import config
parser = argparse.ArgumentParser(description='Image classification with HarDNet on Imagenet')
parser.add_argument('--dataset_path', type=str, default='/home/hardnet/imagenet_original/train/',
help='Dataset path')
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
parser.add_argument('--device_num', type=int, default=8, help='Device num')
parser.add_argument('--pre_trained', type=str, default=True)
parser.add_argument('--train_url', type=str)
parser.add_argument('--data_url', type=str)
parser.add_argument('--pre_ckpt_path', type=str, default='/home/work/user-job-dir/hardnet/src/HarDNet85.ckpt')
parser.add_argument('--pre_ckpt_path', type=str, default='', help='Pretrain path')
parser.add_argument('--label_smooth_factor', type=float, default=0.1, help='label_smooth_factor')
parser.add_argument('--isModelArts', type=ast.literal_eval, default=True)
parser.add_argument('--distribute', type=ast.literal_eval, default=True)
@ -57,22 +55,25 @@ if args.isModelArts:
if __name__ == '__main__':
target = args.device_target
context.set_context(mode=context.GRAPH_MODE, device_target=target,
enable_auto_mixed_precision=True, save_graphs=False)
if args.distribute:
# init context
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
if target == "Ascend":
init()
device_id = int(os.getenv('DEVICE_ID'))
context.set_auto_parallel_context(device_id=device_id,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
if target == "GPU":
init()
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
device_id = args.device_id
context.set_context(mode=context.GRAPH_MODE,
device_target=target,
save_graphs=False,
device_id=args.device_id)
if target == "Ascend":
device_id = args.device_id
context.set_context(device_id=args.device_id)
if args.isModelArts:
import moxing as mox
@ -146,8 +147,12 @@ if __name__ == '__main__':
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
model = Model(network, loss_fn=loss, optimizer=net_opt,
loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O3")
if target == "Ascend":
model = Model(network, loss_fn=loss, optimizer=net_opt,
loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O3")
if target == "GPU":
model = Model(network, loss_fn=loss, optimizer=net_opt,
loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2")
# define callbacks
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
@ -156,10 +161,14 @@ if __name__ == '__main__':
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config.keep_checkpoint_max)
if args.isModelArts:
save_checkpoint_path = '/cache/train_output/device_' + os.getenv('DEVICE_ID') + '/'
else:
save_checkpoint_path = config.save_checkpoint_path
if target == "GPU" and args.distribute:
save_checkpoint_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(get_rank()) + '/')
else:
save_checkpoint_path = config.save_checkpoint_path
ckpt_cb = ModelCheckpoint(prefix="HarDNet85",
directory=save_checkpoint_path,
@ -171,7 +180,7 @@ if __name__ == '__main__':
print("Total epoch: {}".format(config.epoch_size))
print("Batch size: {}".format(config.batch_size))
print("Class num: {}".format(config.class_num))
print("======= Multiple Training begin========")
print("=======Training begin========")
model.train(config.epoch_size, train_dataset,
callbacks=cb, dataset_sink_mode=True)
if args.isModelArts: