!18820 HarDNet(GPU) check
Merge pull request !18820 from kqzhang/hardnet
This commit is contained in:
commit
3a9f5f6483
|
@ -103,6 +103,25 @@ HarDNet指的是Harmonic DenseNet: A low memory traffic network,其突出的
|
|||
|
||||
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools.>
|
||||
|
||||
- GPU环境运行
|
||||
|
||||
```python
|
||||
# 运行训练示例
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python3 train.py --device_target 'GPU' --distribute False --dataset_path /path/dataset --pre_ckpt_path /path/pretrained_path > train.log 2>&1 &
|
||||
或
|
||||
bash run_single_train_gpu.sh 0 /path/dataset /path/pretrain_path
|
||||
|
||||
# 运行分布式训练示例
|
||||
bash run_distribute_train_gpu.sh 8 0,1,2,3,4,5,6,7 /path/dataset /path/pretrain_path
|
||||
|
||||
# 运行评估示例
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python3 eval.py --device_target 'GPU' --dataset_path /path/dataset --ckpt_path /path/ckpt_path > eval.log 2>&1 &
|
||||
或
|
||||
bash run_eval_gpu.sh /path/dataset 0 /path/ckpt
|
||||
```
|
||||
|
||||
- 默认使用ImageNet2012数据集。您也可以将`$dataset_type`传入脚本,以便选择其他数据集。如需查看更多详情,请参考指定脚本。
|
||||
|
||||
# 脚本说明
|
||||
|
@ -118,6 +137,9 @@ HarDNet指的是Harmonic DenseNet: A low memory traffic network,其突出的
|
|||
│ ├──run_single_train.sh // 单卡到Ascend的shell脚本
|
||||
│ ├──run_distribute_train.sh // 分布式到Ascend的shell脚本
|
||||
│ ├──run_eval.sh // Ascend评估的shell脚本
|
||||
| ├──run_single_train_gpu.sh // 单卡到GPU的shell脚本
|
||||
│ ├──run_distribute_train_gpu.sh // 分布式到GPU的shell脚本
|
||||
│ ├──run_eval_gpu.sh // GPU评估的shell脚本
|
||||
├── src
|
||||
│ ├──dataset.py // 创建数据集
|
||||
│ ├──hardnet.py // hardnet架构
|
||||
|
@ -191,6 +213,28 @@ HarDNet指的是Harmonic DenseNet: A low memory traffic network,其突出的
|
|||
|
||||
模型检查点保存在当前目录下。
|
||||
|
||||
- GPU处理器环境运行
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python3 train.py --device_target 'GPU' --isModelArts False --distribute False --dataset_path /path/dataset --pre_ckpt_path /path/pretrained_path > train.log 2>&1 &
|
||||
或
|
||||
bash run_single_train_gpu.sh 0 /path/dataset /path/pretrain_path
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过train.log文件查看结果。
|
||||
|
||||
训练结束后,您可在默认脚本文件夹下找到检查点文件。采用以下方式达到损失值:
|
||||
|
||||
```bash
|
||||
# grep "loss is " train.log
|
||||
epoch:1 step:5000, loss is 3.0897788
|
||||
epcoh:2 step:5000, loss is 2.4842823
|
||||
...
|
||||
```
|
||||
|
||||
模型检查点保存在当前目录下。
|
||||
|
||||
### 分布式训练
|
||||
|
||||
- Ascend处理器环境运行
|
||||
|
@ -214,6 +258,35 @@ HarDNet指的是Harmonic DenseNet: A low memory traffic network,其突出的
|
|||
...
|
||||
```
|
||||
|
||||
- GPU处理器环境运行
|
||||
|
||||
```bash
|
||||
bash run_distribute_train_gpu.sh 8 0,1,2,3,4,5,6,7 /path/dataset /path/pretrain_path
|
||||
```
|
||||
|
||||
上述shell脚本将在后台运行分布训练。您可以通过train.log文件查看结果。采用以下方式达到损失值:
|
||||
|
||||
```bash
|
||||
# grep "result:" train.log
|
||||
epoch: 1 step: 625, loss is 2.7857578
|
||||
epoch: 1 step: 625, loss is 2.7340727
|
||||
epoch: 1 step: 625, loss is 2.7651663
|
||||
epoch: 1 step: 625, loss is 2.8074665
|
||||
epoch: 1 step: 625, loss is 2.8567638
|
||||
epoch: 1 step: 625, loss is 2.768191
|
||||
epoch: 1 step: 625, loss is 3.0651402
|
||||
epoch: 1 step: 625, loss is 3.039652
|
||||
epoch time: 1753885.943 ms, per step time: 2806.218 ms
|
||||
epoch time: 1753861.017 ms, per step time: 2806.178 ms
|
||||
epoch time: 1753959.524 ms, per step time: 2806.335 ms
|
||||
epoch time: 1753182.479 ms, per step time: 2805.092 ms
|
||||
epoch time: 1753981.462 ms, per step time: 2806.370 ms
|
||||
epoch time: 1753181.926 ms, per step time: 2805.091 ms
|
||||
epoch time: 1753266.931 ms, per step time: 2805.227 ms
|
||||
epoch time: 1753218.315 ms, per step time: 2805.149 ms
|
||||
...
|
||||
```
|
||||
|
||||
## 评估过程
|
||||
|
||||
### 评估
|
||||
|
@ -242,6 +315,31 @@ HarDNet指的是Harmonic DenseNet: A low memory traffic network,其突出的
|
|||
accuracy:{'acc':0.777}
|
||||
```
|
||||
|
||||
- 在GPU环境运行时评估ImageNet数据集
|
||||
|
||||
在运行以下命令之前,请检查用于评估的检查点路径。请将检查点路径设置为绝对全路径,例如“username/hardnet/train_hardnet_390.ckpt”。
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
python3 eval.py --device_target 'GPU' --dataset_path /path/dataset --ckpt_path /path/ckpt_path > eval.log 2>&1 &
|
||||
或
|
||||
bash run_eval_gpu.sh /path/dataset 0 /path/ckpt
|
||||
```
|
||||
|
||||
上述python命令将在后台运行,您可以通过eval.log文件查看结果。测试数据集的准确性如下:
|
||||
|
||||
```bash
|
||||
# grep "accuracy:" eval.log
|
||||
accuracy:{'acc':0.775}
|
||||
```
|
||||
|
||||
注:对于分布式训练后评估,请将checkpoint_path设置为最后保存的检查点文件,如“username/hardnet/result/train_hardnet-150-625.ckpt”。测试数据集的准确性如下:
|
||||
|
||||
```bash
|
||||
# grep "accuracy:" dist.eval.log
|
||||
accuracy:{'acc':0.777}
|
||||
```
|
||||
|
||||
## 推理过程
|
||||
|
||||
### 导出MindIR
|
||||
|
@ -284,38 +382,38 @@ bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [DEVICE_ID]
|
|||
|
||||
#### ImageNet上的HarDNet
|
||||
|
||||
| 参数 | Ascend |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| 模型版本 | Inception V1 |
|
||||
| 资源 | Ascend 910 ;CPU 2.60GHz,192核;内存:755G |
|
||||
| 上传日期 | 2021-3-22 |
|
||||
| MindSpore版本 | 1.1.1-aarch64 |
|
||||
| 数据集 | ImageNet2012 |
|
||||
| 训练参数 | epoch=150, steps=625, batch_size = 256, lr=0.1 |
|
||||
| 优化器 | Momentum |
|
||||
| 损失函数 | Softmax交叉熵 |
|
||||
| 输出 | 概率 |
|
||||
| 损失 | 0.0016 |
|
||||
| 速度 | 单卡:347毫秒/步; 8卡:358毫秒/步 |
|
||||
| 总时长 | 单卡:72小时50分钟; 8卡:10小时14分钟 |
|
||||
| 参数(M) | 13.0 |
|
||||
| 微调检查点 | 280M (.ckpt文件) |
|
||||
| 脚本 | [hardnet脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/hardnet) |
|
||||
| 参数 | Ascend |GPU |
|
||||
| -------------------- | ------------------------- | -------------------------- |
|
||||
| 模型版本 | Inception V1 | Inception V1 |
|
||||
| 资源 | Ascend 910 | Tesla V100 |
|
||||
| 上传日期 | 2021-3-22 | 2021-4-21 |
|
||||
| MindSpore版本 | 1.1.1-aarch64 | 1.1.1-aarch64 |
|
||||
| 数据集 | ImageNet2012 | ImageNet2012 |
|
||||
| 训练参数 | epoch=150, steps=625, batch_size = 256, lr=0.1 | epoch=150, steps=625, batch_size = 256, lr=0.1 |
|
||||
| 优化器 | Momentum | Momentum |
|
||||
| 损失函数 | Softmax交叉熵 | Softmax交叉熵 |
|
||||
| 输出 | 概率 | 概率 |
|
||||
| 损失 | 0.0016 | 0.0016 |
|
||||
| 速度 | 单卡:347毫秒/步; 8卡:358毫秒/步 | 8卡:2806毫秒/步 |
|
||||
| 总时长 | 单卡:72小时50分钟; 8卡:10小时14分钟 | 8卡:71小时14分钟 |
|
||||
| 参数(M) | 13.0 | 13.0 |
|
||||
| 微调检查点 | 280M (.ckpt文件) | 281M (.ckpt文件) |
|
||||
| 脚本 | [hardnet脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/hardnet) | [hardnet脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/hardnet) |
|
||||
|
||||
### 推理性能
|
||||
|
||||
#### ImageNet上的HarDNet
|
||||
|
||||
| 参数 | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| 模型版本 | Inception V1 |
|
||||
| 资源 | Ascend 910 |
|
||||
| 上传日期 | 2020-09-20 |
|
||||
| MindSpore版本 | 1.1.1-aarch64 |
|
||||
| 数据集 | ImageNet2012 |
|
||||
| batch_size | 256 |
|
||||
| 输出 | 概率 |
|
||||
| 准确性 | 8卡: 78% |
|
||||
| 参数 | Ascend | GPU |
|
||||
| ------------------- | --------------------------- | --------------------------- |
|
||||
| 模型版本 | Inception V1 | Inception V1 |
|
||||
| 资源 | Ascend 910 | Tesla V100 |
|
||||
| 上传日期 | 2021-03-22 | 2020-04-21 |
|
||||
| MindSpore版本 | 1.1.1-aarch64 | 1.1.1-aarch64 |
|
||||
| 数据集 | ImageNet2012 | ImageNet2012 |
|
||||
| batch_size | 256 | 256 |
|
||||
| 输出 | 概率 | 概率 |
|
||||
| 准确性 | 8卡: 78% | 8卡: 77.7% |
|
||||
|
||||
## 使用流程
|
||||
|
||||
|
@ -328,9 +426,9 @@ bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [DEVICE_ID]
|
|||
```python
|
||||
# 设置上下文
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=target,
|
||||
save_graphs=False,
|
||||
device_id=device_id)
|
||||
device_target="Ascend",
|
||||
save_graphs=False,
|
||||
device_id=device_id)
|
||||
|
||||
# 加载未知数据集进行推理
|
||||
predict_data = create_dataset_ImageNet(dataset_path=args.dataset_path,
|
||||
|
@ -358,6 +456,38 @@ bash run_infer_310.sh [MINDIR_PATH] [DATASET_PATH] [DEVICE_ID]
|
|||
print("==============Acc: {} ==============".format(acc))
|
||||
```
|
||||
|
||||
如果您需要使用此训练模型在GPU上进行推理,可参考此[链接](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/migrate_3rd_scripts.html)。下面是操作步骤示例:
|
||||
|
||||
- GPU处理器环境运行
|
||||
|
||||
```python
|
||||
# 设置上下文
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="GPU",
|
||||
save_graphs=False,)
|
||||
|
||||
# 加载未知数据集进行推理
|
||||
dataset = dataset.create_dataset(cfg.data_path, 1, False)
|
||||
|
||||
# 定义网络
|
||||
network = HarDNet85(num_classes=config.class_num)
|
||||
|
||||
# 加载checkpoint
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
# 定义损失函数
|
||||
loss = CrossEntropySmooth(smooth_factor=args.label_smooth_factor,
|
||||
num_classes=config.class_num)
|
||||
|
||||
# 定义模型
|
||||
model = Model(network, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
|
||||
# 对未知数据集进行预测
|
||||
acc = model.eval(dataset)
|
||||
print("==============Acc: {} ==============".format(acc))
|
||||
```
|
||||
|
||||
### 迁移学习
|
||||
|
||||
待补充
|
||||
|
|
|
@ -34,13 +34,10 @@ np.random.seed(1)
|
|||
dataset.config.set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--use_hardnet', type=bool, default=True, help='Enable HarnetUnit')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||
|
||||
parser.add_argument('--dataset_path', type=str, default='/data/imagenet_original/val/',
|
||||
help='Dataset path')
|
||||
parser.add_argument('--ckpt_path', type=str,
|
||||
default='/home/hardnet/result/HarDNet-150_625.ckpt',
|
||||
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
parser.add_argument('--ckpt_path', type=str, default='',
|
||||
help='if mode is test, must provide path where the trained ckpt file')
|
||||
parser.add_argument('--label_smooth_factor', type=float, default=0.1, help='label_smooth_factor')
|
||||
parser.add_argument('--device_id', type=int, default=0, help='device_id')
|
||||
|
@ -52,8 +49,10 @@ def test(ckpt_path):
|
|||
# init context
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=target,
|
||||
save_graphs=False,
|
||||
device_id=args.device_id)
|
||||
save_graphs=False)
|
||||
|
||||
if args.device_target == "Ascend":
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
# dataset
|
||||
predict_data = create_dataset_ImageNet(dataset_path=args.dataset_path,
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_distribute_train.sh DEVICE_NUM DEVICE_ID(0,1,2,3,4,5,6,7) DATA_PATH PRETRAINED_PATH"
|
||||
echo "For example: sh run_distribute_train.sh 8 0,1,2,3,4,5,6,7 /path/dataset /path/pretrain_path"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
set -e
|
||||
DATA_PATH=$3
|
||||
PRETRAINED_PATH=$4
|
||||
|
||||
if [ $1 -lt 1 ] && [ $1 -gt 8 ]
|
||||
then
|
||||
echo "error: DEVICE_NUM=$1 is not in (1-8)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=$1
|
||||
export RANK_SIZE=$1
|
||||
export CUDA_VISIBLE_DEVICES="$2"
|
||||
|
||||
cd ../
|
||||
mpirun -n $1 --allow-run-as-root python3 train.py --device_target 'GPU' --isModelArts False --dataset_path ${DATA_PATH} --pre_ckpt_path ${PRETRAINED_PATH} > train.log 2>&1 &
|
|
@ -0,0 +1,27 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_eval.sh DATA_PATH DEVICE_ID CKPT_PATH"
|
||||
echo "For example: sh run_eval.sh /path/dataset 0 /path/ckpt"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
set -e
|
||||
export CUDA_VISIBLE_DEVICES="$2"
|
||||
|
||||
cd ../
|
||||
python3 eval.py --device_target 'GPU' --dataset_path $1 --ckpt_path $3 > eval.log 2>&1 &
|
|
@ -0,0 +1,30 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash run_single_train.sh DEVICE_ID DATA_PATH PRETRAINED_PATH"
|
||||
echo "For example: sh run_single_train.sh 0 /path/dataset /path/pretrain_path"
|
||||
echo "It is better to use the absolute path."
|
||||
echo "=============================================================================================================="
|
||||
set -e
|
||||
DATA_PATH=$2
|
||||
PRETRAINED_PATH=$3
|
||||
|
||||
export CUDA_VISIBLE_DEVICES=$1
|
||||
|
||||
cd ../
|
||||
python3 train.py --device_target 'GPU' --isModelArts False --distribute False --dataset_path ${DATA_PATH} --pre_ckpt_path ${PRETRAINED_PATH} > train.log 2>&1 &
|
|
@ -81,13 +81,13 @@ class _CombConvLayer(nn.Cell):
|
|||
combconvlayer
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel=1, stride=1, dropout=0.9, bias=False):
|
||||
super(CombConvLayer, self).__init__()
|
||||
super(_CombConvLayer, self).__init__()
|
||||
self.CombConvLayer_Conv = _ConvLayer(in_channels, out_channels, kernel=kernel)
|
||||
self.CombConvLayer_DWConv = _DWConvLayer(out_channels, out_channels, stride=stride)
|
||||
|
||||
def construct(self, x):
|
||||
out = CombConvLayer_Conv(x)
|
||||
out = CombConvLayer_DWConv(out)
|
||||
out = self.CombConvLayer_Conv(x)
|
||||
out = self.CombConvLayer_DWConv(out)
|
||||
|
||||
return out
|
||||
|
||||
|
@ -178,11 +178,11 @@ class _CommenHead(nn.Cell):
|
|||
"""
|
||||
the transition layer
|
||||
"""
|
||||
def __init__(self, num_classes, out_channels, drop_rate):
|
||||
def __init__(self, num_classes, out_channels, keep_rate):
|
||||
super(_CommenHead, self).__init__()
|
||||
self.avgpool = GlobalAvgpooling()
|
||||
self.flat = nn.Flatten()
|
||||
self.drop = nn.Dropout(keep_prob=drop_rate)
|
||||
self.drop = nn.Dropout(keep_prob=keep_rate)
|
||||
self.dense = nn.Dense(out_channels, num_classes, has_bias=True)
|
||||
|
||||
def construct(self, x):
|
||||
|
@ -204,7 +204,7 @@ class HarDNet(nn.Cell):
|
|||
second_kernel = 3
|
||||
max_pool = True
|
||||
grmul = 1.7
|
||||
drop_rate = 0.9
|
||||
keep_rate = 0.9
|
||||
|
||||
# HarDNet68
|
||||
ch_list = [128, 256, 320, 640, 1024]
|
||||
|
@ -219,7 +219,7 @@ class HarDNet(nn.Cell):
|
|||
gr = [24, 24, 28, 36, 48, 256]
|
||||
n_layers = [8, 16, 16, 16, 16, 4]
|
||||
downSamp = [1, 0, 1, 0, 1, 0]
|
||||
drop_rate = 0.2
|
||||
keep_rate = 0.8
|
||||
elif arch == 39:
|
||||
# HarDNet39
|
||||
first_ch = [24, 48]
|
||||
|
@ -232,7 +232,7 @@ class HarDNet(nn.Cell):
|
|||
if depth_wise:
|
||||
second_kernel = 1
|
||||
max_pool = False
|
||||
drop_rate = 0.05
|
||||
keep_rate = 0.95
|
||||
|
||||
blks = len(n_layers)
|
||||
self.layers = nn.CellList()
|
||||
|
@ -261,7 +261,7 @@ class HarDNet(nn.Cell):
|
|||
else:
|
||||
self.layers.append(_DWConvLayer(ch, ch, stride=2))
|
||||
self.out_channels = ch_list[blks - 1]
|
||||
self.droprate = drop_rate
|
||||
self.keeprate = keep_rate
|
||||
|
||||
def construct(self, x):
|
||||
for layer in self.layers:
|
||||
|
@ -272,8 +272,8 @@ class HarDNet(nn.Cell):
|
|||
def get_out_channels(self):
|
||||
return self.out_channels
|
||||
|
||||
def get_drop_rate(self):
|
||||
return self.droprate
|
||||
def get_keep_rate(self):
|
||||
return self.keeprate
|
||||
|
||||
class HarDNet68(nn.Cell):
|
||||
"""
|
||||
|
@ -283,9 +283,9 @@ class HarDNet68(nn.Cell):
|
|||
super(HarDNet68, self).__init__()
|
||||
self.net = HarDNet(depth_wise=False, arch=68, pretrained=False)
|
||||
out_channels = self.net.get_out_channels()
|
||||
drop_rate = self.net.get_drop_rate()
|
||||
keep_rate = self.net.get_keep_rate()
|
||||
|
||||
self.head = _CommenHead(num_classes, out_channels, drop_rate)
|
||||
self.head = _CommenHead(num_classes, out_channels, keep_rate)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.net(x)
|
||||
|
@ -301,9 +301,9 @@ class HarDNet85(nn.Cell):
|
|||
super(HarDNet85, self).__init__()
|
||||
self.net = HarDNet(depth_wise=False, arch=85, pretrained=False)
|
||||
out_channels = self.net.get_out_channels()
|
||||
drop_rate = self.net.get_drop_rate()
|
||||
keep_rate = self.net.get_keep_rate()
|
||||
|
||||
self.head = _CommenHead(num_classes, out_channels, drop_rate)
|
||||
self.head = _CommenHead(num_classes, out_channels, keep_rate)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.net(x)
|
||||
|
|
|
@ -25,7 +25,7 @@ from mindspore.train.model import Model, ParallelMode
|
|||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as weight_init
|
||||
|
||||
|
@ -37,14 +37,12 @@ from src.config import config
|
|||
|
||||
parser = argparse.ArgumentParser(description='Image classification with HarDNet on Imagenet')
|
||||
|
||||
parser.add_argument('--dataset_path', type=str, default='/home/hardnet/imagenet_original/train/',
|
||||
help='Dataset path')
|
||||
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
||||
parser.add_argument('--device_num', type=int, default=8, help='Device num')
|
||||
parser.add_argument('--pre_trained', type=str, default=True)
|
||||
parser.add_argument('--train_url', type=str)
|
||||
parser.add_argument('--data_url', type=str)
|
||||
parser.add_argument('--pre_ckpt_path', type=str, default='/home/work/user-job-dir/hardnet/src/HarDNet85.ckpt')
|
||||
parser.add_argument('--pre_ckpt_path', type=str, default='', help='Pretrain path')
|
||||
parser.add_argument('--label_smooth_factor', type=float, default=0.1, help='label_smooth_factor')
|
||||
parser.add_argument('--isModelArts', type=ast.literal_eval, default=True)
|
||||
parser.add_argument('--distribute', type=ast.literal_eval, default=True)
|
||||
|
@ -57,22 +55,25 @@ if args.isModelArts:
|
|||
|
||||
if __name__ == '__main__':
|
||||
target = args.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=target,
|
||||
enable_auto_mixed_precision=True, save_graphs=False)
|
||||
|
||||
if args.distribute:
|
||||
# init context
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
|
||||
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
|
||||
if target == "Ascend":
|
||||
init()
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_auto_parallel_context(device_id=device_id,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
if target == "GPU":
|
||||
init()
|
||||
context.set_auto_parallel_context(device_num=get_group_size(),
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
device_id = args.device_id
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=target,
|
||||
save_graphs=False,
|
||||
device_id=args.device_id)
|
||||
if target == "Ascend":
|
||||
device_id = args.device_id
|
||||
context.set_context(device_id=args.device_id)
|
||||
|
||||
if args.isModelArts:
|
||||
import moxing as mox
|
||||
|
@ -146,8 +147,12 @@ if __name__ == '__main__':
|
|||
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
||||
model = Model(network, loss_fn=loss, optimizer=net_opt,
|
||||
loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O3")
|
||||
if target == "Ascend":
|
||||
model = Model(network, loss_fn=loss, optimizer=net_opt,
|
||||
loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O3")
|
||||
if target == "GPU":
|
||||
model = Model(network, loss_fn=loss, optimizer=net_opt,
|
||||
loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2")
|
||||
|
||||
# define callbacks
|
||||
time_cb = TimeMonitor(data_size=train_dataset.get_dataset_size())
|
||||
|
@ -156,10 +161,14 @@ if __name__ == '__main__':
|
|||
if config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
|
||||
if args.isModelArts:
|
||||
save_checkpoint_path = '/cache/train_output/device_' + os.getenv('DEVICE_ID') + '/'
|
||||
else:
|
||||
save_checkpoint_path = config.save_checkpoint_path
|
||||
if target == "GPU" and args.distribute:
|
||||
save_checkpoint_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(get_rank()) + '/')
|
||||
else:
|
||||
save_checkpoint_path = config.save_checkpoint_path
|
||||
|
||||
ckpt_cb = ModelCheckpoint(prefix="HarDNet85",
|
||||
directory=save_checkpoint_path,
|
||||
|
@ -171,7 +180,7 @@ if __name__ == '__main__':
|
|||
print("Total epoch: {}".format(config.epoch_size))
|
||||
print("Batch size: {}".format(config.batch_size))
|
||||
print("Class num: {}".format(config.class_num))
|
||||
print("======= Multiple Training begin========")
|
||||
print("=======Training begin========")
|
||||
model.train(config.epoch_size, train_dataset,
|
||||
callbacks=cb, dataset_sink_mode=True)
|
||||
if args.isModelArts:
|
||||
|
|
Loading…
Reference in New Issue